diff --git a/.gitignore b/.gitignore index d11a504bdc56ee98b3d5a0c33f9f75d996e45567..be75938ec401b1d72fa54773c85191aaac7d7f35 100644 --- a/.gitignore +++ b/.gitignore @@ -6,7 +6,7 @@ node_modules /bazel-* /bazel_pip /tools/python_bin_path.sh -/tools/git/gen +/tensorflow/tools/git/gen /pip_test /_python_build *.pyc @@ -26,4 +26,11 @@ Podfile.lock /tensorflow/contrib/lite/gen/** /tensorflow/contrib/lite/examples/ios/simple/data/*.txt /tensorflow/contrib/lite/examples/ios/simple/data/*.tflite -xcuserdata/** \ No newline at end of file +xcuserdata/** + +# Android +.gradle +.idea +*.iml +local.properties +gradleBuild diff --git a/BUILD b/BUILD index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..4bf647e47aa56cff0b3fd5af7d5df99d8b70549b 100644 --- a/BUILD +++ b/BUILD @@ -0,0 +1,6 @@ +exports_files( + [ + "LICENSE", + "ACKNOWLEDGEMENTS", + ], +) diff --git a/CODEOWNERS b/CODEOWNERS index 57a4df40e651f45dc03493af631d73332e46c182..007a304c3e706ce968576ec8979c08f1a3bcc552 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -1,53 +1,53 @@ # NOTE: Disabled temporarily because it's too noisy on pushes. # Where component owners are known, add them here. -#tensorflow/core/platform/windows/* @mrry -#tensorflow/java/* @asimshankar -#tensorflow/tensorboard/* @jart @dandelionmane -#tensorflow/tools/docs/* @markdaoust +# /tensorflow/core/platform/windows/ @mrry +# /tensorflow/java/ @asimshankar +# /tensorflow/tensorboard/ @jart @dandelionmane +# /tensorflow/tools/docs/ @markdaoust # contrib -# NEED OWNER: tensorflow/contrib/avro/* -#tensorflow/contrib/batching/* @alextp @chrisolston -#tensorflow/contrib/bayesflow/* @ebrevdo @rsepassi @jvdillon -#tensorflow/contrib/boosted_trees/* @sshrdp @yk5 @nataliaponomareva -#tensorflow/contrib/cmake/* @mrry @benoitsteiner -#tensorflow/contrib/copy_graph/* @tucker @poxvoculi -#tensorflow/contrib/crf/* @kentonl -#tensorflow/contrib/data/* @mrry -#tensorflow/contrib/distributions/* @jvdillon @langmore @rsepassi -#tensorflow/contrib/factorization/* @agarwal-ashish @xavigonzalvo -#tensorflow/contrib/ffmpeg/* @fredbertsch -# NEED OWNER: tensorflow/contrib/framework/* -#tensorflow/contrib/graph_editor/* @purpledog -# NEED OWNER: tensorflow/contrib/grid_rnn/* -#tensorflow/contrib/hvx/* @satok16 -#tensorflow/contrib/integrate/* @shoyer -#tensorflow/contrib/kernel_methods/* @petrosmol -#tensorflow/contrib/ios_examples/* @petewarden -#tensorflow/contrib/labeled_tensor/* @shoyer -#tensorflow/contrib/layers/* @fchollet @martinwicke -#tensorflow/contrib/learn/* @martinwicke @ispirmustafa @alextp -#tensorflow/contrib/linalg/* @langmore -#tensorflow/contrib/linear_optimizer/* @petrosmol @andreasst @katsiapis -#tensorflow/contrib/lookup/* @ysuematsu @andreasst -#tensorflow/contrib/losses/* @alextp @ispirmustafa -#tensorflow/contrib/makefile/* @petewarden @satok16 @wolffg -#tensorflow/contrib/metrics/* @alextp @honkentuber @ispirmustafa -#tensorflow/contrib/nccl/* @cwhipkey @zheng-xq -#tensorflow/contrib/opt/* @strategist333 -#tensorflow/contrib/pi_examples/* @maciekcc -#tensorflow/contrib/quantization/* @petewarden @cwhipkey @keveman -#tensorflow/contrib/rnn/* @ebrevdo -#tensorflow/contrib/saved_model/* @nfiedel @sukritiramesh -#tensorflow/contrib/seq2seq/* @lukaszkaiser -#tensorflow/contrib/session_bundle/* @nfiedel @sukritiramesh -#tensorflow/contrib/slim/* @sguada @thenbasilmanran -#tensorflow/contrib/stateless/* @girving -#tensorflow/contrib/tensor_forest/* @gilberthendry @thomascolthurst -#tensorflow/contrib/testing/* @dandelionmane -#tensorflow/contrib/timeseries/* @allenlavoie -#tensorflow/contrib/tpu/* @frankchn @saeta @jhseu -#tensorflow/contrib/training/* @joel-shor @ebrevdo -#tensorflow/contrib/util/* @sherrym +# NEED OWNER: /tensorflow/contrib/avro/ +# /tensorflow/contrib/batching/ @alextp @chrisolston +# /tensorflow/contrib/bayesflow/ @ebrevdo @rsepassi @jvdillon +# /tensorflow/contrib/boosted_trees/ @sshrdp @yk5 @nataliaponomareva +# /tensorflow/contrib/cmake/ @mrry @benoitsteiner +# /tensorflow/contrib/copy_graph/ @tucker @poxvoculi +# /tensorflow/contrib/crf/ @kentonl +# /tensorflow/contrib/data/ @mrry +# /tensorflow/contrib/distributions/ @jvdillon @langmore @rsepassi +# /tensorflow/contrib/factorization/ @agarwal-ashish @xavigonzalvo +# /tensorflow/contrib/ffmpeg/ @fredbertsch +# NEED OWNER: /tensorflow/contrib/framework/ +# /tensorflow/contrib/graph_editor/ @purpledog +# NEED OWNER: /tensorflow/contrib/grid_rnn/ +# /tensorflow/contrib/hvx/ @satok16 +# /tensorflow/contrib/integrate/ @shoyer +# /tensorflow/contrib/kernel_methods/ @petrosmol +# /tensorflow/contrib/ios_examples/ @petewarden +# /tensorflow/contrib/labeled_tensor/ @shoyer +# /tensorflow/contrib/layers/ @fchollet @martinwicke +# /tensorflow/contrib/learn/ @martinwicke @ispirmustafa @alextp +# /tensorflow/contrib/linalg/ @langmore +# /tensorflow/contrib/linear_optimizer/ @petrosmol @andreasst @katsiapis +# /tensorflow/contrib/lookup/ @ysuematsu @andreasst +# /tensorflow/contrib/losses/ @alextp @ispirmustafa +# /tensorflow/contrib/makefile/ @petewarden @satok16 @wolffg +# /tensorflow/contrib/metrics/ @alextp @honkentuber @ispirmustafa +# /tensorflow/contrib/nccl/ @cwhipkey @zheng-xq +# /tensorflow/contrib/opt/ @strategist333 +# /tensorflow/contrib/pi_examples/ @maciekcc +# /tensorflow/contrib/quantization/ @petewarden @cwhipkey @keveman +# /tensorflow/contrib/rnn/ @ebrevdo +# /tensorflow/contrib/saved_model/ @nfiedel @sukritiramesh +# /tensorflow/contrib/seq2seq/ @lukaszkaiser +# /tensorflow/contrib/session_bundle/ @nfiedel @sukritiramesh +# /tensorflow/contrib/slim/ @sguada @thenbasilmanran +# /tensorflow/contrib/stateless/ @girving +# /tensorflow/contrib/tensor_forest/ @gilberthendry @thomascolthurst +# /tensorflow/contrib/testing/ @dandelionmane +# /tensorflow/contrib/timeseries/ @allenlavoie +# /tensorflow/contrib/tpu/ @frankchn @saeta @jhseu +# /tensorflow/contrib/training/ @joel-shor @ebrevdo +# /tensorflow/contrib/util/ @sherrym diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md index ff11d131409b65880f16b80f9fe38dc39ac0e5fa..5fff9d05a1c589636bc9c711e6eb7cc4aba86b2f 100644 --- a/CODE_OF_CONDUCT.md +++ b/CODE_OF_CONDUCT.md @@ -67,4 +67,4 @@ If the Project Stewards receive a report alleging a violation of the Code of Con ## Attribution -This Code of Conduct is adapted from the Contributor Covenant, version 1.4, available at http://contributor-covenant.org/version/1/4, and includes some aspects of the Geek Feminism Code of Conduct and the Drupal Code of Conduct. +This Code of Conduct is adapted from the Contributor Covenant, version 1.4, available at https://contributor-covenant.org/version/1/4, and includes some aspects of the Geek Feminism Code of Conduct and the Drupal Code of Conduct. diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 1b537ca73cc94e992e7537fe69c8d0cc8fd13102..de4fded6ae6e66995aa9f1687a9d598017416f7a 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -8,8 +8,8 @@ We'd love to accept your patches! Before we can take them, we have to jump a cou Please fill out either the individual or corporate Contributor License Agreement (CLA). - * If you are an individual writing original source code and you're sure you own the intellectual property, then you'll need to sign an [individual CLA](http://code.google.com/legal/individual-cla-v1.0.html). - * If you work for a company that wants to allow you to contribute your work, then you'll need to sign a [corporate CLA](http://code.google.com/legal/corporate-cla-v1.0.html). + * If you are an individual writing original source code and you're sure you own the intellectual property, then you'll need to sign an [individual CLA](https://code.google.com/legal/individual-cla-v1.0.html). + * If you work for a company that wants to allow you to contribute your work, then you'll need to sign a [corporate CLA](https://code.google.com/legal/corporate-cla-v1.0.html). Follow either of the two links above to access the appropriate CLA and instructions for how to sign and return it. Once we receive it, we'll be able to accept your pull requests. @@ -20,6 +20,9 @@ Follow either of the two links above to access the appropriate CLA and instructi If you have improvements to TensorFlow, send us your pull requests! For those just getting started, Github has a [howto](https://help.github.com/articles/using-pull-requests/). +TensorFlow team members will be assigned to review your pull requests. Once the pull requests are approved and pass continuous integration checks, we will merge the pull requests. +For some pull requests, we will apply the patch for each pull request to our internal version control system first, and export the change out as a new commit later, at which point the original pull request will be closed. The commits in the pull request will be squashed into a single commit with the pull request creator as the author. These pull requests will be labeled as pending merge internally. + If you want to contribute but you're not sure where to start, take a look at the [issues with the "contributions welcome" label](https://github.com/tensorflow/tensorflow/labels/stat%3Acontributions%20welcome). These are issues that we believe are particularly well suited for outside @@ -114,7 +117,7 @@ pylint --rcfile=/tmp/pylintrc myfile.py * [Google Java Style Guide](https://google.github.io/styleguide/javaguide.html) * [Google JavaScript Style Guide](https://google.github.io/styleguide/jsguide.html) * [Google Shell Style Guide](https://google.github.io/styleguide/shell.xml) -* [Google Objective-C Style Guide](http://google.github.io/styleguide/objcguide.html) +* [Google Objective-C Style Guide](https://google.github.io/styleguide/objcguide.html) #### Running sanity check diff --git a/LICENSE b/LICENSE index 15ae42140452d32ccf929f59f7eca01a3c7b555f..4862420c0234f7542d4fe8f3520516b484a64aed 100644 --- a/LICENSE +++ b/LICENSE @@ -1,4 +1,4 @@ -Copyright 2017 The TensorFlow Authors. All rights reserved. +Copyright 2018 The TensorFlow Authors. All rights reserved. Apache License Version 2.0, January 2004 diff --git a/README.md b/README.md index aff3427bddb307aea6d6c2466eac14c9edffcc32..0c93813e584d4e41fe80d50e047069b2dad8311a 100644 --- a/README.md +++ b/README.md @@ -31,6 +31,10 @@ tracking requests and bugs. So please see [TensorFlow Discuss](https://groups.google.com/a/tensorflow.org/forum/#!forum/discuss) for general questions and discussion, and please direct specific questions to [Stack Overflow](https://stackoverflow.com/questions/tagged/tensorflow).** +The TensorFlow project strives to abide by generally accepted best practices in open-source software development: + +[![CII Best Practices](https://bestpractices.coreinfrastructure.org/projects/1486/badge)](https://bestpractices.coreinfrastructure.org/projects/1486) + ## Installation *See [Installing TensorFlow](https://www.tensorflow.org/get_started/os_setup.html) for instructions on how to install our release binaries or how to build from source.* @@ -46,11 +50,11 @@ packages on Linux, Mac, and Windows. **Individual whl files** -* Linux CPU-only: [Python 2](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly-1.head-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/)) / [Python 3.4](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly-1.head-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/)) / [Python 3.5](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly-1.head-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=cpu-slave/)) -* Linux GPU: [Python 2](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/42/artifact/pip_test/whl/tf_nightly_gpu-1.head-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/)) / [Python 3.4](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly_gpu-1.head-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/)) / [Python 3.5](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly_gpu-1.head-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/)) +* Linux CPU-only: [Python 2](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly-1.head-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/)) / [Python 3.4](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly-1.head-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/)) / [Python 3.5](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly-1.head-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=cpu-slave/)) / [Python 3.6](http://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.6,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly-1.head-cp36-cp36m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.6,label=cpu-slave/)) +* Linux GPU: [Python 2](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/42/artifact/pip_test/whl/tf_nightly_gpu-1.head-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/)) / [Python 3.4](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly_gpu-1.head-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/)) / [Python 3.5](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly_gpu-1.head-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/)) / [Python 3.6](http://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.6,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly_gpu-1.head-cp36-cp36m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.6,label=gpu-linux/)) * Mac CPU-only: [Python 2](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-mac/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac-slave/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly-1.head-py2-none-any.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-mac/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac-slave/)) / [Python 3](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-mac/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac-slave/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly-1.head-py3-none-any.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-mac/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac-slave/)) -* Windows CPU-only: [Python 3.5 64-bit](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows,PY=35/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tf_nightly-1.head-cp35-cp35m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows,PY=35/)) / [Python 3.6 64-bit](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows,PY=36/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tf_nightly-1.head-cp36-cp36m-win_amd64.whl) ([build history](http://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows,PY=36/)) -* Windows GPU: [Python 3.5 64-bit](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows-gpu,PY=35/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tf_nightly_gpu-1.head-cp35-cp35m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows-gpu,PY=35/)) / [Python 3.6 64-bit](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows-gpu,PY=36/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tf_nightly_gpu-1.head-cp36-cp36m-win_amd64.whl) ([build history](http://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows-gpu,PY=36/)) +* Windows CPU-only: [Python 3.5 64-bit](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows,PY=35/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tf_nightly-1.head-cp35-cp35m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows,PY=35/)) / [Python 3.6 64-bit](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows,PY=36/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tf_nightly-1.head-cp36-cp36m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows,PY=36/)) +* Windows GPU: [Python 3.5 64-bit](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows-gpu,PY=35/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tf_nightly_gpu-1.head-cp35-cp35m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows-gpu,PY=35/)) / [Python 3.6 64-bit](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows-gpu,PY=36/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tf_nightly_gpu-1.head-cp36-cp36m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows-gpu,PY=36/)) * Android: [demo APK](https://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/tensorflow_demo.apk), [native libs](https://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/native/) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-android/)) diff --git a/RELEASE.md b/RELEASE.md index e04bd3fc505d51ade9e9fa12c822cb695e90b4f3..39fc46ac6357300ea2b3365fa4c6d432d2a206db 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -1,3 +1,175 @@ +# Release 1.5.0 + +## Breaking Changes +* Prebuilt binaries are now built against CUDA 9 and cuDNN 7. +* Our Linux binaries are built using ubuntu 16 containers, potentially + introducing glibc incompatibility issues with ubuntu 14. +* Starting from 1.6 release, our prebuilt binaries will use AVX instructions. + This may break TF on older CPUs. + +## Major Features And Improvements +* [Eager execution](https://github.com/tensorflow/tensorflow/tree/r1.5/tensorflow/contrib/eager) + preview version is now available. +* [TensorFlow Lite](https://github.com/tensorflow/tensorflow/tree/r1.5/tensorflow/contrib/lite) + dev preview is now available. +* CUDA 9 and cuDNN 7 support. +* Accelerated Linear Algebra (XLA): + * Add `complex64` support to XLA compiler. + * `bfloat` support is now added to XLA infrastructure. + * Make `ClusterSpec` propagation work with XLA devices. + * Use a determinisitic executor to generate XLA graph. +* `tf.contrib`: + * `tf.contrib.distributions`: + * Add `tf.contrib.distributions.Autoregressive`. + * Make `tf.contrib.distributions` QuadratureCompound classes support batch + * Infer `tf.contrib.distributions.RelaxedOneHotCategorical` `dtype` from arguments. + * Make `tf.contrib.distributions` quadrature family parameterized by + `quadrature_grid_and_prob` vs `quadrature_degree`. + * `auto_correlation` added to `tf.contrib.distributions` + * Add `tf.contrib.bayesflow.layers`, a collection of probabilistic (neural) layers. + * Add `tf.contrib.bayesflow.halton_sequence`. + * Add `tf.contrib.data.make_saveable_from_iterator.` + * Add `tf.contrib.data.shuffle_and_repeat`. + * Add new custom transformation: `tf.contrib.data.scan()`. + * `tf.contrib.distributions.bijectors`: + * Add `tf.contrib.distributions.bijectors.MaskedAutoregressiveFlow`. + * Add `tf.contrib.distributions.bijectors.Permute`. + * Add `tf.contrib.distributions.bijectors.Gumbel`. + * Add `tf.contrib.distributions.bijectors.Reshape`. + * Support shape inference (i.e., shapes containing -1) in the Reshape bijector. +* Add `streaming_precision_recall_at_equal_thresholds,` a method for computing + streaming precision and recall with `O(num_thresholds + size of predictions)` + time and space complexity. +* Change `RunConfig` default behavior to not set a random seed, making random + behavior independently random on distributed workers. We expect this to + generally improve training performance. Models that do rely on determinism + should set a random seed explicitly. +* Replaced the implementation of `tf.flags` with `absl.flags`. +* Add support for `CUBLAS_TENSOR_OP_MATH` in fp16 GEMM +* Add support for CUDA on NVIDIA Tegra devices + +## Bug Fixes and Other Changes +* Documentation updates: + * Clarified that you can only install TensorFlow on 64-bit machines. + * Added a short doc explaining how `Estimator`s save checkpoints. + * Add documentation for ops supported by the `tf2xla` bridge. + * Fix minor typos in the doc of `SpaceToDepth` and `DepthToSpace`. + * Updated documentation comments in `mfcc_mel_filterbank.h` and `mfcc.h` to + clarify that the input domain is squared magnitude spectra and the weighting + is done on linear magnitude spectra (sqrt of inputs). + * Change `tf.contrib.distributions` docstring examples to use `tfd` alias + rather than `ds`, `bs`. + * Fix docstring typos in `tf.distributions.bijectors.Bijector`. + * `tf.assert_equal` no longer raises `ValueError.` It now raises + `InvalidArgumentError,` as documented. + * Update Getting Started docs and API intro. +* Google Cloud Storage (GCS): + * Add userspace DNS caching for the GCS client. + * Customize request timeouts for the GCS filesystem. + * Improve GCS filesystem caching. +* Bug Fixes: + * Fix bug where partitioned integer variables got their wrong shapes. Before + * Fix correctness bug in CPU and GPU implementations of Adadelta. + * Fix a bug in `import_meta_graph`'s handling of partitioned variables when + importing into a scope. WARNING: This may break loading checkpoints of + graphs with partitioned variables saved after using `import_meta_graph` with + a non-empty `import_scope` argument. + * Fix bug in offline debugger which prevented viewing events. + * Added the `WorkerService.DeleteWorkerSession` method to the gRPC interface, + to fix a memory leak. Ensure that your master and worker servers are running + the same version of TensorFlow to avoid compatibility issues. + * Fix bug in peephole implementation of BlockLSTM cell. + * Fix bug by casting dtype of `log_det_jacobian` to match `log_prob` in + `TransformedDistribution`. + * Fix a bug in `import_meta_graph`'s handling of partitioned variables when + * Ensure `tf.distributions.Multinomial` doesn't underflow in `log_prob`. + Before this change, all partitions of an integer variable were initialized + with the shape of the unpartitioned variable; after this change they are + initialized correctly. +* Other: + * Add necessary shape util support for bfloat16. + * Add a way to run ops using a step function to MonitoredSession. + * Add `DenseFlipout` probabilistic layer. + * A new flag `ignore_live_threads` is available on train. If set to `True`, it + will ignore threads that remain running when tearing down infrastructure + after successfully completing training, instead of throwing a RuntimeError. + * Restandardize `DenseVariational` as simpler template for other probabilistic + layers. + * `tf.data` now supports `tf.SparseTensor` components in dataset elements. + * It is now possible to iterate over `Tensor`s. + * Allow `SparseSegmentReduction` ops to have missing segment IDs. + * Modify custom export strategy to account for multidimensional sparse float + splits. + * `Conv2D`, `Conv2DBackpropInput`, `Conv2DBackpropFilter` now supports arbitrary + dilations with GPU and cuDNNv6 support. + * `Estimator` now supports `Dataset`: `input_fn` can return a `Dataset` + instead of `Tensor`s. + * Add `RevBlock`, a memory-efficient implementation of reversible residual layers. + * Reduce BFCAllocator internal fragmentation. + * Add `cross_entropy` and `kl_divergence` to `tf.distributions.Distribution`. + * Add `tf.nn.softmax_cross_entropy_with_logits_v2` which enables backprop + w.r.t. the labels. + * GPU back-end now uses `ptxas` to compile generated PTX. + * `BufferAssignment`'s protocol buffer dump is now deterministic. + * Change embedding op to use parallel version of `DynamicStitch`. + * Add support for sparse multidimensional feature columns. + * Speed up the case for sparse float columns that have only 1 value. + * Allow sparse float splits to support multivalent feature columns. + * Add `quantile` to `tf.distributions.TransformedDistribution`. + * Add `NCHW_VECT_C` support for `tf.depth_to_space` on GPU. + * Add `NCHW_VECT_C` support for `tf.space_to_depth` on GPU. + +## API Changes +* Rename `SqueezeDims` attribute to `Axis` in C++ API for Squeeze op. +* `Stream::BlockHostUntilDone` now returns Status rather than bool. +* Minor refactor: move stats files from `stochastic` to `common` and remove + `stochastic`. + +## Thanks to our Contributors + +This release contains contributions from many people at Google, as well as: + +Adam Zahran, Ag Ramesh, Alan Lee, Alan Yee, Alex Sergeev, Alexander, Amir H. Jadidinejad, +Amy, Anastasios Doumoulakis, Andrei Costinescu, Andrei Nigmatulin, Anthony Platanios, +Anush Elangovan, arixlin, Armen Donigian, ArtëM Sobolev, Atlas7, Ben Barsdell, Bill Prin, +Bo Wang, Brett Koonce, Cameron Thomas, Carl Thomé, Cem Eteke, cglewis, Changming Sun, +Charles Shenton, Chi-Hung, Chris Donahue, Chris Filo Gorgolewski, Chris Hoyean Song, +Chris Tava, Christian Grail, Christoph Boeddeker, cinqS, Clayne Robison, codrut3, concerttttt, +CQY, Dan Becker, Dan Jarvis, Daniel Zhang, David Norman, dmaclach, Dmitry Trifonov, +Donggeon Lim, dongpilYu, Dr. Kashif Rasul, Edd Wilder-James, Eric Lv, fcharras, Felix Abecassis, +FirefoxMetzger, formath, FredZhang, Gaojin Cao, Gary Deer, Guenther Schmuelling, Hanchen Li, +Hanmin Qin, hannesa2, hyunyoung2, Ilya Edrenkin, Jackson Kontny, Jan, Javier Luraschi, +Jay Young, Jayaram Bobba, Jeff, Jeff Carpenter, Jeremy Sharpe, Jeroen BéDorf, Jimmy Jia, +Jinze Bai, Jiongyan Zhang, Joe Castagneri, Johan Ju, Josh Varty, Julian Niedermeier, +JxKing, Karl Lessard, Kb Sriram, Keven Wang, Koan-Sin Tan, Kyle Mills, lanhin, LevineHuang, +Loki Der Quaeler, Loo Rong Jie, Luke Iwanski, LáSzló Csomor, Mahdi Abavisani, Mahmoud Abuzaina, +ManHyuk, Marek ŠUppa, MathSquared, Mats Linander, Matt Wytock, Matthew Daley, Maximilian Bachl, +mdymczyk, melvyniandrag, Michael Case, Mike Traynor, miqlas, Namrata-Ibm, Nathan Luehr, +Nathan Van Doorn, Noa Ezra, Nolan Liu, Oleg Zabluda, opensourcemattress, Ouwen Huang, +Paul Van Eck, peisong, Peng Yu, PinkySan, pks, powderluv, Qiao Hai-Jun, Qiao Longfei, +Rajendra Arora, Ralph Tang, resec, Robin Richtsfeld, Rohan Varma, Ryohei Kuroki, SaintNazaire, +Samuel He, Sandeep Dcunha, sandipmgiri, Sang Han, scott, Scott Mudge, Se-Won Kim, Simon Perkins, +Simone Cirillo, Steffen Schmitz, Suvojit Manna, Sylvus, Taehoon Lee, Ted Chang, Thomas Deegan, +Till Hoffmann, Tim, Toni Kunic, Toon Verstraelen, Tristan Rice, Urs KöSter, Utkarsh Upadhyay, +Vish (Ishaya) Abrams, Winnie Tsang, Yan Chen, Yan Facai (颜发才), Yi Yang, Yong Tang, +Youssef Hesham, Yuan (Terry) Tang, Zhengsheng Wei, zxcqwe4906, 张志豪, 田传武 + +We are also grateful to all who filed issues or helped resolve them, asked and +answered questions, and were part of inspiring discussions. + +# Release 1.4.1 + +## Bug Fixes and Other Changes +* `LinearClassifier` fix for CloudML Engine. + +# Release 1.4.0 + +## Major Features And Improvements +* `tf.keras` is now part of the core TensorFlow API. +* [`tf.data`](http://tensorflow.org/programmers_guide/datasets) is now part of + the core TensorFlow API. + * The API is now subject to backwards compatibility guarantees. + # Release 1.4.0 ## Major Features And Improvements diff --git a/WORKSPACE b/WORKSPACE index b40913801ba8e3c8ee73f7ba69540b520ad698a6..7ae39374f18efd3bddb9aae9bb8dba5c13a61dcc 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -2,11 +2,11 @@ workspace(name = "org_tensorflow") http_archive( name = "io_bazel_rules_closure", - sha256 = "110fe68753413777944b473c25eed6368c4a0487cee23a7bac1b13cc49d3e257", - strip_prefix = "rules_closure-4af89ef1db659eb41f110df189b67d4cf14073e1", + sha256 = "6691c58a2cd30a86776dd9bb34898b041e37136f2dc7e24cadaeaf599c95c657", + strip_prefix = "rules_closure-08039ba8ca59f64248bb3b6ae016460fe9c9914f", urls = [ - "https://mirror.bazel.build/github.com/bazelbuild/rules_closure/archive/4af89ef1db659eb41f110df189b67d4cf14073e1.tar.gz", - "https://github.com/bazelbuild/rules_closure/archive/4af89ef1db659eb41f110df189b67d4cf14073e1.tar.gz", # 2017-08-28 + "https://mirror.bazel.build/github.com/bazelbuild/rules_closure/archive/08039ba8ca59f64248bb3b6ae016460fe9c9914f.tar.gz", + "https://github.com/bazelbuild/rules_closure/archive/08039ba8ca59f64248bb3b6ae016460fe9c9914f.tar.gz", # 2018-01-16 ], ) diff --git a/configure.py b/configure.py index cf562bdee8ef288e4c2938f50e5c6366ce05ccff..cf16ef483763733cc12c838ea92b144c6493f0b1 100644 --- a/configure.py +++ b/configure.py @@ -34,8 +34,10 @@ except ImportError: _TF_BAZELRC = os.path.join(os.path.dirname(os.path.abspath(__file__)), '.tf_configure.bazelrc') -_DEFAULT_CUDA_VERSION = '8.0' -_DEFAULT_CUDNN_VERSION = '6' +_TF_WORKSPACE = os.path.join(os.path.dirname(os.path.abspath(__file__)), + 'WORKSPACE') +_DEFAULT_CUDA_VERSION = '9.0' +_DEFAULT_CUDNN_VERSION = '7' _DEFAULT_CUDA_COMPUTE_CAPABILITIES = '3.5,5.2' _DEFAULT_CUDA_PATH = '/usr/local/cuda' _DEFAULT_CUDA_PATH_LINUX = '/opt/cuda' @@ -44,6 +46,13 @@ _DEFAULT_CUDA_PATH_WIN = ('C:/Program Files/NVIDIA GPU Computing ' _TF_OPENCL_VERSION = '1.2' _DEFAULT_COMPUTECPP_TOOLKIT_PATH = '/usr/local/computecpp' _DEFAULT_TRISYCL_INCLUDE_DIR = '/usr/local/triSYCL/include' +_SUPPORTED_ANDROID_NDK_VERSIONS = [10, 11, 12, 13, 14, 15] + +_DEFAULT_PROMPT_ASK_ATTEMPTS = 10 + + +class UserInputError(Exception): + pass def is_windows(): @@ -158,7 +167,7 @@ def get_python_path(environ_cp, python_bin_path): try: library_paths = run_shell( [python_bin_path, '-c', - 'import site; print("\\n".join(site.getsitepackages()))']).split("\n") + 'import site; print("\\n".join(site.getsitepackages()))']).split('\n') except subprocess.CalledProcessError: library_paths = [run_shell( [python_bin_path, '-c', @@ -256,19 +265,6 @@ def reset_tf_configure_bazelrc(): f.write('import %workspace%/.tf_configure.bazelrc\n') -def run_gen_git_source(environ_cp): - """Run the gen_git_source to create links. - - The links are for bazel to track dependencies for git hash propagation. - - Args: - environ_cp: copy of the os.environ. - """ - cmd = '"%s" tensorflow/tools/git/gen_git_source.py --configure %s' % ( - environ_cp.get('PYTHON_BIN_PATH'), os.getcwd()) - os.system(cmd) - - def cleanup_makefile(): """Delete any leftover BUILD files from the Makefile build. @@ -306,6 +302,12 @@ def get_var(environ_cp, Returns: boolean value of the variable. + + Raises: + UserInputError: if an environment variable is set, but it cannot be + interpreted as a boolean indicator, assume that the user has made a + scripting error, and will continue to provide invalid input. + Raise the error to avoid infinitely looping. """ if not question: question = 'Do you wish to build TensorFlow with %s support?' % query_item @@ -323,6 +325,23 @@ def get_var(environ_cp, question += ' [y/N]: ' var = environ_cp.get(var_name) + if var is not None: + var_content = var.strip().lower() + true_strings = ('1', 't', 'true', 'y', 'yes') + false_strings = ('0', 'f', 'false', 'n', 'no') + if var_content in true_strings: + var = True + elif var_content in false_strings: + var = False + else: + raise UserInputError( + 'Environment variable %s must be set as a boolean indicator.\n' + 'The following are accepted as TRUE : %s.\n' + 'The following are accepted as FALSE: %s.\n' + 'Current value is %s.' % ( + var_name, ', '.join(true_strings), ', '.join(false_strings), + var)) + while var is None: user_input_origin = get_input(question) user_input = user_input_origin.strip().lower() @@ -509,6 +528,21 @@ def set_tf_cuda_clang(environ_cp): no_reply=no_reply) +def set_tf_download_clang(environ_cp): + """Set TF_DOWNLOAD_CLANG action_env.""" + question = 'Do you want to download a fresh release of clang? (Experimental)' + yes_reply = 'Clang will be downloaded and used to compile tensorflow.' + no_reply = 'Clang will not be downloaded.' + set_action_env_var( + environ_cp, + 'TF_DOWNLOAD_CLANG', + None, + False, + question=question, + yes_reply=yes_reply, + no_reply=no_reply) + + def get_from_env_or_user_or_default(environ_cp, var_name, ask_for_var, var_default): """Get var_name either from env, or user or default. @@ -557,6 +591,219 @@ def set_clang_cuda_compiler_path(environ_cp): clang_cuda_compiler_path) +def prompt_loop_or_load_from_env( + environ_cp, + var_name, + var_default, + ask_for_var, + check_success, + error_msg, + suppress_default_error=False, + n_ask_attempts=_DEFAULT_PROMPT_ASK_ATTEMPTS +): + """Loop over user prompts for an ENV param until receiving a valid response. + + For the env param var_name, read from the environment or verify user input + until receiving valid input. When done, set var_name in the environ_cp to its + new value. + + Args: + environ_cp: (Dict) copy of the os.environ. + var_name: (String) string for name of environment variable, e.g. "TF_MYVAR". + var_default: (String) default value string. + ask_for_var: (String) string for how to ask for user input. + check_success: (Function) function that takes one argument and returns a + boolean. Should return True if the value provided is considered valid. May + contain a complex error message if error_msg does not provide enough + information. In that case, set suppress_default_error to True. + error_msg: (String) String with one and only one '%s'. Formatted with each + invalid response upon check_success(input) failure. + suppress_default_error: (Bool) Suppress the above error message in favor of + one from the check_success function. + n_ask_attempts: (Integer) Number of times to query for valid input before + raising an error and quitting. + + Returns: + [String] The value of var_name after querying for input. + + Raises: + UserInputError: if a query has been attempted n_ask_attempts times without + success, assume that the user has made a scripting error, and will + continue to provide invalid input. Raise the error to avoid infinitely + looping. + """ + default = environ_cp.get(var_name) or var_default + full_query = '%s [Default is %s]: ' % ( + ask_for_var, + default, + ) + + for _ in range(n_ask_attempts): + val = get_from_env_or_user_or_default(environ_cp, + var_name, + full_query, + default) + if check_success(val): + break + if not suppress_default_error: + print(error_msg % val) + environ_cp[var_name] = '' + else: + raise UserInputError('Invalid %s setting was provided %d times in a row. ' + 'Assuming to be a scripting mistake.' % + (var_name, n_ask_attempts)) + + environ_cp[var_name] = val + return val + + +def create_android_ndk_rule(environ_cp): + """Set ANDROID_NDK_HOME and write Android NDK WORKSPACE rule.""" + if is_windows() or is_cygwin(): + default_ndk_path = cygpath('%s/Android/Sdk/ndk-bundle' % + environ_cp['APPDATA']) + elif is_macos(): + default_ndk_path = '%s/library/Android/Sdk/ndk-bundle' % environ_cp['HOME'] + else: + default_ndk_path = '%s/Android/Sdk/ndk-bundle' % environ_cp['HOME'] + + def valid_ndk_path(path): + return (os.path.exists(path) and + os.path.exists(os.path.join(path, 'source.properties'))) + + android_ndk_home_path = prompt_loop_or_load_from_env( + environ_cp, + var_name='ANDROID_NDK_HOME', + var_default=default_ndk_path, + ask_for_var='Please specify the home path of the Android NDK to use.', + check_success=valid_ndk_path, + error_msg=('The path %s or its child file "source.properties" ' + 'does not exist.') + ) + + write_android_ndk_workspace_rule(android_ndk_home_path) + + +def create_android_sdk_rule(environ_cp): + """Set Android variables and write Android SDK WORKSPACE rule.""" + if is_windows() or is_cygwin(): + default_sdk_path = cygpath('%s/Android/Sdk' % environ_cp['APPDATA']) + elif is_macos(): + default_sdk_path = '%s/library/Android/Sdk/ndk-bundle' % environ_cp['HOME'] + else: + default_sdk_path = '%s/Android/Sdk' % environ_cp['HOME'] + + def valid_sdk_path(path): + return (os.path.exists(path) and + os.path.exists(os.path.join(path, 'platforms')) and + os.path.exists(os.path.join(path, 'build-tools'))) + + android_sdk_home_path = prompt_loop_or_load_from_env( + environ_cp, + var_name='ANDROID_SDK_HOME', + var_default=default_sdk_path, + ask_for_var='Please specify the home path of the Android SDK to use.', + check_success=valid_sdk_path, + error_msg=('Either %s does not exist, or it does not contain the ' + 'subdirectories "platforms" and "build-tools".')) + + platforms = os.path.join(android_sdk_home_path, 'platforms') + api_levels = sorted(os.listdir(platforms)) + api_levels = [x.replace('android-', '') for x in api_levels] + + def valid_api_level(api_level): + return os.path.exists(os.path.join(android_sdk_home_path, + 'platforms', + 'android-' + api_level)) + + android_api_level = prompt_loop_or_load_from_env( + environ_cp, + var_name='ANDROID_API_LEVEL', + var_default=api_levels[-1], + ask_for_var=('Please specify the Android SDK API level to use. ' + '[Available levels: %s]') % api_levels, + check_success=valid_api_level, + error_msg='Android-%s is not present in the SDK path.') + + build_tools = os.path.join(android_sdk_home_path, 'build-tools') + versions = sorted(os.listdir(build_tools)) + + def valid_build_tools(version): + return os.path.exists(os.path.join(android_sdk_home_path, + 'build-tools', + version)) + + android_build_tools_version = prompt_loop_or_load_from_env( + environ_cp, + var_name='ANDROID_BUILD_TOOLS_VERSION', + var_default=versions[-1], + ask_for_var=('Please specify an Android build tools version to use. ' + '[Available versions: %s]') % versions, + check_success=valid_build_tools, + error_msg=('The selected SDK does not have build-tools version %s ' + 'available.')) + + write_android_sdk_workspace_rule(android_sdk_home_path, + android_build_tools_version, + android_api_level) + + +def write_android_sdk_workspace_rule(android_sdk_home_path, + android_build_tools_version, + android_api_level): + print('Writing android_sdk_workspace rule.\n') + with open(_TF_WORKSPACE, 'a') as f: + f.write(""" +android_sdk_repository( + name="androidsdk", + api_level=%s, + path="%s", + build_tools_version="%s")\n +""" % (android_api_level, android_sdk_home_path, android_build_tools_version)) + + +def write_android_ndk_workspace_rule(android_ndk_home_path): + print('Writing android_ndk_workspace rule.') + ndk_api_level = check_ndk_level(android_ndk_home_path) + if int(ndk_api_level) not in _SUPPORTED_ANDROID_NDK_VERSIONS: + print('WARNING: The API level of the NDK in %s is %s, which is not ' + 'supported by Bazel (officially supported versions: %s). Please use ' + 'another version. Compiling Android targets may result in confusing ' + 'errors.\n' % (android_ndk_home_path, ndk_api_level, + _SUPPORTED_ANDROID_NDK_VERSIONS)) + with open(_TF_WORKSPACE, 'a') as f: + f.write(""" +android_ndk_repository( + name="androidndk", + path="%s", + api_level=%s)\n +""" % (android_ndk_home_path, ndk_api_level)) + + +def check_ndk_level(android_ndk_home_path): + """Check the revision number of an Android NDK path.""" + properties_path = '%s/source.properties' % android_ndk_home_path + if is_windows() or is_cygwin(): + properties_path = cygpath(properties_path) + with open(properties_path, 'r') as f: + filedata = f.read() + + revision = re.search(r'Pkg.Revision = (\d+)', filedata) + if revision: + return revision.group(1) + return None + + +def workspace_has_any_android_rule(): + """Check the WORKSPACE for existing android_*_repository rules.""" + with open(_TF_WORKSPACE, 'r') as f: + workspace = f.read() + has_any_rule = re.search(r'^android_[ns]dk_repository', + workspace, + re.MULTILINE) + return has_any_rule + + def set_gcc_host_compiler_path(environ_cp): """Set GCC_HOST_COMPILER_PATH.""" default_gcc_host_compiler_path = which('gcc') or '' @@ -566,23 +813,16 @@ def set_gcc_host_compiler_path(environ_cp): # os.readlink is only available in linux default_gcc_host_compiler_path = os.path.realpath(cuda_bin_symlink) - ask_gcc_path = ( - 'Please specify which gcc should be used by nvcc as the ' - 'host compiler. [Default is %s]: ') % default_gcc_host_compiler_path - while True: - gcc_host_compiler_path = get_from_env_or_user_or_default( - environ_cp, 'GCC_HOST_COMPILER_PATH', ask_gcc_path, - default_gcc_host_compiler_path) - - if os.path.exists(gcc_host_compiler_path): - break - - # Reset and retry - print('Invalid gcc path. %s cannot be found' % gcc_host_compiler_path) - environ_cp['GCC_HOST_COMPILER_PATH'] = '' + gcc_host_compiler_path = prompt_loop_or_load_from_env( + environ_cp, + var_name='GCC_HOST_COMPILER_PATH', + var_default=default_gcc_host_compiler_path, + ask_for_var= + 'Please specify which gcc should be used by nvcc as the host compiler.', + check_success=os.path.exists, + error_msg='Invalid gcc path. %s cannot be found.', + ) - # Set GCC_HOST_COMPILER_PATH - environ_cp['GCC_HOST_COMPILER_PATH'] = gcc_host_compiler_path write_action_env_to_bazelrc('GCC_HOST_COMPILER_PATH', gcc_host_compiler_path) @@ -592,7 +832,7 @@ def set_tf_cuda_version(environ_cp): 'Please specify the CUDA SDK version you want to use, ' 'e.g. 7.0. [Leave empty to default to CUDA %s]: ') % _DEFAULT_CUDA_VERSION - while True: + for _ in range(_DEFAULT_PROMPT_ASK_ATTEMPTS): # Configure the Cuda SDK version to use. tf_cuda_version = get_from_env_or_user_or_default( environ_cp, 'TF_CUDA_VERSION', ask_cuda_version, _DEFAULT_CUDA_VERSION) @@ -630,6 +870,11 @@ def set_tf_cuda_version(environ_cp): environ_cp['TF_CUDA_VERSION'] = '' environ_cp['CUDA_TOOLKIT_PATH'] = '' + else: + raise UserInputError('Invalid TF_CUDA_SETTING setting was provided %d ' + 'times in a row. Assuming to be a scripting mistake.' % + _DEFAULT_PROMPT_ASK_ATTEMPTS) + # Set CUDA_TOOLKIT_PATH and TF_CUDA_VERSION environ_cp['CUDA_TOOLKIT_PATH'] = cuda_toolkit_path write_action_env_to_bazelrc('CUDA_TOOLKIT_PATH', cuda_toolkit_path) @@ -643,7 +888,7 @@ def set_tf_cudnn_version(environ_cp): 'Please specify the cuDNN version you want to use. ' '[Leave empty to default to cuDNN %s.0]: ') % _DEFAULT_CUDNN_VERSION - while True: + for _ in range(_DEFAULT_PROMPT_ASK_ATTEMPTS): tf_cudnn_version = get_from_env_or_user_or_default( environ_cp, 'TF_CUDNN_VERSION', ask_cudnn_version, _DEFAULT_CUDNN_VERSION) @@ -702,6 +947,10 @@ def set_tf_cudnn_version(environ_cp): print('%s.%s' % (cudnn_path_from_ldconfig, tf_cudnn_version)) environ_cp['TF_CUDNN_VERSION'] = '' + else: + raise UserInputError('Invalid TF_CUDNN setting was provided %d ' + 'times in a row. Assuming to be a scripting mistake.' % + _DEFAULT_PROMPT_ASK_ATTEMPTS) # Set CUDNN_INSTALL_PATH and TF_CUDNN_VERSION environ_cp['CUDNN_INSTALL_PATH'] = cudnn_install_path @@ -810,90 +1059,83 @@ def set_other_cuda_vars(environ_cp): def set_host_cxx_compiler(environ_cp): """Set HOST_CXX_COMPILER.""" default_cxx_host_compiler = which('g++') or '' - ask_cxx_host_compiler = ( - 'Please specify which C++ compiler should be used as' - ' the host C++ compiler. [Default is %s]: ') % default_cxx_host_compiler - while True: - host_cxx_compiler = get_from_env_or_user_or_default( - environ_cp, 'HOST_CXX_COMPILER', ask_cxx_host_compiler, - default_cxx_host_compiler) - if os.path.exists(host_cxx_compiler): - break - - # Reset and retry - print('Invalid C++ compiler path. %s cannot be found' % host_cxx_compiler) - environ_cp['HOST_CXX_COMPILER'] = '' + host_cxx_compiler = prompt_loop_or_load_from_env( + environ_cp, + var_name='HOST_CXX_COMPILER', + var_default=default_cxx_host_compiler, + ask_for_var=('Please specify which C++ compiler should be used as the ' + 'host C++ compiler.'), + check_success=os.path.exists, + error_msg='Invalid C++ compiler path. %s cannot be found.', + ) - # Set HOST_CXX_COMPILER - environ_cp['HOST_CXX_COMPILER'] = host_cxx_compiler write_action_env_to_bazelrc('HOST_CXX_COMPILER', host_cxx_compiler) def set_host_c_compiler(environ_cp): """Set HOST_C_COMPILER.""" default_c_host_compiler = which('gcc') or '' - ask_c_host_compiler = ( - 'Please specify which C compiler should be used as the' - ' host C compiler. [Default is %s]: ') % default_c_host_compiler - - while True: - host_c_compiler = get_from_env_or_user_or_default( - environ_cp, 'HOST_C_COMPILER', ask_c_host_compiler, - default_c_host_compiler) - if os.path.exists(host_c_compiler): - break - # Reset and retry - print('Invalid C compiler path. %s cannot be found' % host_c_compiler) - environ_cp['HOST_C_COMPILER'] = '' + host_c_compiler = prompt_loop_or_load_from_env( + environ_cp, + var_name='HOST_C_COMPILER', + var_default=default_c_host_compiler, + ask_for_var=('Please specify which C compiler should be used as the host' + 'C compiler.'), + check_success=os.path.exists, + error_msg='Invalid C compiler path. %s cannot be found.', + ) - # Set HOST_C_COMPILER - environ_cp['HOST_C_COMPILER'] = host_c_compiler write_action_env_to_bazelrc('HOST_C_COMPILER', host_c_compiler) def set_computecpp_toolkit_path(environ_cp): """Set COMPUTECPP_TOOLKIT_PATH.""" - ask_computecpp_toolkit_path = ('Please specify the location where ComputeCpp ' - 'for SYCL %s is installed. [Default is %s]: ' - ) % (_TF_OPENCL_VERSION, - _DEFAULT_COMPUTECPP_TOOLKIT_PATH) - while True: - computecpp_toolkit_path = get_from_env_or_user_or_default( - environ_cp, 'COMPUTECPP_TOOLKIT_PATH', ask_computecpp_toolkit_path, - _DEFAULT_COMPUTECPP_TOOLKIT_PATH) + def toolkit_exists(toolkit_path): + """Check if a computecpp toolkit path is valid.""" if is_linux(): sycl_rt_lib_path = 'lib/libComputeCpp.so' else: sycl_rt_lib_path = '' - sycl_rt_lib_path_full = os.path.join(computecpp_toolkit_path, + sycl_rt_lib_path_full = os.path.join(toolkit_path, sycl_rt_lib_path) - if os.path.exists(sycl_rt_lib_path_full): - break + exists = os.path.exists(sycl_rt_lib_path_full) + if not exists: + print('Invalid SYCL %s library path. %s cannot be found' % + (_TF_OPENCL_VERSION, sycl_rt_lib_path_full)) + return exists - print('Invalid SYCL %s library path. %s cannot be found' % - (_TF_OPENCL_VERSION, sycl_rt_lib_path_full)) - environ_cp['COMPUTECPP_TOOLKIT_PATH'] = '' + computecpp_toolkit_path = prompt_loop_or_load_from_env( + environ_cp, + var_name='COMPUTECPP_TOOLKIT_PATH', + var_default=_DEFAULT_COMPUTECPP_TOOLKIT_PATH, + ask_for_var=( + 'Please specify the location where ComputeCpp for SYCL %s is ' + 'installed.' % _TF_OPENCL_VERSION), + check_success=toolkit_exists, + error_msg='Invalid SYCL compiler path. %s cannot be found.', + suppress_default_error=True) - # Set COMPUTECPP_TOOLKIT_PATH - environ_cp['COMPUTECPP_TOOLKIT_PATH'] = computecpp_toolkit_path write_action_env_to_bazelrc('COMPUTECPP_TOOLKIT_PATH', computecpp_toolkit_path) + def set_trisycl_include_dir(environ_cp): - """Set TRISYCL_INCLUDE_DIR""" + """Set TRISYCL_INCLUDE_DIR.""" + ask_trisycl_include_dir = ('Please specify the location of the triSYCL ' 'include directory. (Use --config=sycl_trisycl ' 'when building with Bazel) ' '[Default is %s]: ' - ) % (_DEFAULT_TRISYCL_INCLUDE_DIR) + ) % (_DEFAULT_TRISYCL_INCLUDE_DIR) + while True: trisycl_include_dir = get_from_env_or_user_or_default( - environ_cp, 'TRISYCL_INCLUDE_DIR', ask_trisycl_include_dir, - _DEFAULT_TRISYCL_INCLUDE_DIR) + environ_cp, 'TRISYCL_INCLUDE_DIR', ask_trisycl_include_dir, + _DEFAULT_TRISYCL_INCLUDE_DIR) if os.path.exists(trisycl_include_dir): break @@ -905,50 +1147,30 @@ def set_trisycl_include_dir(environ_cp): write_action_env_to_bazelrc('TRISYCL_INCLUDE_DIR', trisycl_include_dir) -def set_trisycl_include_dir(environ_cp): - """Set TRISYCL_INCLUDE_DIR.""" - ask_trisycl_include_dir = ('Please specify the location of the triSYCL ' - 'include directory. (Use --config=sycl_trisycl ' - 'when building with Bazel) ' - '[Default is %s]: ') % ( - _DEFAULT_TRISYCL_INCLUDE_DIR) - while True: - trisycl_include_dir = get_from_env_or_user_or_default( - environ_cp, 'TRISYCL_INCLUDE_DIR', ask_trisycl_include_dir, - _DEFAULT_TRISYCL_INCLUDE_DIR) - if os.path.exists(trisycl_include_dir): - break - - print('Invalid triSYCL include directory, %s cannot be found' % - (trisycl_include_dir)) - - # Set TRISYCL_INCLUDE_DIR - environ_cp['TRISYCL_INCLUDE_DIR'] = trisycl_include_dir - write_action_env_to_bazelrc('TRISYCL_INCLUDE_DIR', trisycl_include_dir) - def set_mpi_home(environ_cp): """Set MPI_HOME.""" + default_mpi_home = which('mpirun') or which('mpiexec') or '' default_mpi_home = os.path.dirname(os.path.dirname(default_mpi_home)) - ask_mpi_home = ('Please specify the MPI toolkit folder. [Default is %s]: ' - ) % default_mpi_home - while True: - mpi_home = get_from_env_or_user_or_default(environ_cp, 'MPI_HOME', - ask_mpi_home, default_mpi_home) - - if os.path.exists(os.path.join(mpi_home, 'include')) and os.path.exists( - os.path.join(mpi_home, 'lib')): - break - - print('Invalid path to the MPI Toolkit. %s or %s cannot be found' % - (os.path.join(mpi_home, 'include'), - os.path.exists(os.path.join(mpi_home, 'lib')))) - environ_cp['MPI_HOME'] = '' + def valid_mpi_path(mpi_home): + exists = (os.path.exists(os.path.join(mpi_home, 'include')) and + os.path.exists(os.path.join(mpi_home, 'lib'))) + if not exists: + print('Invalid path to the MPI Toolkit. %s or %s cannot be found' % + (os.path.join(mpi_home, 'include'), + os.path.exists(os.path.join(mpi_home, 'lib')))) + return exists - # Set MPI_HOME - environ_cp['MPI_HOME'] = str(mpi_home) + _ = prompt_loop_or_load_from_env( + environ_cp, + var_name='MPI_HOME', + var_default=default_mpi_home, + ask_for_var='Please specify the MPI toolkit folder.', + check_success=valid_mpi_path, + error_msg='', + suppress_default_error=True) def set_other_mpi_vars(environ_cp): @@ -983,47 +1205,25 @@ def set_other_mpi_vars(environ_cp): raise ValueError('Cannot find the MPI library file in %s/lib' % mpi_home) -def set_mkl(): - write_to_bazelrc('build:mkl --define using_mkl=true') - write_to_bazelrc('build:mkl -c opt') - print( - 'Add "--config=mkl" to your bazel command to build with MKL ' - 'support.\nPlease note that MKL on MacOS or windows is still not ' - 'supported.\nIf you would like to use a local MKL instead of ' - 'downloading, please set the environment variable \"TF_MKL_ROOT\" every ' - 'time before build.') - - -def set_monolithic(): - # Add --config=monolithic to your bazel command to use a mostly-static - # build and disable modular op registration support (this will revert to - # loading TensorFlow with RTLD_GLOBAL in Python). By default (without - # --config=monolithic), TensorFlow will build with a dependence on - # //tensorflow:libtensorflow_framework.so. - write_to_bazelrc('build:monolithic --define framework_shared_object=false') - # For projects which use TensorFlow as part of a Bazel build process, putting - # nothing in a bazelrc will default to a monolithic build. The following line - # opts in to modular op registration support by default: - write_to_bazelrc('build --define framework_shared_object=true') - - -def create_android_bazelrc_configs(): - # Flags for --config=android - write_to_bazelrc('build:android --crosstool_top=//external:android/crosstool') - write_to_bazelrc( - 'build:android --host_crosstool_top=@bazel_tools//tools/cpp:toolchain') - # Flags for --config=android_arm - write_to_bazelrc('build:android_arm --config=android') - write_to_bazelrc('build:android_arm --cpu=armeabi-v7a') - # Flags for --config=android_arm64 - write_to_bazelrc('build:android_arm64 --config=android') - write_to_bazelrc('build:android_arm64 --cpu=arm64-v8a') - - def set_grpc_build_flags(): write_to_bazelrc('build --define grpc_no_ares=true') +def set_windows_build_flags(): + if is_windows(): + # The non-monolithic build is not supported yet + write_to_bazelrc('build --config monolithic') + # Suppress warning messages + write_to_bazelrc('build --copt=-w --host_copt=-w') + # Output more verbose information when something goes wrong + write_to_bazelrc('build --verbose_failures') + + +def config_info_line(name, help_text): + """Helper function to print formatted help text for Bazel config options.""" + print('\t--config=%-12s\t# %s' % (name, help_text)) + + def main(): # Make a copy of os.environ to be clear when functions and getting and setting # environment variables. @@ -1034,7 +1234,6 @@ def main(): reset_tf_configure_bazelrc() cleanup_makefile() setup_python(environ_cp) - run_gen_git_source(environ_cp) if is_windows(): environ_cp['TF_NEED_S3'] = '0' @@ -1083,8 +1282,19 @@ def main(): set_tf_cuda_clang(environ_cp) if environ_cp.get('TF_CUDA_CLANG') == '1': - # Set up which clang we should use as the cuda / host compiler. - set_clang_cuda_compiler_path(environ_cp) + if not is_windows(): + # Ask if we want to download clang release while building. + set_tf_download_clang(environ_cp) + else: + # We use bazel's generated crosstool on Windows and there is no + # way to provide downloaded toolchain for that yet. + # TODO(ibiryukov): Investigate using clang as a cuda compiler on + # Windows. + environ_cp['TF_DOWNLOAD_CLANG'] = '0' + + if environ_cp.get('TF_DOWNLOAD_CLANG') != '1': + # Set up which clang we should use as the cuda / host compiler. + set_clang_cuda_compiler_path(environ_cp) else: # Set up which gcc nvcc should use as the host compiler # No need to set this on Windows @@ -1099,9 +1309,29 @@ def main(): set_grpc_build_flags() set_cc_opt_flags(environ_cp) - set_mkl() - set_monolithic() - create_android_bazelrc_configs() + set_windows_build_flags() + + if workspace_has_any_android_rule(): + print('The WORKSPACE file has at least one of ["android_sdk_repository", ' + '"android_ndk_repository"] already set. Will not ask to help ' + 'configure the WORKSPACE. Please delete the existing rules to ' + 'activate the helper.\n') + else: + if get_var( + environ_cp, 'TF_SET_ANDROID_WORKSPACE', 'android workspace', + False, + ('Would you like to interactively configure ./WORKSPACE for ' + 'Android builds?'), + 'Searching for NDK and SDK installations.', + 'Not configuring the WORKSPACE for Android builds.'): + create_android_ndk_rule(environ_cp) + create_android_sdk_rule(environ_cp) + + print('Preconfigured Bazel build configs. You can use any of the below by ' + 'adding "--config=<>" to your build command. See tools/bazel.rc for ' + 'more details.') + config_info_line('mkl', 'Build with MKL support.') + config_info_line('monolithic', 'Config for mostly static monolithic build.') if __name__ == '__main__': main() diff --git a/tensorflow/BUILD b/tensorflow/BUILD index bfebe8a5678a2c0508b31f5dd898eac22186a072..da37564697a7159518a6ba71271f911713e3e58e 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -364,11 +364,9 @@ config_setting( visibility = ["//visibility:public"], ) -# Make a dummy rule that we can change "default" in select statements to. -# to disable dependencies in copybara. config_setting( - name = "dummy_disabled_internal", - values = {"define": "with_dummy_disabled_internal=true"}, + name = "override_eigen_strong_inline", + values = {"define": "override_eigen_strong_inline=true"}, visibility = ["//visibility:public"], ) @@ -409,6 +407,8 @@ filegroup( "//tensorflow/c:all_files", "//tensorflow/cc:all_files", "//tensorflow/cc/saved_model:all_files", + "//tensorflow/cc/saved_model/python:all_files", + "//tensorflow/cc/tools:all_files", "//tensorflow/compiler/aot:all_files", "//tensorflow/compiler/aot/tests:all_files", "//tensorflow/compiler/jit:all_files", @@ -427,6 +427,7 @@ filegroup( "//tensorflow/compiler/xla/client:all_files", "//tensorflow/compiler/xla/client/lib:all_files", "//tensorflow/compiler/xla/legacy_flags:all_files", + "//tensorflow/compiler/xla/python:all_files", "//tensorflow/compiler/xla/service:all_files", "//tensorflow/compiler/xla/service/cpu:all_files", "//tensorflow/compiler/xla/service/gpu:all_files", @@ -452,6 +453,7 @@ filegroup( "//tensorflow/contrib/cloud:all_files", "//tensorflow/contrib/cloud/kernels:all_files", "//tensorflow/contrib/cluster_resolver:all_files", + "//tensorflow/contrib/coder:all_files", "//tensorflow/contrib/compiler:all_files", "//tensorflow/contrib/copy_graph:all_files", "//tensorflow/contrib/crf:all_files", @@ -461,10 +463,13 @@ filegroup( "//tensorflow/contrib/data/python/kernel_tests:all_files", "//tensorflow/contrib/data/python/ops:all_files", "//tensorflow/contrib/decision_trees/proto:all_files", + "//tensorflow/contrib/deprecated:all_files", "//tensorflow/contrib/distributions:all_files", + "//tensorflow/contrib/eager/proto:all_files", "//tensorflow/contrib/eager/python:all_files", "//tensorflow/contrib/estimator:all_files", "//tensorflow/contrib/factorization:all_files", + "//tensorflow/contrib/factorization/examples:all_files", "//tensorflow/contrib/factorization/kernels:all_files", "//tensorflow/contrib/ffmpeg:all_files", "//tensorflow/contrib/ffmpeg/default:all_files", @@ -475,6 +480,7 @@ filegroup( "//tensorflow/contrib/graph_editor:all_files", "//tensorflow/contrib/grid_rnn:all_files", "//tensorflow/contrib/hooks:all_files", + "//tensorflow/contrib/hvx/clock_cycle_profiling:all_files", "//tensorflow/contrib/hvx/hvx_ops_support_checker:all_files", "//tensorflow/contrib/image:all_files", "//tensorflow/contrib/input_pipeline:all_files", @@ -492,6 +498,8 @@ filegroup( "//tensorflow/contrib/layers/kernels:all_files", "//tensorflow/contrib/learn:all_files", "//tensorflow/contrib/learn/python/learn/datasets:all_files", + "//tensorflow/contrib/legacy_seq2seq:all_files", + "//tensorflow/contrib/libsvm:all_files", "//tensorflow/contrib/linalg:all_files", "//tensorflow/contrib/linear_optimizer:all_files", "//tensorflow/contrib/lite:all_files", @@ -516,15 +524,22 @@ filegroup( "//tensorflow/contrib/lookup:all_files", "//tensorflow/contrib/losses:all_files", "//tensorflow/contrib/makefile:all_files", + "//tensorflow/contrib/memory_stats:all_files", "//tensorflow/contrib/meta_graph_transform:all_files", "//tensorflow/contrib/metrics:all_files", "//tensorflow/contrib/model_pruning:all_files", - "//tensorflow/contrib/mpi_collectives:all_files", + "//tensorflow/contrib/model_pruning/examples/cifar10:all_files", + "//tensorflow/contrib/nccl:all_files", "//tensorflow/contrib/ndlstm:all_files", "//tensorflow/contrib/nearest_neighbor:all_files", "//tensorflow/contrib/nn:all_files", "//tensorflow/contrib/opt:all_files", + "//tensorflow/contrib/periodic_resample:all_files", "//tensorflow/contrib/predictor:all_files", + "//tensorflow/contrib/py2tf:all_files", + "//tensorflow/contrib/py2tf/convert:all_files", + "//tensorflow/contrib/py2tf/pyct:all_files", + "//tensorflow/contrib/py2tf/pyct/static_analysis:all_files", "//tensorflow/contrib/quantize:all_files", "//tensorflow/contrib/receptive_field:all_files", "//tensorflow/contrib/reduce_slice_ops:all_files", @@ -567,6 +582,7 @@ filegroup( "//tensorflow/contrib/util:all_files", "//tensorflow/contrib/verbs:all_files", "//tensorflow/core:all_files", + "//tensorflow/core/api_def:all_files", "//tensorflow/core/debug:all_files", "//tensorflow/core/distributed_runtime:all_files", "//tensorflow/core/distributed_runtime/rpc:all_files", @@ -577,6 +593,9 @@ filegroup( "//tensorflow/core/grappler/optimizers:all_files", "//tensorflow/core/grappler/utils:all_files", "//tensorflow/core/kernels:all_files", + "//tensorflow/core/kernels/batching_util:all_files", + "//tensorflow/core/kernels/data:all_files", + "//tensorflow/core/kernels/data/sql:all_files", "//tensorflow/core/kernels/fuzzing:all_files", "//tensorflow/core/kernels/hexagon:all_files", "//tensorflow/core/kernels/neon:all_files", @@ -591,6 +610,7 @@ filegroup( "//tensorflow/core/profiler/internal/advisor:all_files", "//tensorflow/core/util/ctc:all_files", "//tensorflow/core/util/tensor_bundle:all_files", + "//tensorflow/examples/adding_an_op:all_files", "//tensorflow/examples/android:all_files", "//tensorflow/examples/benchmark:all_files", "//tensorflow/examples/get_started/regression:all_files", @@ -598,10 +618,13 @@ filegroup( "//tensorflow/examples/image_retraining:all_files", "//tensorflow/examples/label_image:all_files", "//tensorflow/examples/learn:all_files", + "//tensorflow/examples/multibox_detector:all_files", "//tensorflow/examples/saved_model:all_files", "//tensorflow/examples/speech_commands:all_files", "//tensorflow/examples/tutorials/estimators:all_files", + "//tensorflow/examples/tutorials/layers:all_files", "//tensorflow/examples/tutorials/mnist:all_files", + "//tensorflow/examples/tutorials/monitors:all_files", "//tensorflow/examples/tutorials/word2vec:all_files", "//tensorflow/examples/wav_to_spectrogram:all_files", "//tensorflow/go:all_files", @@ -610,6 +633,7 @@ filegroup( "//tensorflow/java/src/main/native:all_files", "//tensorflow/python:all_files", "//tensorflow/python/data:all_files", + "//tensorflow/python/data/kernel_tests:all_files", "//tensorflow/python/data/ops:all_files", "//tensorflow/python/data/util:all_files", "//tensorflow/python/debug:all_files", @@ -623,6 +647,7 @@ filegroup( "//tensorflow/python/kernel_tests/random:all_files", "//tensorflow/python/ops/distributions:all_files", "//tensorflow/python/ops/linalg:all_files", + "//tensorflow/python/ops/losses:all_files", "//tensorflow/python/profiler:all_files", "//tensorflow/python/profiler/internal:all_files", "//tensorflow/python/saved_model:all_files", @@ -633,6 +658,7 @@ filegroup( "//tensorflow/tools/api/tests:all_files", "//tensorflow/tools/benchmark:all_files", "//tensorflow/tools/build_info:all_files", + "//tensorflow/tools/ci_build/gpu_build:all_files", "//tensorflow/tools/common:all_files", "//tensorflow/tools/compatibility:all_files", "//tensorflow/tools/dist_test/server:all_files", @@ -640,17 +666,20 @@ filegroup( "//tensorflow/tools/docker/notebooks:all_files", "//tensorflow/tools/docs:all_files", "//tensorflow/tools/git:all_files", + "//tensorflow/tools/graph_transforms:all_files", "//tensorflow/tools/mlpbtxt:all_files", "//tensorflow/tools/proto_text:all_files", "//tensorflow/tools/quantization:all_files", "//tensorflow/tools/test:all_files", "//tensorflow/user_ops:all_files", + "//third_party/eigen3:all_files", + "//third_party/fft2d:all_files", + "//third_party/flatbuffers:all_files", "//third_party/hadoop:all_files", - "//third_party/mpi:all_files", "//third_party/sycl:all_files", "//third_party/sycl/sycl:all_files", ], - visibility = [":__subpackages__"], + visibility = ["//visibility:public"], ) load( @@ -774,6 +803,7 @@ tf_cc_shared_object( "//tensorflow/cc:cc_ops", "//tensorflow/cc:client_session", "//tensorflow/cc:scope", + "//tensorflow/cc/profiler", "//tensorflow/core:tensorflow", ], ) diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD index ef7eb5a4d16b29aecc34f33cb41dd7cf9450c5f2..f258bcd95684cc58c2ead3886b3ce74e4af6c5aa 100644 --- a/tensorflow/c/BUILD +++ b/tensorflow/c/BUILD @@ -42,6 +42,7 @@ tf_cuda_library( "//tensorflow/core:core_cpu", "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/core:op_gen_lib", ], }), ) @@ -73,6 +74,7 @@ tf_cuda_library( "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", + "//tensorflow/core:op_gen_lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index bb41f92306b413d610bf115d144b15faa568ee14..6fc75a98f1e05c3971cb4546bd16f015c25b6709 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/cc/framework/scope_internal.h" #include "tensorflow/cc/ops/while_loop.h" #include "tensorflow/cc/saved_model/loader.h" +#include "tensorflow/core/framework/op_gen_lib.h" #endif #include "tensorflow/c/c_api_internal.h" #include "tensorflow/core/common_runtime/device_mgr.h" @@ -383,12 +384,11 @@ void TF_Reset_Helper(const TF_SessionOptions* opt, const char** containers, // be less than the total node count. Status ValidateNoCycles(const Graph& g) { // TODO(nolivia): check this on a subset of the graph instead of all of it. - int total_num_nodes = g.num_node_ids(); // A node is ready when all of its inputs have been visited. std::vector ready; - std::vector pending_count(total_num_nodes, 0); + std::vector pending_count(g.num_node_ids(), 0); - for (int i = 0; i < total_num_nodes; ++i) { + for (int i = 0; i < g.num_node_ids(); ++i) { const Node* n = g.FindNodeId(i); if (n == nullptr) continue; pending_count[i] = n->in_edges().size(); @@ -421,7 +421,7 @@ Status ValidateNoCycles(const Graph& g) { } } - if (processed < total_num_nodes) { + if (processed < g.num_nodes()) { std::vector nodes_in_cycle; for (int i = 0; i < pending_count.size() && nodes_in_cycle.size() < 3; ++i) { @@ -430,7 +430,7 @@ Status ValidateNoCycles(const Graph& g) { } } return errors::InvalidArgument( - "Graph is invalid, contains a cycle with ", total_num_nodes - processed, + "Graph is invalid, contains a cycle with ", g.num_nodes() - processed, " nodes, including: ", str_util::Join(nodes_in_cycle, ", ")); } return Status::OK(); @@ -580,6 +580,7 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src, status->status = InvalidArgument( "invalid string tensor encoding (string #", i, " of ", srcarray.size(), "): ", status->status.error_message()); + delete[] base; return nullptr; } dst += consumed; @@ -589,6 +590,7 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src, status->status = InvalidArgument( "invalid string tensor encoding (decoded ", (dst - base), " bytes, but the tensor is encoded in ", size, " bytes"); + delete[] base; return nullptr; } @@ -625,6 +627,73 @@ Status MessageToBuffer(const tensorflow::protobuf::Message& in, return Status::OK(); } +void RecordMutation(TF_Graph* graph, const TF_Operation& op, + const char* mutation_type) + EXCLUSIVE_LOCKS_REQUIRED(graph->mu) { + // If any session has already run this node_id, mark this session as + // unrunnable. + for (auto it : graph->sessions) { + if (it.first->last_num_graph_nodes > op.node.id()) { + it.second = FailedPrecondition( + "Operation '", op.node.DebugString(), "' was changed by ", + mutation_type, + " after it was run by a session. Nodes can be mutated " + "only before they are executed by a session. Either don't modify " + "nodes after running them or create a new session."); + } + } +} + +namespace { + +// Helper method that creates a shape handle for a shape described by dims. +tensorflow::shape_inference::ShapeHandle ShapeHandleFromDims( + tensorflow::shape_inference::InferenceContext* ic, int num_dims, + const int64_t* dims) { + if (num_dims != -1) { + std::vector dim_vec; + dim_vec.reserve(num_dims); + for (int i = 0; i < num_dims; ++i) { + dim_vec.push_back(ic->MakeDim(dims[i])); + } + return ic->MakeShape(dim_vec); + } else { + return ic->UnknownShape(); + } +} + +} // namespace + +void TF_GraphSetOutputHandleShapesAndTypes(TF_Graph* graph, TF_Output output, + int num_shapes_and_types, + const int64_t** shapes, + const int* ranks, + const TF_DataType* types, + TF_Status* status) { + Node* node = &output.oper->node; + + mutex_lock l(graph->mu); + tensorflow::shape_inference::InferenceContext* ic = + graph->refiner.GetContext(node); + if (ic == nullptr) { + status->status = + InvalidArgument("Node ", node->name(), " was not found in the graph"); + return; + } + + auto shape_and_type_vec = + std::vector( + num_shapes_and_types); + for (int i = 0; i < num_shapes_and_types; ++i) { + tensorflow::shape_inference::ShapeHandle shape_handle = + ShapeHandleFromDims(ic, ranks[i], shapes[i]); + shape_and_type_vec[i] = tensorflow::shape_inference::ShapeAndType( + shape_handle, static_cast(types[i])); + } + + ic->set_output_handle_shapes_and_types(output.index, shape_and_type_vec); +} + // Helpers for loading a TensorFlow plugin (a .so file). Status LoadLibrary(const char* library_filename, void** result, const void** buf, size_t* len); @@ -930,7 +999,6 @@ void TF_GraphSetTensorShape(TF_Graph* graph, TF_Output output, Node* node = &output.oper->node; mutex_lock l(graph->mu); - // Set the shape. tensorflow::shape_inference::InferenceContext* ic = graph->refiner.GetContext(node); if (ic == nullptr) { @@ -938,18 +1006,8 @@ void TF_GraphSetTensorShape(TF_Graph* graph, TF_Output output, InvalidArgument("Node ", node->name(), " was not found in the graph"); return; } - - tensorflow::shape_inference::ShapeHandle new_shape; - if (num_dims != -1) { - std::vector dim_vec; - dim_vec.reserve(num_dims); - for (int i = 0; i < num_dims; ++i) { - dim_vec.push_back(ic->MakeDim(dims[i])); - } - new_shape = ic->MakeShape(dim_vec); - } else { - new_shape = ic->UnknownShape(); - } + tensorflow::shape_inference::ShapeHandle new_shape = + tensorflow::ShapeHandleFromDims(ic, num_dims, dims); status->status = graph->refiner.SetShape(node, output.index, new_shape); } @@ -1143,6 +1201,13 @@ void TF_SetAttrTypeList(TF_OperationDescription* desc, const char* attr_name, reinterpret_cast(values), num_values)); } +void TF_SetAttrFuncName(TF_OperationDescription* desc, const char* attr_name, + const char* value, size_t length) { + tensorflow::NameAttrList func_name; + func_name.set_name(std::string(value, value + length)); + desc->node_builder.Attr(attr_name, func_name); +} + void TF_SetAttrShape(TF_OperationDescription* desc, const char* attr_name, const int64_t* dims, int num_dims) { PartialTensorShape shape; @@ -1745,7 +1810,6 @@ void TF_OperationToNodeDef(TF_Operation* oper, TF_Buffer* output_node_def, TF_Graph::TF_Graph() : graph(tensorflow::OpRegistry::Global()), refiner(graph.versions().producer(), graph.op_registry()), - num_sessions(0), delete_requested(false), parent(nullptr), parent_inputs(nullptr) {} @@ -1755,7 +1819,7 @@ TF_Graph* TF_NewGraph() { return new TF_Graph; } void TF_DeleteGraph(TF_Graph* g) { g->mu.lock(); g->delete_requested = true; - const bool del = g->num_sessions == 0; + const bool del = g->sessions.empty(); g->mu.unlock(); if (del) delete g; } @@ -1835,6 +1899,16 @@ void TF_ImportGraphDefOptionsSetPrefix(TF_ImportGraphDefOptions* opts, opts->opts.prefix = prefix; } +void TF_ImportGraphDefOptionsSetUniquifyNames(TF_ImportGraphDefOptions* opts, + unsigned char uniquify_names) { + opts->opts.uniquify_names = uniquify_names; +} + +void TF_ImportGraphDefOptionsSetUniquifyPrefix(TF_ImportGraphDefOptions* opts, + unsigned char uniquify_prefix) { + opts->opts.uniquify_prefix = uniquify_prefix; +} + void TF_ImportGraphDefOptionsAddInputMapping(TF_ImportGraphDefOptions* opts, const char* src_name, int src_index, TF_Output dst) { @@ -1892,12 +1966,12 @@ void TF_ImportGraphDefResultsReturnOperations(TF_ImportGraphDefResults* results, *opers = results->return_nodes.data(); } -void TF_ImportGraphDefResultsUnusedInputMappings( - TF_ImportGraphDefResults* results, int* num_unused_input_mappings, +void TF_ImportGraphDefResultsMissingUnusedInputMappings( + TF_ImportGraphDefResults* results, int* num_missing_unused_input_mappings, const char*** src_names, int** src_indexes) { - *num_unused_input_mappings = results->unused_key_names.size(); - *src_names = results->unused_key_names.data(); - *src_indexes = results->unused_key_indexes.data(); + *num_missing_unused_input_mappings = results->missing_unused_key_names.size(); + *src_names = results->missing_unused_key_names.data(); + *src_indexes = results->missing_unused_key_indexes.data(); } void TF_DeleteImportGraphDefResults(TF_ImportGraphDefResults* results) { @@ -1937,18 +2011,21 @@ static void GraphImportGraphDefLocked(TF_Graph* graph, const GraphDef& def, tf_results->return_nodes[i] = ToOperation(results.return_nodes[i]); } - // Populate unused map keys - DCHECK(tf_results->unused_key_names.empty()); - DCHECK(tf_results->unused_key_indexes.empty()); - DCHECK(tf_results->unused_key_names_data.empty()); - tf_results->unused_key_names.resize(results.unused_input_map_keys.size()); - tf_results->unused_key_indexes.resize(results.unused_input_map_keys.size()); - for (int i = 0; i < results.unused_input_map_keys.size(); ++i) { - TensorId id = results.unused_input_map_keys[i]; - tf_results->unused_key_names_data.push_back(id.first.ToString()); - tf_results->unused_key_names[i] = - tf_results->unused_key_names_data.back().c_str(); - tf_results->unused_key_indexes[i] = id.second; + // Populate missing unused map keys + DCHECK(tf_results->missing_unused_key_names.empty()); + DCHECK(tf_results->missing_unused_key_indexes.empty()); + DCHECK(tf_results->missing_unused_key_names_data.empty()); + + size_t size = results.missing_unused_input_map_keys.size(); + tf_results->missing_unused_key_names.resize(size); + tf_results->missing_unused_key_indexes.resize(size); + + for (int i = 0; i < size; ++i) { + TensorId id = results.missing_unused_input_map_keys[i]; + tf_results->missing_unused_key_names_data.push_back(id.first.ToString()); + tf_results->missing_unused_key_names[i] = + tf_results->missing_unused_key_names_data.back().c_str(); + tf_results->missing_unused_key_indexes[i] = id.second; } } @@ -2325,11 +2402,12 @@ TF_Session* TF_NewSession(TF_Graph* graph, const TF_SessionOptions* opt, Session* session; status->status = NewSession(opt->options, &session); if (status->status.ok()) { + TF_Session* new_session = new TF_Session(session, graph); if (graph != nullptr) { mutex_lock l(graph->mu); - graph->num_sessions += 1; + graph->sessions[new_session] = Status::OK(); } - return new TF_Session(session, graph); + return new_session; } else { DCHECK_EQ(nullptr, session); return nullptr; @@ -2393,7 +2471,7 @@ TF_Session* TF_LoadSessionFromSavedModel( TF_Session* session = new TF_Session(bundle.session.release(), graph); - graph->num_sessions += 1; + graph->sessions[session] = Status::OK(); session->last_num_graph_nodes = graph->graph.num_node_ids(); return session; #endif // __ANDROID__ @@ -2408,8 +2486,8 @@ void TF_DeleteSession(TF_Session* s, TF_Status* status) { TF_Graph* const graph = s->graph; if (graph != nullptr) { graph->mu.lock(); - graph->num_sessions -= 1; - const bool del = graph->delete_requested && graph->num_sessions == 0; + graph->sessions.erase(s); + const bool del = graph->delete_requested && graph->sessions.empty(); graph->mu.unlock(); if (del) delete graph; } @@ -2425,6 +2503,13 @@ static bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status) { mutex_lock session_lock(session->mu); session->graph->mu.lock(); const Graph& graph = session->graph->graph; + + status->status = session->graph->sessions[session]; + if (!status->status.ok()) { + session->graph->mu.unlock(); + return false; + } + const auto num_nodes = graph.num_node_ids(); if (session->last_num_graph_nodes < num_nodes) { status->status = tensorflow::ValidateNoCycles(session->graph->graph); @@ -2580,4 +2665,54 @@ void TF_SessionPRun(TF_Session* session, const char* handle, output_values, target_names, nullptr, status); } +TF_ApiDefMap* TF_NewApiDefMap(TF_Buffer* op_list_buffer, TF_Status* status) { + tensorflow::OpList op_list; + if (!op_list.ParseFromArray(op_list_buffer->data, op_list_buffer->length)) { + status->status = InvalidArgument("Unparseable OpList"); + return nullptr; + } + status->status = Status::OK(); + return new TF_ApiDefMap(op_list); +} + +void TF_DeleteApiDefMap(TF_ApiDefMap* apimap) { delete apimap; } + +void TF_ApiDefMapPut(TF_ApiDefMap* api_def_map, const char* text, + size_t text_len, TF_Status* status) { +#ifdef __ANDROID__ + status->status = tensorflow::errors::Unimplemented( + "ApiDefMap is not supported in Android."); +#else + mutex_lock l(api_def_map->lock); + if (api_def_map->update_docs_called) { + status->status = FailedPrecondition( + "TF_ApiDefMapPut cannot be called after TF_ApiDefMapGet has been " + "called."); + return; + } + string api_def_text(text, text_len); + status->status = api_def_map->api_def_map.LoadApiDef(api_def_text); +#endif // __ANDROID__ +} + +TF_Buffer* TF_ApiDefMapGet(TF_ApiDefMap* api_def_map, const char* name, + size_t name_len, TF_Status* status) { +#ifdef __ANDROID__ + status->status = tensorflow::errors::Unimplemented( + "ApiDefMap is not supported in Android."); + return nullptr; +#else + mutex_lock l(api_def_map->lock); + if (!api_def_map->update_docs_called) { + api_def_map->api_def_map.UpdateDocs(); + api_def_map->update_docs_called = true; + } + string name_str(name, name_len); + const auto* api_def = api_def_map->api_def_map.GetApiDef(name_str); + + TF_Buffer* ret = TF_NewBuffer(); + status->status = MessageToBuffer(*api_def, ret); + return ret; +#endif // __ANDROID__ +} } // end extern "C" diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h index bb569d67fcbcec29e9494236abd79b3e40db91cd..d2e45341bf1b9ee4579f84064550ce26041dd04a 100644 --- a/tensorflow/c/c_api.h +++ b/tensorflow/c/c_api.h @@ -511,6 +511,11 @@ TF_CAPI_EXPORT extern void TF_SetAttrTypeList(TF_OperationDescription* desc, const char* attr_name, const TF_DataType* values, int num_values); +// Set a 'func' attribute to the specified name. +// `value` must point to a string of length `length` bytes. +TF_CAPI_EXPORT extern void TF_SetAttrFuncName(TF_OperationDescription* desc, + const char* attr_name, + const char* value, size_t length); // Set `num_dims` to -1 to represent "unknown rank". Otherwise, // `dims` points to an array of length `num_dims`. `dims[i]` must be @@ -889,6 +894,20 @@ TF_CAPI_EXPORT extern void TF_DeleteImportGraphDefOptions( TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsSetPrefix( TF_ImportGraphDefOptions* opts, const char* prefix); +// Set whether to uniquify imported operation names. If true, imported operation +// names will be modified if their name already exists in the graph. If false, +// conflicting names will be treated as an error. Note that this option has no +// effect if a prefix is set, since the prefix will guarantee all names are +// unique. Defaults to false. +TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsSetUniquifyNames( + TF_ImportGraphDefOptions* opts, unsigned char uniquify_names); + +// If true, the specified prefix will be modified if it already exists as an +// operation name or prefix in the graph. If false, a conflicting prefix will be +// treated as an error. This option has no effect if no prefix is specified. +TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsSetUniquifyPrefix( + TF_ImportGraphDefOptions* opts, unsigned char uniquify_prefix); + // Set any imported nodes with input `src_name:src_index` to have that input // replaced with `dst`. `src_name` refers to a node in the graph to be imported, // `dst` references a node already existing in the graph being imported into. @@ -948,16 +967,16 @@ TF_CAPI_EXPORT extern void TF_ImportGraphDefResultsReturnOperations( TF_ImportGraphDefResults* results, int* num_opers, TF_Operation*** opers); // Fetches any input mappings requested via -// TF_ImportGraphDefOptionsAddInputMapping() that weren't used as input to any -// node in the imported graph def. The number of fetched mappings is returned in -// `num_unused_input_mappings`. The array of each mapping's source node name is -// returned in `src_names`, and the array of each mapping's source index is -// returned in `src_indexes`. +// TF_ImportGraphDefOptionsAddInputMapping() that didn't appear in the GraphDef +// and weren't used as input to any node in the imported graph def. The number +// of fetched mappings is returned in `num_missing_unused_input_mappings`. The +// array of each mapping's source node name is returned in `src_names`, and the +// array of each mapping's source index is returned in `src_indexes`. // // `*src_names`, `*src_indexes`, and the memory backing each string in // `src_names` are owned by and have the lifetime of `results`. -TF_CAPI_EXPORT extern void TF_ImportGraphDefResultsUnusedInputMappings( - TF_ImportGraphDefResults* results, int* num_unused_input_mappings, +TF_CAPI_EXPORT extern void TF_ImportGraphDefResultsMissingUnusedInputMappings( + TF_ImportGraphDefResults* results, int* num_missing_unused_input_mappings, const char*** src_names, int** src_indexes); // Deletes a results object returned by TF_GraphImportGraphDefWithResults(). @@ -1015,6 +1034,23 @@ TF_CAPI_EXPORT extern void TF_GraphCopyFunction(TF_Graph* g, const TF_Function* grad, TF_Status* status); +// Returns the number of TF_Functions registered in `g`. +TF_CAPI_EXPORT extern int TF_GraphNumFunctions(TF_Graph* g); + +// Fills in `funcs` with the TF_Function* registered in `g`. +// `funcs` must point to an array of TF_Function* of length at least +// `max_func`. In usual usage, max_func should be set to the result of +// TF_GraphNumFunctions(g). In this case, all the functions registered in +// `g` will be returned. Else, an unspecified subset. +// +// If successful, returns the number of TF_Function* successfully set in +// `funcs` and sets status to OK. The caller takes ownership of +// all the returned TF_Functions. They must be deleted with TF_DeleteFunction. +// On error, returns 0, sets status to the encountered error, and the contents +// of funcs will be undefined. +TF_CAPI_EXPORT extern int TF_GraphGetFunctions(TF_Graph* g, TF_Function** funcs, + int max_func, TF_Status* status); + // Note: The following function may fail on very large protos in the future. TF_CAPI_EXPORT extern void TF_OperationToNodeDef(TF_Operation* oper, @@ -1504,6 +1540,49 @@ TF_CAPI_EXPORT extern void TF_DeleteLibraryHandle(TF_Library* lib_handle); // in this address space. TF_CAPI_EXPORT extern TF_Buffer* TF_GetAllOpList(); +// TF_ApiDefMap encapsulates a collection of API definitions for an operation. +// +// This object maps the name of a TensorFlow operation to a description of the +// API to generate for it, as defined by the ApiDef protocol buffer ( +// https://www.tensorflow.org/code/tensorflow/core/framework/api_def.proto) +// +// The ApiDef messages are typically used to generate convenience wrapper +// functions for TensorFlow operations in various language bindings. +typedef struct TF_ApiDefMap TF_ApiDefMap; + +// Creates a new TF_ApiDefMap instance. +// +// Params: +// op_list_buffer - TF_Buffer instance containing serialized OpList +// protocol buffer. (See +// https://www.tensorflow.org/code/tensorflow/core/framework/op_def.proto +// for the OpList proto definition). +// status - Set to OK on success and an appropriate error on failure. +TF_CAPI_EXPORT extern TF_ApiDefMap* TF_NewApiDefMap(TF_Buffer* op_list_buffer, + TF_Status* status); + +// Deallocates a TF_ApiDefMap. +TF_CAPI_EXPORT extern void TF_DeleteApiDefMap(TF_ApiDefMap* apimap); + +// Add ApiDefs to the map. +// +// `text` corresponds to a text representation of an ApiDefs protocol message. +// (https://www.tensorflow.org/code/tensorflow/core/framework/api_def.proto). +// +// The provided ApiDefs will be merged with existing ones in the map, with +// precedence given to the newly added version in case of conflicts with +// previous calls to TF_ApiDefMapPut. +TF_CAPI_EXPORT extern void TF_ApiDefMapPut(TF_ApiDefMap* api_def_map, + const char* text, size_t text_len, + TF_Status* status); + +// Returns a serialized ApiDef protocol buffer for the TensorFlow operation +// named `name`. +TF_CAPI_EXPORT extern TF_Buffer* TF_ApiDefMapGet(TF_ApiDefMap* api_def_map, + const char* name, + size_t name_len, + TF_Status* status); + #ifdef __cplusplus } /* end extern "C" */ #endif diff --git a/tensorflow/c/c_api_function.cc b/tensorflow/c/c_api_function.cc index dcb818b88b6fca460852beb6e948d2eb6964f663..46271e0514f473099848a8573cb7cb6fad33f7dc 100644 --- a/tensorflow/c/c_api_function.cc +++ b/tensorflow/c/c_api_function.cc @@ -68,7 +68,7 @@ class NodeNameMapping { // This is a superset of values in name_mapping_. std::unordered_set used_names_; // Mapping from original node name from the graph to the normalized - // and uniqified version of it. + // and uniquified version of it. std::unordered_map name_mapping_; }; @@ -226,12 +226,17 @@ Status FillFunctionBody( } node_def->add_input(strings::StrCat("^", normalized)); } + + // A function is stateful if any of its nodes are stateful. + if (node->op_def().is_stateful()) { + fdef->mutable_signature()->set_is_stateful(true); + } } return Status::OK(); } // Graph to FunctionDef conversion. This code is closely modeled on the Python -// code in third_party/tensorflow/python/framework/function.py. +// code in tensorflow/python/framework/function.py. Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name, bool append_hash_to_fn_name, const std::vector& body_nodes, @@ -307,7 +312,7 @@ Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name, TF_RETURN_IF_ERROR( NameRangesForNode(*node, node->op_def(), nullptr, &output_ranges)); for (const auto& output : output_ranges) { - const string& output_name = output.first; + const StringPiece& output_name = output.first; int index_start = output.second.first; int index_end = output.second.second; for (int i = index_start; i < index_end; ++i) { @@ -543,6 +548,28 @@ void TF_GraphCopyFunction(TF_Graph* g, const TF_Function* func, status->status = g->graph.AddFunctionLibrary(fdef_lib); } +int TF_GraphNumFunctions(TF_Graph* g) { + tensorflow::mutex_lock l(g->mu); + return g->graph.flib_def().num_functions(); +} + +int TF_GraphGetFunctions(TF_Graph* g, TF_Function** funcs, int max_func, + TF_Status* status) { + tensorflow::FunctionDefLibrary lib; + { + tensorflow::mutex_lock l(g->mu); + lib = g->graph.flib_def().ToProto(); + } + const auto len = std::min(max_func, static_cast(lib.function_size())); + for (int i = 0; i < len; ++i) { + TF_Function* func = new TF_Function(); + func->fdef = lib.function(i); + funcs[i] = func; + } + status->status = tensorflow::Status::OK(); + return len; +} + void TF_FunctionToFunctionDef(TF_Function* func, TF_Buffer* output_func_def, TF_Status* status) { status->status = MessageToBuffer(func->fdef, output_func_def); diff --git a/tensorflow/c/c_api_function_test.cc b/tensorflow/c/c_api_function_test.cc index d5580b658992413ae6f9cb79ef88751ee28ce465..dbce66d2317a8e89288fab932cf69055f8b5a7f0 100644 --- a/tensorflow/c/c_api_function_test.cc +++ b/tensorflow/c/c_api_function_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" @@ -1462,7 +1463,11 @@ TEST_F(CApiFunctionTest, AppendHash) { /*append_hash=*/true); tensorflow::FunctionDef fdef; ASSERT_TRUE(GetFunctionDef(func_, &fdef)); +#if (__BYTE_ORDER__ == __ORDER_BIG_ENDIAN__) + ASSERT_EQ(string("func_name_base_ZpgUD4x8oqk"), fdef.signature().name()); +#else ASSERT_EQ(string("func_name_base_qaJ8jA8UmGY"), fdef.signature().name()); +#endif } TEST_F(CApiFunctionTest, GetOpDef) { @@ -1482,9 +1487,124 @@ TEST_F(CApiFunctionTest, GetOpDef) { EXPECT_EQ(op_def.name(), func_name_); EXPECT_EQ(op_def.input_arg_size(), 1); EXPECT_EQ(op_def.output_arg_size(), 1); + EXPECT_FALSE(op_def.is_stateful()); TF_DeleteBuffer(buffer); } +void DefineStatefulFunction(const char* name, TF_Function** func) { + std::unique_ptr func_graph( + TF_NewGraph(), TF_DeleteGraph); + std::unique_ptr s(TF_NewStatus(), + TF_DeleteStatus); + + TF_Tensor* tensor_shape = Int32Tensor({37, 1}); + TF_Operation* shape = Const(tensor_shape, func_graph.get(), s.get(), "shape"); + TF_Operation* random = + RandomUniform(shape, TF_FLOAT, func_graph.get(), s.get()); + + TF_Output inputs[] = {}; + TF_Output outputs[] = {{random, 0}}; + *func = TF_GraphToFunction(func_graph.get(), name, /*append_hash=*/false, -1, + /*opers=*/nullptr, 0, inputs, 1, outputs, + /*output_names=*/nullptr, + /*opts=*/nullptr, "", s.get()); + ASSERT_EQ(TF_OK, TF_GetCode(s.get())) << TF_Message(s.get()); + ASSERT_NE(*func, nullptr); + TF_DeleteTensor(tensor_shape); +} + +TEST_F(CApiFunctionTest, StatefulOpDef) { + DefineStatefulFunction(func_name_, &func_); + TF_GraphCopyFunction(host_graph_, func_, nullptr, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + // Test we can retrieve function OpDef from graph + TF_Buffer* buffer = TF_NewBuffer(); + TF_GraphGetOpDef(host_graph_, func_name_, buffer, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + // Sanity check returned OpDef + string data(static_cast(buffer->data), buffer->length); + OpDef op_def; + op_def.ParseFromString(data); + EXPECT_EQ(op_def.name(), func_name_); + EXPECT_EQ(op_def.input_arg_size(), 0); + EXPECT_EQ(op_def.output_arg_size(), 1); + EXPECT_TRUE(op_def.is_stateful()); + + TF_DeleteBuffer(buffer); +} + +void AssertEqual(TF_Function* f1, TF_Function* f2) { + string s1, s2; + tensorflow::FunctionDef fdef1, fdef2; + ASSERT_TRUE(GetFunctionDef(f1, &fdef1)); + ASSERT_TRUE(GetFunctionDef(f2, &fdef2)); + SerializeToStringDeterministic(fdef1, &s1); + SerializeToStringDeterministic(fdef2, &s2); + ASSERT_EQ(s1, s2); +} + +string GetName(TF_Function* func) { + tensorflow::FunctionDef fdef; + GetFunctionDef(func, &fdef); + return fdef.signature().name(); +} + +TEST_F(CApiFunctionTest, GetFunctionsFromGraph) { + TF_Function* funcs[2]; + + // Get functions from empty graph + EXPECT_EQ(TF_GraphNumFunctions(host_graph_), 0); + TF_GraphGetFunctions(host_graph_, nullptr, 0, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + // Define a function and add it to host_graph_ + TF_Function* func0; + DefineFunction("FooFunc0", &func0); + TF_GraphCopyFunction(host_graph_, func0, nullptr, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + // Get this function from host_graph_ + EXPECT_EQ(TF_GraphNumFunctions(host_graph_), 1); + EXPECT_EQ(TF_GraphGetFunctions(host_graph_, funcs, 0, s_), 0); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + EXPECT_EQ(TF_GraphGetFunctions(host_graph_, funcs, 1, s_), 1); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + AssertEqual(func0, funcs[0]); + TF_DeleteFunction(funcs[0]); + EXPECT_EQ(TF_GraphGetFunctions(host_graph_, funcs, 2, s_), 1); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + AssertEqual(func0, funcs[0]); + TF_DeleteFunction(funcs[0]); + + // Define a second function + TF_Function* func1; + DefineFunction("FooFunc1", &func1); + TF_GraphCopyFunction(host_graph_, func1, nullptr, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + // Get both function from host_graph_ + EXPECT_EQ(TF_GraphNumFunctions(host_graph_), 2); + EXPECT_EQ(TF_GraphGetFunctions(host_graph_, funcs, 0, s_), 0); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + EXPECT_EQ(TF_GraphGetFunctions(host_graph_, funcs, 2, s_), 2); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + if (GetName(funcs[0]) == GetName(func0)) { + AssertEqual(func0, funcs[0]); + AssertEqual(func1, funcs[1]); + } else { + AssertEqual(func0, funcs[1]); + AssertEqual(func1, funcs[0]); + } + + TF_DeleteFunction(funcs[0]); + TF_DeleteFunction(funcs[1]); + + TF_DeleteFunction(func0); + TF_DeleteFunction(func1); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/c/c_api_internal.h b/tensorflow/c/c_api_internal.h index bb04e01beec931a8ea66d0855eec9625d3a6a5ab..91667056e0eeb224b4b8a034766f11a123cd1a03 100644 --- a/tensorflow/c/c_api_internal.h +++ b/tensorflow/c/c_api_internal.h @@ -24,6 +24,9 @@ limitations under the License. #include #include +#ifndef __ANDROID__ +#include "tensorflow/core/framework/op_gen_lib.h" +#endif #include "tensorflow/core/common_runtime/shape_refiner.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" @@ -81,12 +84,20 @@ struct TF_Graph { std::unordered_map name_map GUARDED_BY(mu); - // TF_Graph may only / must be deleted when - // num_sessions == 0 && delete_requested == true - - // num_sessions incremented by TF_NewSession, and decremented by + // The keys of this map are all the active sessions using this graph. + // Each value is the current "runnability" status of the corresponding + // session. Under normal conditions all statuses are Status::OK(), but + // if some operation is mutated after it was run by a session (this + // is detected in RecordMutation function), that session is no longer + // safe to run. Its status will contain the error that will be returned + // to the user, should she try running this session. + // + // Sessions are added to this map in TF_NewSession, and removed in // TF_DeleteSession. - int num_sessions GUARDED_BY(mu); + // TF_Graph may only / must be deleted when + // sessions.size() == 0 && delete_requested == true + tensorflow::gtl::FlatMap sessions + GUARDED_BY(mu); bool delete_requested GUARDED_BY(mu); // set true by TF_DeleteGraph // Used to link graphs contained in TF_WhileParams to the parent graph that @@ -135,11 +146,11 @@ struct TF_ImportGraphDefOptions { struct TF_ImportGraphDefResults { std::vector return_tensors; std::vector return_nodes; - std::vector unused_key_names; - std::vector unused_key_indexes; + std::vector missing_unused_key_names; + std::vector missing_unused_key_indexes; - // Backing memory for unused_key_names values. - std::list unused_key_names_data; + // Backing memory for missing_unused_key_names values. + std::list missing_unused_key_names_data; }; struct TF_DeviceList { @@ -150,6 +161,22 @@ struct TF_Function { tensorflow::FunctionDef fdef; }; +struct TF_ApiDefMap { + explicit TF_ApiDefMap(const tensorflow::OpList& op_list) + : +#ifndef __ANDROID__ + api_def_map(op_list), +#endif + update_docs_called(false) { + } + +#ifndef __ANDROID__ + tensorflow::ApiDefMap api_def_map GUARDED_BY(lock); +#endif + bool update_docs_called GUARDED_BY(lock); + tensorflow::mutex lock; +}; + namespace tensorflow { class TensorCApi { @@ -167,6 +194,24 @@ TF_Tensor* TF_TensorFromTensor(const Tensor& src, TF_Status* status); Status MessageToBuffer(const tensorflow::protobuf::Message& in, TF_Buffer* out); +// Set the shapes and types of the output's handle. +// +// The lengths of the arrays pointed to by `shapes`, `ranks`, and `types` must +// all be equal to `num_shapes_and_types`. If `ranks[i] != -1`, (i.e., if the +// rank is known), then it must be equal to the length of `shapes[i]`; if +// `ranks[i] == 1`, then `shapes[i]` may be nullptr. +// +// TODO(akshayka): Implement a corresponding getter method. +void TF_GraphSetOutputHandleShapesAndTypes(TF_Graph* graph, TF_Output output, + int num_shapes_and_types, + const int64_t** shapes, + const int* ranks, + const TF_DataType* types, + TF_Status* status); + +void RecordMutation(TF_Graph* graph, const TF_Operation& op, + const char* mutation_type); + } // end namespace tensorflow #endif // TENSORFLOW_C_C_API_INTERNAL_H_ diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc index 6ec1db8ccfdb713f330b708e604bd4b502ff7202..df697e16d3d3fcaac66f967c0d3938450f0b0be6 100644 --- a/tensorflow/c/c_api_test.cc +++ b/tensorflow/c/c_api_test.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/cc/saved_model/tag_constants.h" #include "tensorflow/core/example/example.pb.h" #include "tensorflow/core/example/feature.pb.h" +#include "tensorflow/core/framework/api_def.pb.h" #include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/graph.pb_text.h" #include "tensorflow/core/framework/node_def.pb_text.h" @@ -773,7 +774,7 @@ TEST(CAPI, ImportGraphDef_WithReturnOutputs) { TF_DeleteStatus(s); } -TEST(CAPI, ImportGraphDef_UnusedInputMappings) { +TEST(CAPI, ImportGraphDef_MissingUnusedInputMappings) { TF_Status* s = TF_NewStatus(); TF_Graph* graph = TF_NewGraph(); @@ -816,7 +817,7 @@ TEST(CAPI, ImportGraphDef_UnusedInputMappings) { int num_unused_input_mappings; const char** src_names; int* src_indexes; - TF_ImportGraphDefResultsUnusedInputMappings( + TF_ImportGraphDefResultsMissingUnusedInputMappings( results, &num_unused_input_mappings, &src_names, &src_indexes); ASSERT_EQ(1, num_unused_input_mappings); EXPECT_EQ(string("fake"), string(src_names[0])); @@ -2027,6 +2028,77 @@ TEST_F(CApiAttributesTest, Errors) { EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)) << TF_Message(s_); } +TEST(TestApiDef, TestCreateApiDef) { + TF_Status* status = TF_NewStatus(); + TF_Library* lib = + TF_LoadLibrary("tensorflow/c/test_op.so", status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteStatus(status); + + TF_Buffer op_list_buf = TF_GetOpList(lib); + status = TF_NewStatus(); + auto* api_def_map = TF_NewApiDefMap(&op_list_buf, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteStatus(status); + + string op_name = "TestCApi"; + status = TF_NewStatus(); + auto* api_def_buf = + TF_ApiDefMapGet(api_def_map, op_name.c_str(), op_name.size(), status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteStatus(status); + + tensorflow::ApiDef api_def; + EXPECT_TRUE(api_def.ParseFromArray(api_def_buf->data, api_def_buf->length)); + EXPECT_EQ(op_name, api_def.graph_op_name()); + EXPECT_EQ(R"doc(Used to test C API)doc", api_def.summary()); + + TF_DeleteBuffer(api_def_buf); + TF_DeleteApiDefMap(api_def_map); + TF_DeleteLibraryHandle(lib); +} + +TEST(TestApiDef, TestCreateApiDefWithOverwrites) { + TF_Status* status = TF_NewStatus(); + TF_Library* lib = + TF_LoadLibrary("tensorflow/c/test_op.so", status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteStatus(status); + + TF_Buffer op_list_buf = TF_GetOpList(lib); + status = TF_NewStatus(); + auto* api_def_map = TF_NewApiDefMap(&op_list_buf, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteStatus(status); + + string api_def_overwrites = R"(op: < + graph_op_name: "TestCApi" + summary: "New summary" +> +)"; + status = TF_NewStatus(); + TF_ApiDefMapPut(api_def_map, api_def_overwrites.c_str(), + api_def_overwrites.size(), status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteStatus(status); + + string op_name = "TestCApi"; + status = TF_NewStatus(); + auto* api_def_buf = + TF_ApiDefMapGet(api_def_map, op_name.c_str(), op_name.size(), status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteStatus(status); + + tensorflow::ApiDef api_def; + EXPECT_TRUE(api_def.ParseFromArray(api_def_buf->data, api_def_buf->length)); + EXPECT_EQ(op_name, api_def.graph_op_name()); + EXPECT_EQ("New summary", api_def.summary()); + + TF_DeleteBuffer(api_def_buf); + TF_DeleteApiDefMap(api_def_map); + TF_DeleteLibraryHandle(lib); +} + #undef EXPECT_TF_META } // namespace diff --git a/tensorflow/c/c_test_util.cc b/tensorflow/c/c_test_util.cc index c291a2e440a8515e968b0ce0395b289080f04e8b..37439ff0beac5a5220460465e954b6c093ee1ba9 100644 --- a/tensorflow/c/c_test_util.cc +++ b/tensorflow/c/c_test_util.cc @@ -193,6 +193,15 @@ TF_Operation* LessThan(TF_Output l, TF_Output r, TF_Graph* graph, return TF_FinishOperation(desc, s); } +TF_Operation* RandomUniform(TF_Operation* shape, TF_DataType dtype, + TF_Graph* graph, TF_Status* s) { + TF_OperationDescription* desc = + TF_NewOperation(graph, "RandomUniform", "random_uniform"); + TF_AddInput(desc, {shape, 0}); + TF_SetAttrType(desc, "dtype", dtype); + return TF_FinishOperation(desc, s); +} + void Split3Helper(TF_Operation* input, TF_Graph* graph, TF_Status* s, const char* name, TF_Operation** op) { TF_Operation* zero = ScalarConst( diff --git a/tensorflow/c/c_test_util.h b/tensorflow/c/c_test_util.h index d54733749248fa32c39d88bb0281d329dd50c7bd..3429009a71a863ae6b69b5cd29ace3c7fd078f4c 100644 --- a/tensorflow/c/c_test_util.h +++ b/tensorflow/c/c_test_util.h @@ -74,7 +74,10 @@ TF_Operation* Neg(TF_Operation* n, TF_Graph* graph, TF_Status* s, TF_Operation* LessThan(TF_Output l, TF_Output r, TF_Graph* graph, TF_Status* s); -// Split `input` along the first dimention into 3 tensors +TF_Operation* RandomUniform(TF_Operation* shape, TF_DataType dtype, + TF_Graph* graph, TF_Status* s); + +// Split `input` along the first dimension into 3 tensors TF_Operation* Split3(TF_Operation* input, TF_Graph* graph, TF_Status* s, const char* name = "split3"); diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index d533758e360bc44a6f52f57eaae5b222e0482860..74190cb135ac6c17bfcc9d8bd2f7c75ac5e8c076 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -33,7 +33,7 @@ tf_cuda_library( "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", ], - }), + }) + ["//tensorflow/core:gpu_runtime"], ) tf_cuda_library( @@ -55,6 +55,10 @@ tf_cuda_library( tf_cuda_cc_test( name = "c_api_test", srcs = ["c_api_test.cc"], + tags = [ + "guitar", + "multi_gpu", + ], deps = [ ":c_api", "//tensorflow/core:lib", @@ -113,3 +117,9 @@ cc_library( "//tensorflow/core:lib", ], ) + +filegroup( + name = "headers", + srcs = ["c_api.h"], + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 706c89536db019c7f7389af576815746b2425520..04a415b909ba3e76dfc12a3522f85d290ba6d36f 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/c/c_api_internal.h" #include "tensorflow/c/eager/c_api_internal.h" #include "tensorflow/c/eager/runtime.h" +#include "tensorflow/core/common_runtime/copy_tensor.h" #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/function.h" @@ -97,7 +98,10 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) { void TFE_DeleteContext(TFE_Context* ctx, TF_Status* status) { status->status = tensorflow::Status::OK(); - tensorflow::gtl::STLDeleteValues(&ctx->kernel_cache); + { + tensorflow::mutex_lock ml(ctx->cache_mu); + tensorflow::gtl::STLDeleteValues(&ctx->kernel_cache); + } TF_Graph* graph = ctx->session->graph; TF_DeleteSession(ctx->session, status); TF_DeleteGraph(graph); @@ -109,6 +113,11 @@ TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) { return TF_SessionListDevices(ctx->session, status); } +void TFE_ContextClearCaches(TFE_Context* ctx) { + tensorflow::mutex_lock ml(ctx->cache_mu); + tensorflow::gtl::STLDeleteValues(&ctx->kernel_cache); +} + TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) { tensorflow::Tensor tensor; status->status = tensorflow::TF_TensorToTensor(t, &tensor); @@ -164,23 +173,13 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h, bool is_same_device = (srcd == dstd) || (DeviceName(srcd) == DeviceName(dstd)); const bool dst_cpu = IsCPU(dstd); + const bool src_cpu = IsCPU(srcd); if (is_same_device) { return new TFE_TensorHandle(h->t, dst_cpu ? nullptr : dstd); } - const bool src_cpu = IsCPU(srcd); - if (src_cpu == dst_cpu) { - TF_SetStatus( - status, TF_INVALID_ARGUMENT, - tensorflow::strings::StrCat( - "TFE_TensorHandleCopyToDevice requires either the source " - "TFE_TensorHandle be on or the destination device be on CPU " - "or be the same (they are ", - DeviceName(srcd), " and ", DeviceName(dstd), " in this call)") - .c_str()); - return nullptr; - } tensorflow::Tensor* src = &(h->t); - if (!dst_cpu && !tensorflow::DataTypeCanUseMemcpy(src->dtype())) { + if (!dst_cpu && (src->dtype() != tensorflow::DT_VARIANT && + !tensorflow::DataTypeCanUseMemcpy(src->dtype()))) { TF_SetStatus( status, TF_INVALID_ARGUMENT, tensorflow::strings::StrCat("Can't copy Tensor with type ", @@ -189,26 +188,22 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h, .c_str()); return nullptr; } - if (src_cpu) { - tensorflow::Tensor dst( - dstd->GetAllocator(tensorflow::AllocatorAttributes()), src->dtype(), - src->shape()); - if (src->shape().num_elements() == 0) { - return new TFE_TensorHandle(dst, dstd); - } - tensorflow::Notification n; - dstd->tensorflow_gpu_device_info()->default_context->CopyCPUTensorToDevice( - src, dstd, &dst, [status, &n](const tensorflow::Status& s) { - status->status = s; - n.Notify(); - }); - n.WaitForNotification(); - return (TF_GetCode(status) == TF_OK) ? new TFE_TensorHandle(dst, dstd) - : nullptr; - } - CHECK(dst_cpu); - tensorflow::Tensor dst(src->dtype(), src->shape()); - tensorflow::Notification n; + tensorflow::AllocatorAttributes attr; + if (src->dtype() == tensorflow::DT_VARIANT) { + attr.set_on_host(true); + } + tensorflow::Tensor dst(dstd->GetAllocator(attr), src->dtype(), src->shape()); + if (src->shape().num_elements() == 0) { + return new TFE_TensorHandle(dst, dst_cpu ? nullptr : dstd); + } + tensorflow::DeviceContext* src_device_context = nullptr; + if (!src_cpu) { + src_device_context = srcd->tensorflow_gpu_device_info()->default_context; + } + tensorflow::DeviceContext* dst_device_context = nullptr; + if (!dst_cpu) { + dst_device_context = dstd->tensorflow_gpu_device_info()->default_context; + } // TODO(ashankar): The Sync() call below may be more aggressive than // necessary. It is based on knowledge of implementation details - that // GPU devices are implemented using 3 streams - one for host->device copies, @@ -217,16 +212,18 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h, // but more than necessary (since it waits for operations that might have // nothing to do with this tensor to complete). status->status = srcd->Sync(); - if (!status->status.ok()) return nullptr; - srcd->tensorflow_gpu_device_info()->default_context->CopyDeviceTensorToCPU( - src, "IGNORE_MY_TENSOR_NAME", srcd, &dst, - [status, &n](const tensorflow::Status& s) { - status->status = s; - n.Notify(); - }); + tensorflow::Notification n; + tensorflow::CopyTensor::ViaDMA("copy", src_device_context, dst_device_context, + srcd, dstd, tensorflow::AllocatorAttributes(), + tensorflow::AllocatorAttributes(), src, &dst, + [status, &n](const tensorflow::Status& s) { + status->status = s; + n.Notify(); + }); n.WaitForNotification(); - return (TF_GetCode(status) == TF_OK) ? new TFE_TensorHandle(dst, nullptr) - : nullptr; + return (TF_GetCode(status) == TF_OK) + ? new TFE_TensorHandle(dst, dst_cpu ? nullptr : dstd) + : nullptr; } TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name, @@ -505,8 +502,11 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, std::vector outputs(1); const tensorflow::MemoryTypeVector* output_memory_types = nullptr; tensorflow::Fprint128 cache_key = op->attrs.CacheKey(device->name()); - tensorflow::KernelAndDevice* kernel = - tensorflow::gtl::FindPtrOrNull(ctx->kernel_cache, cache_key); + tensorflow::KernelAndDevice* kernel; + { + tensorflow::tf_shared_lock l(ctx->cache_mu); + kernel = tensorflow::gtl::FindPtrOrNull(ctx->kernel_cache, cache_key); + } if (kernel == nullptr) { const tensorflow::NodeDef& ndef = op->attrs.BuildNodeDef(); kernel = new tensorflow::KernelAndDevice(ctx->rendezvous); @@ -522,6 +522,7 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, delete kernel; return; } + tensorflow::mutex_lock ml(ctx->cache_mu); tensorflow::gtl::InsertOrUpdate(&(ctx->kernel_cache), cache_key, kernel); } std::vector copied_tensors; @@ -534,19 +535,54 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, } return; } + std::unique_ptr maybe_stats; + if (ctx->should_store_metadata.load()) { + maybe_stats.reset(new tensorflow::NodeExecStats); + maybe_stats->set_node_name(op->name); + maybe_stats->set_all_start_micros(tensorflow::Env::Default()->NowMicros()); + maybe_stats->set_op_start_rel_micros(0); + maybe_stats->set_scheduled_micros(tensorflow::Env::Default()->NowMicros()); + // TODO(apassos) track referenced tensors + } // WARNING: kernel->Run utilizes the FunctionLibraryRuntime // (ctx->func_lib(device)), which in turn holds a pointer to func_lib_def, // which is GUARDED_BY(ctx->functions_mu). But knowledge of the implementation - // of FunctionLibraryRuntime tells use that func_lib_def is not accessed by + // of FunctionLibraryRuntime tells us that func_lib_def is not accessed by // FunctionLibraryRuntime::Run(), so there is no thread-safety concern here. // This is quite subtle. Re-work things to make this better? (Would it make // sense for FunctionLibraryRuntime to ensure thread-safe access to - // FunctionLibraryDefinition?). - status->status = kernel->Run(&op->inputs, &outputs); + // FunctionLibraryDefinition?). TODO(apassos) figure out how to record stats + // for ops which are a part of functions. + status->status = kernel->Run(&op->inputs, &outputs, maybe_stats.get()); for (auto* t : copied_tensors) { TFE_DeleteTensorHandle(t); } if (!status->status.ok()) return; + if (maybe_stats != nullptr) { + maybe_stats->set_op_end_rel_micros(tensorflow::Env::Default()->NowMicros() - + maybe_stats->all_start_micros()); + tensorflow::mutex_lock ml(ctx->metadata_mu); + if (ctx->should_store_metadata.load()) { + auto* step_stats = ctx->run_metadata.mutable_step_stats(); + // Lazily initialize the RunMetadata with information about all devices if + // this is the first call. + while (step_stats->dev_stats_size() < ctx->devices().size()) { + step_stats->add_dev_stats(); + } + // Find the current device's index. + int device_idx = 0; + for (int i = 0; i < ctx->devices().size(); ++i) { + if (ctx->devices()[i] == device) { + device_idx = i; + break; + } + } + // Populate the device stats for this device. + auto* dev_stats = step_stats->mutable_dev_stats(device_idx); + dev_stats->set_device(device->name()); + *dev_stats->add_node_stats() = *maybe_stats; + } + } *num_retvals = std::min(*num_retvals, outputs.size()); for (int i = 0; i < *num_retvals; ++i) { tensorflow::Device* d = IsCPU(device) ? nullptr : device; @@ -593,3 +629,20 @@ const tensorflow::Tensor* TFE_TensorHandleUnderlyingTensorInHostMemory( } return &h->t; } + +void TFE_ContextEnableRunMetadata(TFE_Context* ctx) { + ctx->should_store_metadata.store(true); +} + +void TFE_ContextDisableRunMetadata(TFE_Context* ctx) { + tensorflow::mutex_lock ml(ctx->metadata_mu); + ctx->should_store_metadata.store(false); + ctx->run_metadata.Clear(); +} + +void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf, + TF_Status* status) { + tensorflow::mutex_lock ml(ctx->metadata_mu); + status->status = MessageToBuffer(ctx->run_metadata, buf); + ctx->run_metadata.Clear(); +} diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h index ca105962df0d6655946304159937621022e7fcba..9b0fd037da35f31e9b97f29b1269bbca9e4c849d 100644 --- a/tensorflow/c/eager/c_api.h +++ b/tensorflow/c/eager/c_api.h @@ -17,6 +17,8 @@ limitations under the License. #define TENSORFLOW_C_EAGER_C_API_H_ // C API extensions to experiment with eager execution of kernels. +// WARNING: Unlike tensorflow/c/c_api.h, the API here is not guaranteed to be +// stable and can change without notice. #include "tensorflow/c/c_api.h" @@ -87,6 +89,10 @@ TF_CAPI_EXPORT extern void TFE_DeleteContext(TFE_Context* ctx, TF_Status* status TF_CAPI_EXPORT extern TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status); +// Clears the internal caches in the TFE context. Useful when reseeding random +// ops. +TF_CAPI_EXPORT extern void TFE_ContextClearCaches(TFE_Context* ctx); + // A handle to a tensor on a device. // // Like a TF_Tensor, a TFE_TensorHandle refers to a tensor with a value, shape, @@ -207,6 +213,19 @@ TF_CAPI_EXPORT extern void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function, TF_Status* status); +// Enables tracing of RunMetadata on the ops executed from this context. +TF_CAPI_EXPORT extern void TFE_ContextEnableRunMetadata(TFE_Context* ctx); + +// Disables tracing of RunMetadata on the ops executed from this context. +TF_CAPI_EXPORT extern void TFE_ContextDisableRunMetadata(TFE_Context* ctx); + +// Populates the passed-in buffer with a serialized RunMetadata protocol buffer +// containing any run metadata information accumulated so far and clears this +// information. +TF_CAPI_EXPORT extern void TFE_ContextExportRunMetadata(TFE_Context* ctx, + TF_Buffer* buf, + TF_Status* status); + #ifdef __cplusplus } /* end extern "C" */ #endif diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h index 0971e2ab2fe98cc8bf6f631f41d5adce90ee7051..55a04d48bad63a8c19ffdc39675b1e1b70ac80d7 100644 --- a/tensorflow/c/eager/c_api_internal.h +++ b/tensorflow/c/eager/c_api_internal.h @@ -58,15 +58,21 @@ struct TFE_Context { // session->devices[i]. std::unique_ptr pflr; + tensorflow::mutex cache_mu; std::unordered_map - kernel_cache; + kernel_cache GUARDED_BY(cache_mu); tensorflow::FunctionLibraryRuntime* func_lib(tensorflow::Device* d) { return pflr->GetFLR(d->name()); } const std::vector& devices() { return session->devices; } + + // Whether we should compute RunMetadata. + std::atomic should_store_metadata{false}; + tensorflow::mutex metadata_mu; + tensorflow::RunMetadata run_metadata GUARDED_BY(metadata_mu); }; struct TFE_TensorHandle { diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc index 3fe0b7efa11bc619ed98bf9a1634ade5b6ed0a7c..423a7e1ff71bfdc5f51e36ae63359869ea079ddc 100644 --- a/tensorflow/c/eager/c_api_test.cc +++ b/tensorflow/c/eager/c_api_test.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/protobuf/config.pb.h" using tensorflow::string; @@ -216,6 +217,64 @@ TEST(CAPI, TensorHandleCopyBetweenDevices) { EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); } +TEST(CAPI, TensorHandleCopyBetweenTwoGPUDevices) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_Context* ctx = TFE_NewContext(opts, status.get()); + TFE_DeleteContextOptions(opts); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + TFE_TensorHandle* hcpu = TestMatrixTensorHandle(); + TF_Tensor* t = TFE_TensorHandleResolve(hcpu, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + TF_DeviceList* devices = TFE_ContextListDevices(ctx, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + const int num_devices = TF_DeviceListCount(devices); + + const char* kCPUDevice = "CPU:0"; + if (num_devices < 3) { + TF_DeleteDeviceList(devices); + TF_DeleteTensor(t); + TFE_DeleteTensorHandle(hcpu); + TFE_DeleteContext(ctx, status.get()); + return; + } + const string gpu_1_name(TF_DeviceListName(devices, 1, status.get())); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK); + const string gpu_2_name(TF_DeviceListName(devices, 2, status.get())); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK); + TFE_TensorHandle* hdevice = + TFE_TensorHandleCopyToDevice(hcpu, ctx, gpu_1_name.c_str(), status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK); + + TFE_TensorHandle* hdevice2 = TFE_TensorHandleCopyToDevice( + hdevice, ctx, gpu_2_name.c_str(), status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK); + TFE_DeleteTensorHandle(hdevice); + // Copy back to CPU + TFE_TensorHandle* hcopy = + TFE_TensorHandleCopyToDevice(hdevice2, ctx, kCPUDevice, status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK); + TFE_DeleteTensorHandle(hdevice2); + + // Ensure that the contents are the same! + TF_Tensor* tcopy = TFE_TensorHandleResolve(hcopy, status.get()); + TFE_DeleteTensorHandle(hcopy); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK); + EXPECT_EQ(TF_TensorByteSize(t), TF_TensorByteSize(tcopy)); + EXPECT_EQ( + 0, memcmp(TF_TensorData(t), TF_TensorData(tcopy), TF_TensorByteSize(t))); + TF_DeleteTensor(tcopy); + + TF_DeleteDeviceList(devices); + TF_DeleteTensor(t); + TFE_DeleteTensorHandle(hcpu); + TFE_DeleteContext(ctx, status.get()); + EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); +} + TEST(CAPI, TensorHandleSilentCopy) { std::unique_ptr status( TF_NewStatus(), TF_DeleteStatus); @@ -295,6 +354,47 @@ TEST(CAPI, Execute) { TF_DeleteStatus(status); } +TEST(CAPI, ExecuteWithTracing) { + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_Context* ctx = TFE_NewContext(opts, status); + TFE_ContextEnableRunMetadata(ctx); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + TFE_TensorHandle* m = TestMatrixTensorHandle(); + TFE_Op* matmul = MatMulOp(ctx, m, m); + TFE_TensorHandle* retvals[2] = {nullptr}; + int num_retvals = 2; // Should be reduced to 1 by the TFE_Execute call. + TFE_Execute(matmul, &retvals[0], &num_retvals, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteOp(matmul); + TFE_DeleteTensorHandle(m); + TF_Buffer* b = TF_NewBuffer(); + TFE_ContextExportRunMetadata(ctx, b, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + tensorflow::RunMetadata rm; + EXPECT_TRUE( + rm.ParseFromString({reinterpret_cast(b->data), b->length})); + TF_DeleteBuffer(b); + TFE_DeleteContext(ctx, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + ASSERT_EQ(1, num_retvals); + + TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status); + TFE_DeleteTensorHandle(retvals[0]); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + float product[4] = {0}; + EXPECT_EQ(sizeof(product), TF_TensorByteSize(t)); + memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t)); + TF_DeleteTensor(t); + EXPECT_EQ(7, product[0]); + EXPECT_EQ(10, product[1]); + EXPECT_EQ(15, product[2]); + EXPECT_EQ(22, product[3]); + TF_DeleteStatus(status); +} + TEST(CAPI, Function) { // First create a simple identity function. TF_Graph* function_graph = TF_NewGraph(); diff --git a/tensorflow/c/eager/runtime.cc b/tensorflow/c/eager/runtime.cc index 38066682a9fc5038c34a4ac3b20a67ceb08ab951..3a9951e14de3a70e0b9e47fa62e6342e063c4bed 100644 --- a/tensorflow/c/eager/runtime.cc +++ b/tensorflow/c/eager/runtime.cc @@ -262,7 +262,8 @@ Status KernelAndDevice::Init(const NodeDef& ndef, FunctionLibraryRuntime* flib, } Status KernelAndDevice::Run(std::vector* input_tensors, - std::vector* output_tensors) { + std::vector* output_tensors, + NodeExecStats* stats) { gtl::InlinedVector inputs; for (Tensor& t : *input_tensors) { inputs.push_back(TensorValue(&t)); @@ -284,6 +285,9 @@ Status KernelAndDevice::Run(std::vector* input_tensors, params.function_library = flib_; params.slice_reader_cache = &slice_reader_cache_; params.rendezvous = rendez_; + if (stats != nullptr) { + params.track_allocations = true; + } // TODO(apassos): use a thread pool. std::function)> runner = [](std::function f) { f(); }; @@ -297,6 +301,28 @@ Status KernelAndDevice::Run(std::vector* input_tensors, for (int i = 0; i < context.num_outputs(); ++i) { output_tensors->push_back(Tensor(*context.mutable_output(i))); } + if (stats != nullptr) { + for (const auto& allocator_pair : context.wrapped_allocators()) { + AllocatorMemoryUsed* memory = stats->add_memory(); + memory->set_allocator_name(allocator_pair.first->Name()); + auto sizes = allocator_pair.second->GetSizes(); + memory->set_total_bytes(std::get<0>(sizes)); + memory->set_peak_bytes(std::get<1>(sizes)); + memory->set_live_bytes(std::get<2>(sizes)); + + AllocatorStats allocator_stats; + allocator_pair.first->GetStats(&allocator_stats); + memory->set_allocator_bytes_in_use(allocator_stats.bytes_in_use); + allocator_pair.second->GetRecordsAndUnRef(); + } + auto* ms = stats->mutable_memory_stats(); + ms->set_temp_memory_size(context.temp_memory_size()); + for (const auto& alloc_id : context.persistent_alloc_ids()) { + ms->mutable_persistent_tensor_alloc_ids()->Add(alloc_id); + } + + ms->set_persistent_memory_size(context.persistent_memory_allocated()); + } return Status::OK(); } diff --git a/tensorflow/c/eager/runtime.h b/tensorflow/c/eager/runtime.h index fb97e94a94103d17164cb30f6c6e0ed3e07dc103..e28a416e67f8382dbd490648106a7eb6e5fcfd13 100644 --- a/tensorflow/c/eager/runtime.h +++ b/tensorflow/c/eager/runtime.h @@ -175,7 +175,8 @@ class KernelAndDevice { : device_(nullptr), flib_(nullptr), rendez_(rendez) {} // TODO(ashankar): Handle list-valued inputs. - Status Run(std::vector* inputs, std::vector* outputs); + Status Run(std::vector* inputs, std::vector* outputs, + NodeExecStats* stats); const OpKernel* kernel() const { return kernel_.get(); } diff --git a/tensorflow/c/eager/runtime_test.cc b/tensorflow/c/eager/runtime_test.cc index 3236c6be0ec5281e8099219968dd5f5c6c2048c3..2ccca66f672b96b3c782ddbfc828eeda270cebee 100644 --- a/tensorflow/c/eager/runtime_test.cc +++ b/tensorflow/c/eager/runtime_test.cc @@ -96,7 +96,7 @@ TEST(KernelAndDevice, Run) { KernelAndDevice::Init(ndef, env.function_library_runtime(), &kernel); ASSERT_TRUE(s.ok()) << s; std::vector outputs; - s = kernel.Run(&inputs, &outputs); + s = kernel.Run(&inputs, &outputs, nullptr); ASSERT_TRUE(s.ok()) << s; ASSERT_EQ(1, outputs.size()); const Tensor& out = outputs[0]; @@ -183,7 +183,7 @@ void BM_KernelAndDeviceRun(int iters) { KernelAndDevice::Init(ndef, env.function_library_runtime(), &kernel)); tensorflow::testing::StartTiming(); for (int i = 0; i < iters; ++i) { - TF_CHECK_OK(kernel.Run(&inputs, &outputs)); + TF_CHECK_OK(kernel.Run(&inputs, &outputs, nullptr)); } } BENCHMARK(BM_KernelAndDeviceRun); diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h index f52248e7d567b8edd911c6dba1786ceb5d5c721c..2b65e38f54090af6731685f78d5f7f914a875e3c 100644 --- a/tensorflow/c/eager/tape.h +++ b/tensorflow/c/eager/tape.h @@ -161,7 +161,7 @@ class GradientTape { // the tape refer to it); to aid in tape garbage collection. std::unordered_map tensor_usage_; - // If true, all activations are deleted in the first call to ComputeGradient. + // If false, all activations are deleted in the first call to ComputeGradient. // Else, only when this is destructed. bool persistent_; }; @@ -350,7 +350,7 @@ BackpropInitialState PrepareBackprop( // Call destructors for all unneeded gradient functions and // clear the op_tape. We can clear the tape because ownership of // backward functions that will be used for gradient computation - // has been transfered to `result`. + // has been transferred to `result`. for (const auto& op_pair : *op_tape) { op_pair.second.backward_function_deleter(); } @@ -491,6 +491,7 @@ Status GradientTape::ComputeGradient( state.op_tape.erase(op_it); std::vector out_gradients; out_gradients.reserve(trace.output_tensor_info.size()); + bool any_gradient_nonzero = false; for (int i = 0; i < trace.output_tensor_info.size(); ++i) { const int64 id = trace.output_tensor_info[i].id; auto grad_it = gradients.find(id); @@ -506,6 +507,7 @@ Status GradientTape::ComputeGradient( trace.output_tensor_info[i].dtype)); } } else { + any_gradient_nonzero = true; out_gradients.push_back(vspace.AggregateGradients(grad_it->second)); if (sources_set.find(grad_it->first) == sources_set.end()) { gradients.erase(grad_it); @@ -513,14 +515,26 @@ Status GradientTape::ComputeGradient( } } std::vector in_gradients; - Status s = vspace.CallBackwardFunction(trace.backward_function, - out_gradients, &in_gradients); - if (!persistent_) { - vspace.ReleaseBackwardFunction(trace.backward_function); - } - if (!s.ok()) { - cleanup(); - return s; + if (any_gradient_nonzero) { + Status s = vspace.CallBackwardFunction(trace.backward_function, + out_gradients, &in_gradients); + if (!persistent_) { + vspace.ReleaseBackwardFunction(trace.backward_function); + } + if (!s.ok()) { + cleanup(); + return s; + } + } else { + in_gradients.resize(trace.input_tensor_id.size()); + if (!persistent_) { + vspace.ReleaseBackwardFunction(trace.backward_function); + } + for (Gradient* grad : out_gradients) { + if (grad != nullptr) { + vspace.DeleteGradient(grad); + } + } } VLOG(1) << "Got " << in_gradients.size() << " in_gradients for " << trace.input_tensor_id.size() << " sources"; diff --git a/tensorflow/c/python_api.cc b/tensorflow/c/python_api.cc index ba5a9268b4f671499590d66fb41060dd18e1ce47..6e37cdb5f4beea53d4a2ded0705ae482d0bc2d68 100644 --- a/tensorflow/c/python_api.cc +++ b/tensorflow/c/python_api.cc @@ -22,6 +22,7 @@ namespace tensorflow { void AddControlInput(TF_Graph* graph, TF_Operation* op, TF_Operation* input) { mutex_lock l(graph->mu); graph->graph.AddControlEdge(&input->node, &op->node); + RecordMutation(graph, *op, "adding control input"); } void SetAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name, @@ -36,11 +37,13 @@ void SetAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name, mutex_lock l(graph->mu); op->node.AddAttr(attr_name, attr_val); + RecordMutation(graph, *op, "setting attribute"); } void SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device) { mutex_lock l(graph->mu); op->node.set_requested_device(device); + RecordMutation(graph, *op, "setting device"); } void UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst, @@ -75,6 +78,25 @@ void UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst, } status->status = graph->graph.UpdateEdge(&new_src.oper->node, new_src.index, &dst.oper->node, dst.index); + + if (status->status.ok()) { + // This modification only updates the destination node for + // the purposes of running this graph in a session. Thus, we don't + // record the source node as being modified. + RecordMutation(graph, *dst.oper, "updating input tensor"); + } +} + +void RemoveAllControlInputs(TF_Graph* graph, TF_Operation* op) { + mutex_lock l(graph->mu); + std::vector control_edges; + for (const Edge* edge : op->node.in_edges()) { + if (!edge->IsControlEdge()) continue; + control_edges.push_back(edge); + } + for (const Edge* edge : control_edges) { + graph->graph.RemoveControlEdge(edge); + } } } // namespace tensorflow diff --git a/tensorflow/c/python_api.h b/tensorflow/c/python_api.h index f54585b0a1034ff108202272a11416e34985959e..b51ef2b53122802fef598a26bd6f1843976f11b0 100644 --- a/tensorflow/c/python_api.h +++ b/tensorflow/c/python_api.h @@ -35,6 +35,8 @@ void SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device); void UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst, TF_Status* status); +void RemoveAllControlInputs(TF_Graph* graph, TF_Operation* op); + } // namespace tensorflow #endif // THIRD_PARTY_TENSORFLOW_C_PYTHON_API_H_ diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD index e354831d7d25af83c068a68a4f844056263a598c..ddcee3deee444382f4bdb206de6f06ee62265a51 100644 --- a/tensorflow/cc/BUILD +++ b/tensorflow/cc/BUILD @@ -421,7 +421,7 @@ tf_cc_test( tf_gen_op_wrappers_cc( name = "cc_ops", - api_def_srcs = ["//tensorflow/core:base_api_def"], + api_def_srcs = ["//tensorflow/core/api_def:base_api_def"], op_lib_names = [ "array_ops", "audio_ops", @@ -448,7 +448,6 @@ tf_gen_op_wrappers_cc( "ops/const_op.h", "ops/standard_ops.h", ], - override_file = "ops/op_gen_overrides.pbtxt", pkg = "//tensorflow/core", ) @@ -527,14 +526,13 @@ cc_library_with_android_deps( ], copts = tf_copts(), data = [ - "//tensorflow/core:base_api_def", + "//tensorflow/core/api_def:base_api_def", ], deps = [ "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:op_gen_lib", - "//tensorflow/core:op_gen_overrides_proto_cc", "//tensorflow/core:proto_text", "//tensorflow/core:protos_all_cc", ], @@ -547,15 +545,11 @@ tf_cc_test( "framework/cc_op_gen.h", "framework/cc_op_gen_test.cc", ], - data = [ - "//tensorflow/cc:ops/op_gen_overrides.pbtxt", - ], deps = [ "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:op_gen_lib", - "//tensorflow/core:op_gen_overrides_proto_cc", "//tensorflow/core:proto_text", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", diff --git a/tensorflow/cc/client/client_session_test.cc b/tensorflow/cc/client/client_session_test.cc index dfbac9788e16e9c7c65abcd1ea213b51d5d5d060..ea5cf5a1f12be316cc6e0d0a02cd3caf4d177400 100644 --- a/tensorflow/cc/client/client_session_test.cc +++ b/tensorflow/cc/client/client_session_test.cc @@ -23,7 +23,13 @@ limitations under the License. #include "tensorflow/core/platform/test.h" namespace tensorflow { -using namespace ops; // NOLINT(build/namespaces) +namespace { + +using ops::Add; +using ops::Const; +using ops::Mul; +using ops::Placeholder; +using ops::Sub; TEST(ClientSessionTest, Basic) { Scope root = Scope::NewRootScope(); @@ -89,4 +95,5 @@ TEST(ClientSessionTest, MultiThreaded) { test::ExpectTensorEqual(outputs[0], test::AsTensor({-1, 2}, {2})); } -} // end namespace tensorflow +} // namespace +} // namespace tensorflow diff --git a/tensorflow/cc/framework/cc_op_gen.cc b/tensorflow/cc/framework/cc_op_gen.cc index d889c518f9c38a9f070970b37a2ad4b1fc26671b..a40ad1ffc3b262840e6ca0043139b1b61e04510d 100644 --- a/tensorflow/cc/framework/cc_op_gen.cc +++ b/tensorflow/cc/framework/cc_op_gen.cc @@ -1057,16 +1057,9 @@ string MakeInternal(const string& fname) { } // namespace void WriteCCOps(const OpList& ops, const ApiDefMap& api_def_map, - const string& dot_h_fname, const string& dot_cc_fname, - const string& overrides_fnames) { + const string& dot_h_fname, const string& dot_cc_fname) { Env* env = Env::Default(); - // Load the override map. - OpGenOverrideMap override_map; - if (!overrides_fnames.empty()) { - TF_CHECK_OK(override_map.LoadFileList(env, overrides_fnames)); - } - // Write the initial boilerplate to the .h and .cc files. std::unique_ptr h = nullptr; std::unique_ptr cc = nullptr; diff --git a/tensorflow/cc/framework/cc_op_gen.h b/tensorflow/cc/framework/cc_op_gen.h index cea28990144b9371e8009ce13f912b44044f9aac..1b5f7dd923731e56ab3d7e5288d17fef9eb3beb0 100644 --- a/tensorflow/cc/framework/cc_op_gen.h +++ b/tensorflow/cc/framework/cc_op_gen.h @@ -24,8 +24,7 @@ namespace tensorflow { /// Result is written to files dot_h and dot_cc. void WriteCCOps(const OpList& ops, const ApiDefMap& api_def_map, - const string& dot_h_fname, const string& dot_cc_fname, - const string& overrides_fnames); + const string& dot_h_fname, const string& dot_cc_fname); } // namespace tensorflow diff --git a/tensorflow/cc/framework/cc_op_gen_main.cc b/tensorflow/cc/framework/cc_op_gen_main.cc index 326d5668b8803ee39ffe24900c92e1db87b93601..3157792e15a006555e4924eea3c72ea643e79c1c 100644 --- a/tensorflow/cc/framework/cc_op_gen_main.cc +++ b/tensorflow/cc/framework/cc_op_gen_main.cc @@ -28,7 +28,7 @@ namespace tensorflow { namespace { void PrintAllCCOps(const std::string& dot_h, const std::string& dot_cc, - const std::string& overrides_fnames, bool include_internal, + bool include_internal, const std::vector& api_def_dirs) { OpList ops; OpRegistry::Global()->Export(include_internal, &ops); @@ -49,7 +49,7 @@ void PrintAllCCOps(const std::string& dot_h, const std::string& dot_cc, api_def_map.UpdateDocs(); - WriteCCOps(ops, api_def_map, dot_h, dot_cc, overrides_fnames); + WriteCCOps(ops, api_def_map, dot_h, dot_cc); } } // namespace @@ -57,24 +57,21 @@ void PrintAllCCOps(const std::string& dot_h, const std::string& dot_cc, int main(int argc, char* argv[]) { tensorflow::port::InitMain(argv[0], &argc, &argv); - // TODO(annarev): Update this file to no longer take op_gen_overrides.pbtxt - // as an argument. - if (argc != 6) { + if (argc != 5) { for (int i = 1; i < argc; ++i) { fprintf(stderr, "Arg %d = %s\n", i, argv[i]); } fprintf(stderr, - "Usage: %s out.h out.cc overrides1.pbtxt,2.pbtxt include_internal " + "Usage: %s out.h out.cc include_internal " "api_def_dirs1,api_def_dir2 ...\n" " include_internal: 1 means include internal ops\n", argv[0]); exit(1); } - bool include_internal = tensorflow::StringPiece("1") == argv[4]; + bool include_internal = tensorflow::StringPiece("1") == argv[3]; std::vector api_def_dirs = tensorflow::str_util::Split( - argv[5], ",", tensorflow::str_util::SkipEmpty()); - tensorflow::PrintAllCCOps(argv[1], argv[2], argv[3], include_internal, - api_def_dirs); + argv[4], ",", tensorflow::str_util::SkipEmpty()); + tensorflow::PrintAllCCOps(argv[1], argv[2], include_internal, api_def_dirs); return 0; } diff --git a/tensorflow/cc/framework/cc_op_gen_test.cc b/tensorflow/cc/framework/cc_op_gen_test.cc index 0b7e720a5c7b343415eee1aa157b8de755a1e1a5..1e0f2d241bb350897a840dda90d6d0c009b1daad 100644 --- a/tensorflow/cc/framework/cc_op_gen_test.cc +++ b/tensorflow/cc/framework/cc_op_gen_test.cc @@ -24,10 +24,6 @@ limitations under the License. namespace tensorflow { namespace { -// TODO(annarev): Remove this op_gen_overrides.pbtxt reference. -// It is needed only because WriteCCOps takes it as an argument. -constexpr char kOverridesFnames[] = - "tensorflow/cc/ops/op_gen_overrides.pbtxt"; constexpr char kBaseOpDef[] = R"( op { name: "Foo" @@ -96,7 +92,7 @@ void GenerateCcOpFiles(Env* env, const OpList& ops, const auto internal_h_file_path = io::JoinPath(tmpdir, "test_internal.h"); const auto internal_cc_file_path = io::JoinPath(tmpdir, "test_internal.cc"); - WriteCCOps(ops, api_def_map, h_file_path, cc_file_path, kOverridesFnames); + WriteCCOps(ops, api_def_map, h_file_path, cc_file_path); TF_ASSERT_OK(ReadFileToString(env, h_file_path, h_file_text)); TF_ASSERT_OK( diff --git a/tensorflow/cc/framework/cc_ops_test.cc b/tensorflow/cc/framework/cc_ops_test.cc index 5da23036eaadbef270ba839357dc4613bf3bf490..ac05e3cf95b1ce4009ee1424713baf2d34902a94 100644 --- a/tensorflow/cc/framework/cc_ops_test.cc +++ b/tensorflow/cc/framework/cc_ops_test.cc @@ -22,8 +22,7 @@ limitations under the License. #include "tensorflow/core/lib/core/status_test_util.h" namespace tensorflow { -using namespace ops; // NOLINT(build/namespaces) - +namespace ops { namespace { Output Linear(const Scope& scope, Input x, Input w, Input b) { @@ -39,8 +38,6 @@ void GetColocationConstraints(const Output& tensor, constraints)); } -} // namespace - TEST(CCOpTest, Basic) { Scope root = Scope::NewRootScope(); auto c = Const(root, {{1, 1}}); @@ -249,4 +246,6 @@ TEST(CCOpTest, InvalidFinalize) { string::npos); } +} // namespace +} // namespace ops } // namespace tensorflow diff --git a/tensorflow/cc/framework/gradient_checker_test.cc b/tensorflow/cc/framework/gradient_checker_test.cc index fdc457f40af875d7c0c243246755d0cb87c44a62..d4f0a7f5ab3716be41e22c02a21aca028f76fb88 100644 --- a/tensorflow/cc/framework/gradient_checker_test.cc +++ b/tensorflow/cc/framework/gradient_checker_test.cc @@ -24,10 +24,18 @@ limitations under the License. #include "tensorflow/core/util/equal_graph_def.h" namespace tensorflow { -using namespace ops; // NOLINT(build/namespaces) - namespace { +using ops::Complex; +using ops::Const; +using ops::MatMul; +using ops::Placeholder; +using ops::Real; +using ops::Split; +using ops::Square; +using ops::Stack; +using ops::Unstack; + TEST(GradientCheckerTest, BasicFloat) { Scope scope = Scope::NewRootScope(); TensorShape shape({2, 4, 3}); diff --git a/tensorflow/cc/framework/gradients_test.cc b/tensorflow/cc/framework/gradients_test.cc index 07a062e704ed6ffc6389b5897309957a1bfcd1c2..26e3170ad8e4f4fba1c2dc014086acf24d949f72 100644 --- a/tensorflow/cc/framework/gradients_test.cc +++ b/tensorflow/cc/framework/gradients_test.cc @@ -26,10 +26,20 @@ limitations under the License. #include "tensorflow/core/util/equal_graph_def.h" namespace tensorflow { -using namespace ops; // NOLINT(build/namespaces) - namespace { +using ops::Assign; +using ops::Const; +using ops::Identity; +using ops::MatMul; +using ops::OnesLike; +using ops::Placeholder; +using ops::Square; +using ops::Stack; +using ops::StopGradient; +using ops::Unstack; +using ops::Variable; + // TODO(andydavis) Add more unit tests once more gradient functions are ported. class GradientsTest : public ::testing::Test { protected: diff --git a/tensorflow/cc/gradients/array_grad_test.cc b/tensorflow/cc/gradients/array_grad_test.cc index 455d7330c10cf230462869475f25a1f1b9bf9e9e..4a215fcc9299cf8b8da04cbf151640631ed0d449 100644 --- a/tensorflow/cc/gradients/array_grad_test.cc +++ b/tensorflow/cc/gradients/array_grad_test.cc @@ -23,11 +23,11 @@ limitations under the License. #include "tensorflow/core/lib/core/status_test_util.h" namespace tensorflow { +namespace { + using namespace ops; // NOLINT(build/namespaces) using ops::internal::MirrorPadGrad; -namespace { - class ArrayGradTest : public ::testing::Test { protected: ArrayGradTest() : scope_(Scope::NewRootScope()) {} diff --git a/tensorflow/cc/gradients/data_flow_grad_test.cc b/tensorflow/cc/gradients/data_flow_grad_test.cc index 734dfd3af97b856a7c8c4894c4a6d1a3ade10992..0ba3c0e27b1e545a30925ea3ef9e2c54dc9d0ae9 100644 --- a/tensorflow/cc/gradients/data_flow_grad_test.cc +++ b/tensorflow/cc/gradients/data_flow_grad_test.cc @@ -23,10 +23,13 @@ limitations under the License. #include "tensorflow/core/lib/random/random.h" namespace tensorflow { -using namespace ops; // NOLINT(build/namespaces) - namespace { +using ops::Const; +using ops::DynamicPartition; +using ops::DynamicStitch; +using ops::Placeholder; + class DataFlowGradTest : public ::testing::Test { protected: DataFlowGradTest() : scope_(Scope::NewRootScope()) {} diff --git a/tensorflow/cc/gradients/grad_testutil.cc b/tensorflow/cc/gradients/grad_testutil.cc index 04b29d4e8b21eeee200d9e7390868d701eda3c22..304117d3719346202d3a8a18637f7c915d4a47f9 100644 --- a/tensorflow/cc/gradients/grad_testutil.cc +++ b/tensorflow/cc/gradients/grad_testutil.cc @@ -18,16 +18,14 @@ limitations under the License. #include "tensorflow/cc/framework/grad_op_registry.h" namespace tensorflow { -using namespace ops; // NOLINT(build/namespaces) - namespace test { Status CallGradFunction(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { - GradFunc grad_fn; - TF_RETURN_IF_ERROR( - GradOpRegistry::Global()->Lookup(op.node()->type_string(), &grad_fn)); + ops::GradFunc grad_fn; + TF_RETURN_IF_ERROR(ops::GradOpRegistry::Global()->Lookup( + op.node()->type_string(), &grad_fn)); TF_RETURN_IF_ERROR(grad_fn(scope, op, grad_inputs, grad_outputs)); TF_RETURN_IF_ERROR(scope.status()); return Status::OK(); diff --git a/tensorflow/cc/gradients/math_grad.cc b/tensorflow/cc/gradients/math_grad.cc index d7446b9560fd7dc8377ea3710641906b274313a9..52c177212a8c88f1857defcc38de4a01ac47dab0 100644 --- a/tensorflow/cc/gradients/math_grad.cc +++ b/tensorflow/cc/gradients/math_grad.cc @@ -473,6 +473,41 @@ Status AddNGrad(const Scope& scope, const Operation& op, } REGISTER_GRADIENT_OP("AddN", AddNGrad); +Status PowGrad(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + auto x = ConjugateHelper(scope, op.input(0)); + auto y = ConjugateHelper(scope, op.input(1)); + auto z = ConjugateHelper(scope, op.output(0)); + auto grad = grad_inputs[0]; + // grad * y * pow(x, y - 1) + auto one = Cast(scope, Const(scope, 1.0), y.type()); + auto gx_1 = Mul(scope, + Mul(scope, grad, y), + Pow(scope, x, Sub(scope, y, one))); + // Avoid false singularity at x = 0 + DataType x_dtype = x.type(); + auto zero = Cast(scope, Const(scope, 0.0), x_dtype); + if (x_dtype == DT_COMPLEX64 || x_dtype == DT_COMPLEX128) { + // real(x) < 0 is fine for the complex case + auto log_x = Where3(scope, + NotEqual(scope, x, zero), + Log(scope, x), + ZerosLike(scope, x)); + auto gy_1 = Mul(scope, Mul(scope, grad, z), log_x); + return BinaryGradCommon(scope, op, grad_outputs, gx_1, gy_1); + } else { + // There's no sensible real value to return if x < 0, so return 0 + auto log_x = Where3(scope, + Greater(scope, x, zero), + Log(scope, x), + ZerosLike(scope, x)); + auto gy_1 = Mul(scope, Mul(scope, grad, z), log_x); + return BinaryGradCommon(scope, op, grad_outputs, gx_1, gy_1); + } +} +REGISTER_GRADIENT_OP("Pow", PowGrad); + // MaximumMinimumGradCommon adds shared ops to calculate gradients for // the binary Maximum and Minimum ops. Status MaximumMinimumGradCommon(const Scope& scope, const Operation& op, @@ -794,6 +829,183 @@ Status MinOrMaxGrad(const Scope& scope, const Operation& op, REGISTER_GRADIENT_OP("Min", MinOrMaxGrad); REGISTER_GRADIENT_OP("Max", MinOrMaxGrad); +Status ProdGrad(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + auto zero = Const(scope, 0); + auto one = Const(scope, 1); + + // The gradient can be expressed by dividing the product by each entry of + // the input tensor. If our input is + // [ + // [3, 4], + // [5, 6], + // [7, 8] + // ] + // and we do a Prod operation on the axis 1, we will obtain [[105, 192]]. + // The gradient will have the same shape as the input + // [ + // [105/3, 192/4], + // dz * [105/5, 192/6], + // [105/7, 192/6] + // ] + // If the input contains a zero, the division is impossible but + // if we take the calculation that gave the first gradient + // (3 * 5 * 6)/3 is equal to 5 * 6 + // the trick will be to cumprod the elements on the axis without + // the element at the current position (3 in the example above). + // We will take as example: + // [ + // [ + // [3.0, 4.0], + // [5.0, 6.0], + // [7.0, 8.0] + // ], + // [ + // [3.0, 5.0], + // [0.0, 6.0], + // [5.0, 6.0] + // ] + // ] + + // [2, 3, 2] + auto input_shape = Shape(scope, op.input(0)); + + // The Reshape with -1 flattens the reduction indices. + // [1] + auto reduction_indices = Reshape(scope, op.input(1), {-1}); + + // [2, 1, 2] + auto output_shape_kept_dims = + ReducedShapeHelper(scope, input_shape, reduction_indices); + + // [1, 3, 1] + auto tile_scaling = SafeDivHelper(scope, input_shape, output_shape_kept_dims); + + // [[[105, 192]], [[0, 180]]] + auto grad = Reshape(scope, grad_inputs[0], output_shape_kept_dims); + + // [[[105, 192], [105, 192], [105, 192]], [[0, 180], [0, 180], [0, 180]]] + auto grad_tiled = Tile(scope, grad, tile_scaling); + + Scope cpu_scope = scope.WithDevice("/cpu:0"); + + // [3] + auto rank = Rank(cpu_scope, op.input(0)); + + + // Normalize any negative indices in the reduction_axes to positive values. + auto reduction_indices_pos = Mod(cpu_scope, Add(cpu_scope, reduction_indices, rank), rank); + + // [1] + auto reduced = Cast(cpu_scope, reduction_indices_pos, DataType::DT_INT32); + + // [0, 1, 2] + auto idx = Range(cpu_scope, zero, rank, one); + + // [0, 2] + auto other = SetDiff1D(cpu_scope, idx, reduced).out; + + // [1, 0, 2] + auto perm = + Concat(cpu_scope, std::initializer_list{reduced, other}, 0); + + // 3 => [3] + auto reduced_num = Prod(cpu_scope, Gather(scope, input_shape, reduced), 0); + + // 2 * 2 => [2] + auto other_num = Prod(cpu_scope, Gather(scope, input_shape, other), 0); + + // [ + // [ + // [ 3., 4.], + // [ 3., 5.] + // ], + // [ + // [ 5., 6.], + // [ 0., 6.] + // ], + // [ + // [ 7., 8.], + // [ 5., 6.] + // ] + // ] + auto permuted = Transpose(scope, op.input(0), perm); + + // [3, 2, 2] + auto permuted_shape = Shape(scope, permuted); + + // [ + // [ 3., 4., 3., 5.], + // [ 5., 6., 0., 6.], + // [ 7., 8., 5., 6.] + // ] + auto reshaped = Reshape( + scope, permuted, + Stack(scope, std::initializer_list{reduced_num, other_num})); + + // [ + // [ 1., 1., 1., 1.], + // [ 3., 4., 3., 5.], + // [ 15., 24., 0., 30.] + // ] + auto left = Cumprod(scope, reshaped, zero, Cumprod::Exclusive(true)); + + // [ + // [ 35., 48., 0., 36.], + // [ 7., 8., 5., 6.], + // [ 1., 1., 1., 1.] + // ] + auto right = + Cumprod(scope, reshaped, zero, Cumprod::Exclusive(true).Reverse(true)); + + // left * right = + // [ + // [ 35., 48., 0., 36.], + // [ 21., 32., 15., 30.], + // [ 15., 24., 0., 30.] + // ] + // y = + // [ + // [ + // [ 35., 48.], + // [ 0., 36.] + // ], + // [ + // [ 21., 32.], + // [ 15., 30.] + // ], + // [ + // [ 15., 24.], + // [ 0., 30.] + // ] + // ] + auto y = Reshape(scope, Mul(scope, left, right), permuted_shape); + + // out = + // [ + // [ + // [ 35., 48.], + // [ 21., 32.], + // [ 15., 24.] + // ], + // [ + // [ 0., 36.], + // [ 15., 30.], + // [ 0., 30.] + // ] + // ] + auto out = + Mul(scope, grad_tiled, Transpose(scope, y, InvertPermutation(scope, perm))); + + grad_outputs->push_back(Reshape(scope, out, input_shape)); + + // stop propagation along reduction_indices + grad_outputs->push_back(NoGradient()); + return scope.status(); +} +REGISTER_GRADIENT_OP("Prod", ProdGrad); + // MatMulGrad helper function used to compute two MatMul operations // based on input matrix transposition combinations. Status MatMulGradHelper(const Scope& scope, const bool is_batch, diff --git a/tensorflow/cc/gradients/math_grad_test.cc b/tensorflow/cc/gradients/math_grad_test.cc index 6313f41da5e5f9cf88be4c8a84408a8df77f0e25..1b4c7c2688083e74433da3dce2849b8c37443684 100644 --- a/tensorflow/cc/gradients/math_grad_test.cc +++ b/tensorflow/cc/gradients/math_grad_test.cc @@ -23,10 +23,31 @@ limitations under the License. #include "tensorflow/core/lib/random/random.h" namespace tensorflow { -using namespace ops; // NOLINT(build/namespaces) - namespace { +using ops::Abs; +using ops::Add; +using ops::AddN; +using ops::BatchMatMul; +using ops::Const; +using ops::Div; +using ops::Greater; +using ops::MatMul; +using ops::Max; +using ops::Maximum; +using ops::Mean; +using ops::Min; +using ops::Minimum; +using ops::Mul; +using ops::Placeholder; +using ops::Pow; +using ops::Prod; +using ops::RealDiv; +using ops::SquaredDifference; +using ops::Sub; +using ops::Sum; +using ops::Where3; + // TODO(andydavis) Test gradient function against numeric gradients output. // TODO(andydavis) As more gradients are added move common test functions // to a testutil library. @@ -83,6 +104,7 @@ class CWiseUnaryGradTest : public ::testing::Test { Output y; switch (op_type) { + using namespace ops; // NOLINT(build/namespaces) case ABS: y = Abs(scope_, x); break; @@ -843,6 +865,14 @@ TEST_F(NaryGradTest, SquaredDifference) { RunTest({x1, x2}, {x1_shape, x2_shape}, {y}, {x1_shape}); } +TEST_F(NaryGradTest, Pow) { + TensorShape shape({3}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)); + // fix exponent to avoid overflow + auto y = Pow(scope_, x, Const(scope_, {1.f, 2.f, 3.f})); + RunTest({x}, {shape}, {y}, {shape}); +} + TEST_F(NaryGradTest, Maximum) { TensorShape shape({3, 2}); auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)); @@ -865,5 +895,14 @@ TEST_F(NaryGradTest, Minimum) { RunTest(x, x_init_value, y, shape); } +TEST_F(NaryGradTest, Prod) { + TensorShape x_shape({2, 3, 2}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); + auto y = Prod(scope_, x, {1}); + // y's shape is the result of reducing x along axes 1 + TensorShape y_shape({2, 1, 2}); + RunTest({x}, {x_shape}, {y}, {y_shape}); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/cc/gradients/nn_grad_test.cc b/tensorflow/cc/gradients/nn_grad_test.cc index f9063e836509669d81d03b1d2f0d32d1166b6eca..0cfe5f6e3c49f7c4a3cafbf48ff4e54a0ffd0d47 100644 --- a/tensorflow/cc/gradients/nn_grad_test.cc +++ b/tensorflow/cc/gradients/nn_grad_test.cc @@ -23,10 +23,22 @@ limitations under the License. #include "tensorflow/core/lib/random/random.h" namespace tensorflow { -using namespace ops; // NOLINT(build/namespaces) - namespace { +using ops::BiasAdd; +using ops::Conv2D; +using ops::Elu; +using ops::L2Loss; +using ops::LogSoftmax; +using ops::LRN; +using ops::MaxPool; +using ops::MaxPoolV2; +using ops::Placeholder; +using ops::Relu; +using ops::Relu6; +using ops::Selu; +using ops::Softmax; + class NNGradTest : public ::testing::Test { protected: NNGradTest() : scope_(Scope::NewRootScope()) {} diff --git a/tensorflow/cc/ops/op_gen_overrides.pbtxt b/tensorflow/cc/ops/op_gen_overrides.pbtxt deleted file mode 100644 index 4aac990e748b0a79cbc3b353b4121a582b0883b0..0000000000000000000000000000000000000000 --- a/tensorflow/cc/ops/op_gen_overrides.pbtxt +++ /dev/null @@ -1,238 +0,0 @@ -# array_ops -op { name: "BroadcastArgs" rename_to: "BroadcastDynamicShape" } -op { name: "BroadcastGradientArgs" hide: true } -op { name: "ConcatOffset" skip: true } # Maybe should just be hidden? -op { name: "Concat" skip: true } -op { name: "ConcatV2" rename_to: "Concat" } -op { name: "ExpandDims" input_rename: { from: "dim" to: "axis" } } -op { name: "ListDiff" rename_to: "SetDiff1D" } -op { name: "MirrorPadGrad" hide: true } -op { name: "Reverse" skip: true } -op { name: "ReverseV2" rename_to: "Reverse" } -op { name: "Split" input_rename: { from: "split_dim" to: "axis" } } -op { name: "SplitV" input_rename: { from: "split_dim" to: "axis" } } -op { name: "Squeeze" attr_rename: { from: "squeeze_dims" to: "axis" } } -op { name: "Pack" rename_to: "Stack" } -op { name: "Unpack" rename_to: "Unstack" } -op { name: "Select" rename_to: "Where3" input_rename: { from: "t" to: "x" } input_rename: { from: "e" to: "y" } } -op { name: "Where" input_rename: { from: "input" to: "condition" } } - - -# candidate_sampling_ops -op { name: "ThreadUnsafeUnigramCandidateSampler", skip: true } - -# control_flow_ops -# TODO(joshl): Hide Switch and Merge once we write and migrate users to -# a Cond() API. -#op { name: "Switch" hide: true } -#op { name: "Merge" hide: true } -op { name: "RefMerge" hide: true } -op { name: "Exit" hide: true } -op { name: "RefExit" hide: true } -op { name: "Enter" hide: true } -op { name: "RefEnter" hide: true } -op { name: "RefIdentity" hide: true } - -# ctc_ops - -# data_flow_ops -op { name: "FakeQueue" skip: true } -op { name: "FIFOQueue" skip: true} -op { name: "FIFOQueueV2" rename_to: "FIFOQueue" } -op { name: "PaddingFIFOQueue" skip: true } -op { name: "PaddingFIFOQueueV2" rename_to: "PaddingFIFOQueue" } -op { name: "PriorityQueue" skip: true } -op { name: "PriorityQueueV2" rename_to: "PriorityQueue" } -op { name: "QueueClose" skip: true } -op { name: "QueueCloseV2" rename_to: "QueueClose" } -op { name: "QueueDequeue" skip: true } -op { name: "QueueDequeueV2" rename_to: "QueueDequeue" } -op { name: "QueueDequeueMany" skip: true } -op { name: "QueueDequeueManyV2" rename_to: "QueueDequeueMany" } -op { name: "QueueDequeueUpTo" skip: true } -op { name: "QueueDequeueUpToV2" rename_to: "QueueDequeueUpTo" } -op { name: "QueueEnqueue" skip: true } -op { name: "QueueEnqueueV2" rename_to: "QueueEnqueue" } -op { name: "QueueEnqueueMany" skip: true } -op { name: "QueueEnqueueManyV2" rename_to: "QueueEnqueueMany" } -op { name: "QueueSize" skip: true } -op { name: "QueueSizeV2" rename_to: "QueueSize" } -op { name: "RandomShuffleQueue" skip: true } -op { name: "RandomShuffleQueueV2" rename_to: "RandomShuffleQueue" } -op { name: "ReaderNumRecordsProduced" skip: true } -op { name: "ReaderNumRecordsProducedV2" rename_to: "ReaderNumRecordsProduced" } -op { name: "ReaderNumWorkUnitsCompleted" skip: true } -op { name: "ReaderNumWorkUnitsCompletedV2" rename_to: "ReaderNumWorkUnitsCompleted" } -op { name: "ReaderRead" skip: true } -op { name: "ReaderReadUpTo" skip: true } -op { name: "ReaderReadUpToV2" rename_to: "ReaderReadUpTo" } -op { name: "ReaderReadV2" rename_to: "ReaderRead" } -op { name: "ReaderReset" skip: true } -op { name: "ReaderResetV2" rename_to: "ReaderReset" } -op { name: "ReaderRestoreState" skip: true } -op { name: "ReaderRestoreStateV2" rename_to: "ReaderRestoreState" } -op { name: "ReaderSerializeState" skip: true } -op { name: "ReaderSerializeStateV2" rename_to: "ReaderSerializeState" } -op { name: "FixedLengthRecordReader" skip: true } -op { name: "FixedLengthRecordReaderV2" rename_to: "FixedLengthRecordReader" } -op { name: "IdentityReader" skip: true } -op { name: "IdentityReaderV2" rename_to: "IdentityReader" } -op { name: "TFRecordReader" skip: true } -op { name: "TFRecordReaderV2" rename_to: "TFRecordReader" } -op { name: "TextLineReader" skip: true } -op { name: "TextLineReaderV2" rename_to: "TextLineReader" } - -# Skip hash table ops until we have better support in C++ (ops are currently -# only used in contrib) -op { name: "HashTable" skip: true } -op { name: "InitializeTable" skip: true } -op { name: "InitializeTableFromTextFile" skip: true } -op { name: "LookupTableFind" skip: true } -op { name: "LookupTableImport" skip: true } -op { name: "LookupTableInsert" skip: true } -op { name: "LookupTableSize" skip: true } -op { name: "MutableDenseHashTable" skip: true } -op { name: "MutableHashTable" skip: true } -op { name: "MutableHashTableOfTensors" skip: true } - -# Stack ops are internal to control flow gradients (not yet implemented in C++) -op { name: "Stack" skip: true } -op { name: "StackClose" skip: true } -op { name: "StackPop" skip: true } -op { name: "StackPush" skip: true } -op { name: "StackV2" skip: true } -op { name: "StackCloseV2" skip: true } -op { name: "StackPopV2" skip: true } -op { name: "StackPushV2" skip: true } - -op { name: "TensorArrayCloseV2" skip: true } -op { name: "TensorArrayCloseV3" rename_to: "TensorArrayClose" } -op { name: "TensorArrayConcatV2" skip: true } -op { name: "TensorArrayConcatV3" rename_to: "TensorArrayConcat" } -op { name: "TensorArrayGatherV2" skip: true } -op { name: "TensorArrayGatherV3" rename_to: "TensorArrayGather" } -op { name: "TensorArrayGradV2" skip: true } -op { name: "TensorArrayGradV3" rename_to: "TensorArrayGrad" } -op { name: "TensorArrayReadV2" skip: true } -op { name: "TensorArrayReadV3" rename_to: "TensorArrayRead" } -op { name: "TensorArrayScatterV2" skip: true } -op { name: "TensorArrayScatterV3" rename_to: "TensorArrayScatter" } -op { name: "TensorArraySizeV2" skip: true } -op { name: "TensorArraySizeV3" rename_to: "TensorArraySize" } -op { name: "TensorArraySplitV2" skip: true } -op { name: "TensorArraySplitV3" rename_to: "TensorArraySplit" } -op { name: "TensorArrayV2" skip: true } -op { name: "TensorArrayV3" rename_to: "TensorArray" } -op { name: "TensorArrayWriteV2" skip: true } -op { name: "TensorArrayWriteV3" rename_to: "TensorArrayWrite" } - -op { name: "WholeFileReader" skip: true } -op { name: "WholeFileReaderV2" rename_to: "WholeFileReader" } - -# functional_ops - -# image_ops -op { name: "AdjustContrastv2" rename_to: "AdjustContrast" } -op { name: "ResizeBilinearGrad" hide: true } -op { name: "ResizeBicubicGrad" hide: true } -op { name: "ResizeNearestNeighborGrad" hide: true } - -# io_ops - -# linalg_ops -op { name: "SelfAdjointEigV2" rename_to: "SelfAdjointEig" } - -# logging_ops -op { name: "AudioSummaryV2" rename_to: "AudioSummary" } - -# lookup_ops -op { name: "LookupTableFind" skip: true } -op { name: "LookupTableFindV2" rename_to: "LookupTableFind" } -op { name: "LookupTableInsert" skip: true } -op { name: "LookupTableInsertV2" rename_to: "LookupTableInsert" } -op { name: "LookupTableSize" skip: true } -op { name: "LookupTableSizeV2" rename_to: "LookupTableSize" } -op { name: "LookupTableExport" skip: true } -op { name: "LookupTableExportV2" rename_to: "LookupTableExport" } -op { name: "LookupTableImport" skip: true } -op { name: "LookupTableImportV2" rename_to: "LookupTableImport" } -op { name: "HashTable" skip: true } -op { name: "HashTableV2" rename_to: "HashTable" } -op { name: "MutableHashTable" skip: true } -op { name: "MutableHashTableV2" rename_to: "MutableHashTable" } -op { name: "MutableHashTableOfTensors" skip: true } -op { name: "MutableHashTableOfTensorsV2" rename_to: "MutableHashTableOfTensors" } -op { name: "MutableDenseHashTable" skip: true } -op { name: "MutableDenseHashTableV2" rename_to: "MutableDenseHashTable" } -op { name: "InitializeTable" skip: true } -op { name: "InitializeTableV2" rename_to: "InitializeTable" } -op { name: "InitializeTableFromTextFile" skip: true } -op { name: "InitializeTableFromTextFileV2" rename_to: "InitializeTableFromTextFile" } - -# math_ops -op { name: "All" alias: "ReduceAll" input_rename: { from: "reduction_indices" to: "axis" } } -op { name: "Any" alias: "ReduceAny" input_rename: { from: "reduction_indices" to: "axis" } } -op { name: "Max" alias: "ReduceMax" input_rename: { from: "reduction_indices" to: "axis" } } -op { name: "Mean" alias: "ReduceMean" input_rename: { from: "reduction_indices" to: "axis" } } -op { name: "Min" alias: "ReduceMin" input_rename: { from: "reduction_indices" to: "axis" } } -op { name: "Mul" rename_to: "Multiply" alias: "Mul" } -op { name: "Neg" rename_to: "Negate" alias: "Neg" } -op { name: "Prod" alias: "ReduceProd" input_rename: { from: "reduction_indices" to: "axis" } } -op { name: "Sub" rename_to: "Subtract" alias: "Sub" } -op { name: "Sum" alias: "ReduceSum" input_rename: { from: "reduction_indices" to: "axis" } } -op { name: "SigmoidGrad" hide: true } -op { name: "TanhGrad" hide: true } -op { name: "InvGrad" hide: true } -op { name: "ReciprocalGrad" hide: true } -op { name: "SqrtGrad" hide: true } -op { name: "RsqrtGrad" hide: true } - -# *Grad ops get hidden, only for use by the gradient code. -op { name: "SigmoidGrad" hide: true } -op { name: "TanhGrad" hide: true } -op { name: "InvGrad" hide: true } -op { name: "ReciprocalGrad" hide: true } -op { name: "SqrtGrad" hide: true } -op { name: "RsqrtGrad" hide: true } - -# nn_ops -op { name: "AvgPoolGrad" hide: true } -op { name: "LRNGrad" hide: true } -op { name: "MaxPoolGrad" hide: true } -op { name: "MaxPoolGradWithArgmax" hide: true } -op { name: "ReluGrad" hide: true } -op { name: "Relu6Grad" hide: true } -op { name: "EluGrad" hide: true } -op { name: "SeluGrad" hide: true } -op { name: "SoftplusGrad" hide: true } -op { name: "SoftsignGrad" hide: true } -op { name: "FractionalAvgPoolGrad" hide: true } -op { name: "FractionalMaxPoolGrad" hide: true } -op { name: "TopKV2" rename_to: "TopK" } -op { name: "BiasAddV1" skip: true } # Use BiasAdd instead - -# parsing_ops - -# random_ops - -op { name: "RandomStandardNormal" rename_to: "RandomNormal" } -# script_ops -# Calling Python functions from a C++ program isn't supported -op { name: "PyFunc" skip: true } -op { name: "PyFuncStateless" skip: true} - -# sdca_ops - -# state_ops - -op { name: "Variable" skip: true } -op { name: "VariableV2" rename_to: "Variable" } - -# sparse_ops - -# string_ops - -# user_ops - -# training_ops - diff --git a/tensorflow/cc/ops/while_loop.cc b/tensorflow/cc/ops/while_loop.cc index e0251efb2a424f86bd5a4885ef22d1928e04bd3e..d1c918d464bc9684b0db6dade2fb80cb2bd6691a 100644 --- a/tensorflow/cc/ops/while_loop.cc +++ b/tensorflow/cc/ops/while_loop.cc @@ -116,7 +116,7 @@ Status CreateCond(const Scope& scope, const CondGraphBuilderFn& cond, return Status::OK(); } -// Create the bdoy subgraph defined by `body`. `outputs` must be non-null and +// Create the body subgraph defined by `body`. `outputs` must be non-null and // empty. Status CreateBody(const Scope& scope, const BodyGraphBuilderFn& body, const std::vector& inputs, diff --git a/tensorflow/cc/profiler/BUILD b/tensorflow/cc/profiler/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..00799526fce572e7bb80199ccb8ce1cc89874031 --- /dev/null +++ b/tensorflow/cc/profiler/BUILD @@ -0,0 +1,36 @@ +package( + default_visibility = ["//visibility:public"], +) + +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test") + +tf_cuda_cc_test( + name = "profiler_test", + srcs = ["profiler_test.cc"], + deps = [ + ":profiler", + "//tensorflow/cc:cc_ops", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:tensorflow", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +cc_library( + name = "profiler", + srcs = ["profiler.cc"], + hdrs = ["profiler.h"], + deps = [ + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/profiler:protos_all_cc", + "//tensorflow/core/profiler:tfprof_options", + "//tensorflow/core/profiler/internal:tfprof_stats", + ], +) diff --git a/tensorflow/cc/profiler/profiler.cc b/tensorflow/cc/profiler/profiler.cc new file mode 100644 index 0000000000000000000000000000000000000000..3e55bac73e6d32a1fa5ddcc1937744e2cf56657d --- /dev/null +++ b/tensorflow/cc/profiler/profiler.cc @@ -0,0 +1,57 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/cc/profiler/profiler.h" + +namespace tensorflow { +namespace tfprof { + +Profiler::Profiler(const GraphDef& graph) { + std::unique_ptr graph_ptr(new GraphDef()); + *graph_ptr = graph; + stats_.reset(new TFStats(std::move(graph_ptr), nullptr, nullptr, nullptr)); +} + +void Profiler::AddStep(int64 step, const RunMetadata& run_meta) { + std::unique_ptr run_meta_ptr(new RunMetadata()); + *run_meta_ptr = run_meta; + stats_->AddRunMeta(step, std::move(run_meta_ptr)); +} + +GraphNodeProto Profiler::ProfileGraph(const Options& options) { + stats_->BuildView(kCmds[1]); + return stats_->ShowGraphNode(kCmds[1], options); +} + +GraphNodeProto Profiler::ProfileNameScope(const Options& options) { + stats_->BuildView(kCmds[0]); + return stats_->ShowGraphNode(kCmds[0], options); +} + +MultiGraphNodeProto Profiler::ProfileOperations(const Options& options) { + stats_->BuildView(kCmds[3]); + return stats_->ShowMultiGraphNode(kCmds[3], options); +} + +Status Profiler::SerializeToString(string* content) { + if (!content) { + return Status(error::Code::INVALID_ARGUMENT, + "Cannot use null string pointer for SerializeToString."); + } + stats_->SerializeToString(content); + return Status::OK(); +} + +} // namespace tfprof +} // namespace tensorflow diff --git a/tensorflow/cc/profiler/profiler.h b/tensorflow/cc/profiler/profiler.h new file mode 100644 index 0000000000000000000000000000000000000000..e1ce315d3c125ef9f0cb16209e199690211df440 --- /dev/null +++ b/tensorflow/cc/profiler/profiler.h @@ -0,0 +1,97 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_CC_PROFILER_PROFILER_H_ +#define THIRD_PARTY_TENSORFLOW_CC_PROFILER_PROFILER_H_ + +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/profiler/internal/tfprof_stats.h" +#include "tensorflow/core/profiler/tfprof_options.h" +#include "tensorflow/core/profiler/tfprof_output.pb.h" + +namespace tensorflow { +namespace tfprof { + +/// @addtogroup core +/// @{ + +/// A `Profiler` object lets the caller profile the execution of a graph. +/// +/// Example: +/// // First build a graph and run tracing. +/// Scope root = Scope::NewRootScope(); +/// auto a = Placeholder(root, DT_INT32); +/// auto c = Add(root, a, {41}); +/// +/// ClientSession session(root); +/// std::vector outputs; +/// RunOptions run_options; +/// run_options.set_trace_level(RunOptions::FULL_TRACE); +/// RunMetadata run_meta; +/// Status s = session.Run(run_options, { {a, {1}} }, {c}, &outputs, +/// &run_meta); +/// if (!s.ok()) { ... } +/// +/// // Then create profiler to do profiling. +/// GraphDef graph; +/// root.ToGraphDef(&graph); +/// Profiler profiler(graph); +/// profiler.AddStep(0, run_meta); +/// Options opts = ... // TODO(xpan): Support option building API. +/// MultiGraphNodeProto r = profiler.ProfileOperations(opts); +/// +class Profiler { + public: + /// `graph` is the model's GraphDef. + Profiler(const GraphDef& graph); + + /// Adds tracing information `run_meta` to profiler. A `run_meta` is + /// generated by a TensorFlow session run call. `step` is the key + /// to the `run_meta`. When calling ProfileXXX methods, caller can specify + /// `step` in `options` to seletively profile the corresponding `run_meta`. + /// Multiple different `run_meta` can be keyed by the same `step` in order + /// to group them together. + void AddStep(int64 step, const RunMetadata& run_meta); + + /// Profiles the model by organizing nodes in graph structure. + /// Each node is an op and the nodes are contected by the op inputs/outputs. + GraphNodeProto ProfileGraph(const Options& options); + + /// Profiles the model by organizing nodes in name scope structure. + /// Each node is an op, and nodes are organized by the ops' name + /// scope, similar to a filesystem tree. + /// E.g. /foo is the root of operation /foo/matmul_1 and foo/conv_2. + GraphNodeProto ProfileNameScope(const Options& options); + + /// Profiles the model by organizing nodes by operation types. + /// Each node is an operation type (e.g. Conv2D or MatMul), containing all + /// ops belonging to that type in the model. + MultiGraphNodeProto ProfileOperations(const Options& options); + + /// Serialize the profile content (ProfileProto) into a binary string, + /// User can write the string to file for offline analysis by + /// tfprof command-line tools or graphical user interface. + Status SerializeToString(string* content); + + private: + std::unique_ptr stats_; +}; +/// @} + +} // namespace tfprof +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CC_PROFILER_PROFILER_H_ diff --git a/tensorflow/cc/profiler/profiler_test.cc b/tensorflow/cc/profiler/profiler_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..280cd74827fc8ae80737eaf61286535fec959aa8 --- /dev/null +++ b/tensorflow/cc/profiler/profiler_test.cc @@ -0,0 +1,177 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/platform/test.h" + +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/cc/profiler/profiler.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/graph/default_device.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/public/session.h" + +namespace tensorflow { +namespace tfprof { + +class ProfilerTest : public ::testing::Test { + protected: + ProfilerTest() {} +}; + +GraphDef CreateGraphDef() { + Scope root = Scope::NewRootScope(); + + auto a = ops::Const(root, {{3, 2}, {-1, 0}}); + + auto x = ops::Const(root.WithOpName("x"), {{1.f}, {1.f}}); + + auto y = ops::MatMul(root.WithOpName("y"), a, x); + + auto y2 = ops::Square(root, y); + + auto y2_sum = ops::Sum(root, y2, 0); + + auto y_norm = ops::Sqrt(root, y2_sum); + + auto y_div = ops::Div(root.WithOpName("y_normalized"), y, y_norm); + + GraphDef def; + TF_CHECK_OK(root.ToGraphDef(&def)); + + return def; +} + +Options Default() { + Options opts(1000, /* max_depth */ + 0, /* min_bytes */ + 0, /* min_peak_bytes */ + 0, /* min_residual_bytes */ + 0, /* min_output_bytes */ + 0, /* min_micros */ + 0, /* min_accelerator_micros */ + 0, /* min_cpu_micros */ + 0, /* min_params */ + 0, /* min_float_ops */ + 0, /* min_occurrence */ + 0, /* step */ + "name", /* order_by */ + {".*"}, /* account_type_regexes */ + {".*"}, /* start_name_regexes */ + {}, /* trim_name_regexes */ + {".*"}, {}, /* hide_name_regexes */ + false, /* account_displayed_op_only */ + {"micros"}, /* select */ + {"none"}, /* output_type */ + {}); + return opts; +} + +template +const T* ExtractNode(const T& pb, const string& name) { + if (pb.name() == name) { + return &pb; + } + for (const T& c : pb.children()) { + const T* ret = ExtractNode(c, name); + if (ret) return ret; + } + return nullptr; +} + +TEST_F(ProfilerTest, Basics) { + SessionOptions options; + options.config.set_allow_soft_placement(true); + std::unique_ptr session(NewSession(options)); + GraphDef def = CreateGraphDef(); + if (options.target.empty()) { + graph::SetDefaultDevice("/gpu:0", &def); + } + + TF_CHECK_OK(session->Create(def)); + + Tensor x(DT_FLOAT, TensorShape({2, 1})); + auto x_flat = x.flat(); + x_flat.setRandom(); + Eigen::Tensor inv_norm = + x_flat.square().sum().sqrt().inverse(); + x_flat = x_flat * inv_norm(); + + std::vector outputs; + RunOptions run_options; + run_options.set_trace_level(RunOptions::FULL_TRACE); + RunMetadata run_metadata; + outputs.clear(); + + Profiler profiler(def); + for (int i = 0; i < 2; ++i) { + TF_CHECK_OK(session->Run(run_options, {{"x", x}}, {"y:0", "y_normalized:0"}, + {}, &outputs, &run_metadata)); + profiler.AddStep(i, run_metadata); + CHECK_EQ(size_t{2}, outputs.size()); + } + + std::vector resp; + TF_CHECK_OK(session->ListDevices(&resp)); + bool has_gpu = false; + for (const auto& dev : resp) { + if (dev.device_type() == "GPU") { + has_gpu = true; + } + } + + GraphNodeProto ret = profiler.ProfileNameScope(Default()); + const GraphNodeProto* matmul = ExtractNode(ret, "y"); + EXPECT_TRUE(matmul); + EXPECT_GT(matmul->exec_micros(), 0); + if (has_gpu) { + EXPECT_GT(matmul->accelerator_exec_micros(), 0); + } else { + EXPECT_EQ(matmul->accelerator_exec_micros(), 0); + } + const GraphNodeProto* square = ExtractNode(ret, "Square"); + EXPECT_TRUE(square); + EXPECT_GT(square->exec_micros(), 0); + if (has_gpu) { + EXPECT_GT(square->accelerator_exec_micros(), 0); + } else { + EXPECT_EQ(square->accelerator_exec_micros(), 0); + } + + Options opts2 = Default(); + opts2.output_type = "timeline"; + string timeline_file = io::JoinPath(testing::TmpDir(), "timeline"); + opts2.output_options["outfile"] = timeline_file; + GraphNodeProto ret2 = profiler.ProfileGraph(opts2); + string s; + TF_CHECK_OK(ReadFileToString(Env::Default(), timeline_file + "_0", &s)); + EXPECT_TRUE(s.find("Square") != s.npos); + + MultiGraphNodeProto ret3 = profiler.ProfileOperations(Default()); + const MultiGraphNodeProto* matmul2 = ExtractNode(ret3, "MatMul"); + EXPECT_TRUE(matmul2); + EXPECT_GT(matmul2->exec_micros(), 0); + if (has_gpu) { + EXPECT_GT(matmul2->accelerator_exec_micros(), 0); + } else { + EXPECT_EQ(matmul2->accelerator_exec_micros(), 0); + } + + TF_CHECK_OK(session->Close()); +} + +} // namespace tfprof +} // namespace tensorflow diff --git a/tensorflow/cc/saved_model/loader.cc b/tensorflow/cc/saved_model/loader.cc index f98abc8a817eca7bc129bb03a2ad31b97d957065..acef098c7d07f45d171679bff7c41e13ef0424f1 100644 --- a/tensorflow/cc/saved_model/loader.cc +++ b/tensorflow/cc/saved_model/loader.cc @@ -62,6 +62,15 @@ Status ReadSavedModel(const string& export_dir, SavedModel* saved_model_proto) { export_dir); } +string GetTagsAsString(const std::unordered_set& tags) { + string tags_as_string = "{ "; + for (const string& tag : tags) { + tags_as_string = strings::StrCat(tags_as_string, tag, " "); + } + tags_as_string = strings::StrCat(tags_as_string, "}"); + return tags_as_string; +} + Status FindMetaGraphDefToLoad(const SavedModel& saved_model_proto, const std::unordered_set& tags, MetaGraphDef* meta_graph_def_to_load) { @@ -77,14 +86,9 @@ Status FindMetaGraphDefToLoad(const SavedModel& saved_model_proto, return Status::OK(); } } - string tags_as_string = "{ "; - for (const string& tag : tags) { - tags_as_string = strings::StrCat(tags_as_string, tag, " "); - } - tags_as_string = strings::StrCat(tags_as_string, "}"); return Status(error::Code::NOT_FOUND, "Could not find meta graph def matching supplied tags: " + - tags_as_string + + GetTagsAsString(tags) + ". To inspect available tag-sets in the SavedModel, please " "use the SavedModel CLI: `saved_model_cli`"); } @@ -233,7 +237,8 @@ Status LoadSavedModelInternal(const SessionOptions& session_options, return Status(error::Code::NOT_FOUND, "SavedModel not found in export directory: " + export_dir); } - LOG(INFO) << "Loading SavedModel from: " << export_dir; + LOG(INFO) << "Loading SavedModel with tags: " << GetTagsAsString(tags) + << "; from: " << export_dir; SavedModel saved_model_proto; TF_RETURN_IF_ERROR(ReadSavedModel(export_dir, &saved_model_proto)); @@ -281,7 +286,8 @@ Status LoadSavedModel(const SessionOptions& session_options, return end_microseconds - start_microseconds; }(); auto log_and_count = [&](const string& status_str) { - LOG(INFO) << "Loading SavedModel: " << status_str << ". Took " + LOG(INFO) << "SavedModel load for tags " << GetTagsAsString(tags) + << "; Status: " << status_str << ". Took " << load_latency_microsecs << " microseconds."; load_attempt_count->GetCell(export_dir, status_str)->IncrementBy(1); }; diff --git a/tensorflow/cc/saved_model/python/BUILD b/tensorflow/cc/saved_model/python/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..f5fbc75edcba9d5ae9ef7432de224df766bcab9e --- /dev/null +++ b/tensorflow/cc/saved_model/python/BUILD @@ -0,0 +1,30 @@ +# Description: +# CLIF wrappers for TensorFlow SavedModels. + +licenses(["notice"]) # Apache 2.0 + +package( + default_visibility = ["//visibility:public"], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) + +load("//tensorflow/core:platform/default/build_config.bzl", "tf_py_clif_cc") + +tf_py_clif_cc( + name = "loader", + srcs = ["loader.clif"], + deps = [ + "//tensorflow/cc/saved_model:loader", + ], +) diff --git a/tensorflow/cc/saved_model/python/loader.clif b/tensorflow/cc/saved_model/python/loader.clif new file mode 100644 index 0000000000000000000000000000000000000000..b102757d2eeb46ee713d8ed0d0c3d66b58740ee0 --- /dev/null +++ b/tensorflow/cc/saved_model/python/loader.clif @@ -0,0 +1,4 @@ +from "third_party/tensorflow/cc/saved_model/loader.h": + namespace `tensorflow`: + class SavedModelBundle: + def __init__(self) diff --git a/tensorflow/cc/tools/BUILD b/tensorflow/cc/tools/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..0a7c37383f96ca65bf5ae05cf0827c01dc4d799b --- /dev/null +++ b/tensorflow/cc/tools/BUILD @@ -0,0 +1,58 @@ +# Description: +# TensorFlow cc tools. + +package( + default_visibility = ["//visibility:public"], +) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +load( + "//tensorflow:tensorflow.bzl", + "tf_cc_test", +) + +cc_library( + name = "freeze_saved_model", + srcs = ["freeze_saved_model.cc"], + hdrs = ["freeze_saved_model.h"], + deps = [ + "//tensorflow/cc/saved_model:loader", + "//tensorflow/core:core_cpu", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:tensorflow", + ], +) + +tf_cc_test( + name = "freeze_saved_model_test", + srcs = ["freeze_saved_model_test.cc"], + deps = [ + ":freeze_saved_model", + "//tensorflow/cc:cc_ops", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + +# ----------------------------------------------------------------------------- +# Google-internal targets. + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/cc/tools/freeze_saved_model.cc b/tensorflow/cc/tools/freeze_saved_model.cc new file mode 100644 index 0000000000000000000000000000000000000000..ddf372cdef21e1b3892c9a03714478d5a5785517 --- /dev/null +++ b/tensorflow/cc/tools/freeze_saved_model.cc @@ -0,0 +1,194 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/cc/tools/freeze_saved_model.h" + +#include + +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/versions.pb.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/protobuf/meta_graph.pb.h" + +namespace tensorflow { + +namespace { + +// Gets tensor names from tensor_info and inserts them into the set of tensor +// names. +void GetTensorNamesFromTensorInfo(const TensorInfo& tensor_info, + std::unordered_set* tensor_names) { + if (tensor_info.has_coo_sparse()) { + // If the tensor is sparse we have to add all three tensors of the sparse + // representations. + const TensorInfo_CooSparse& coo_sparse = tensor_info.coo_sparse(); + tensor_names->insert(coo_sparse.values_tensor_name()); + tensor_names->insert(coo_sparse.indices_tensor_name()); + tensor_names->insert(coo_sparse.dense_shape_tensor_name()); + } else { + tensor_names->insert(tensor_info.name()); + } +} + +// Gets the union of all inputs and outputs of all SignatureDefs in the bundle +void GetSignatureDefsInputsAndOutputs( + const SavedModelBundle& saved_model_bundle, + std::unordered_set* inputs, std::unordered_set* outputs) { + for (auto& sigdef_elem : saved_model_bundle.meta_graph_def.signature_def()) { + const SignatureDef& signature_def = sigdef_elem.second; + for (auto& input_elem : signature_def.inputs()) { + GetTensorNamesFromTensorInfo(input_elem.second, inputs); + } + for (auto& output_elem : signature_def.outputs()) { + GetTensorNamesFromTensorInfo(output_elem.second, outputs); + } + } +} + +// Gets a map from string node name to NodeDef. +void GetNodeNameToNodeDefMap( + GraphDef* graph_def, + std::unordered_map* name_to_node_map) { + for (size_t i = 0; i < graph_def->node_size(); i++) { + NodeDef* node = graph_def->mutable_node(i); + (*name_to_node_map)[node->name()] = node; + } +} + +// Gets the set of node names needed by `outputs` and the corresponding set of +// variable nodes to convert. +void GetReachableNodesAndVariables( + GraphDef* graph_def, const std::unordered_set& outputs, + std::unordered_set* reachable_node_names, + std::unordered_set* variable_node_names) { + // TODO(suharshs): Add support for ResourceVariables. + static const std::unordered_set* kVariableTypes = + new std::unordered_set({"Variable", "VariableV2"}); + // name_to_node_map is needed to get the inputs from the NodeDef corresponding + // the a string node name. These inputs are used when doing our backwards + // traversal. + std::unordered_map name_to_node_map; + GetNodeNameToNodeDefMap(graph_def, &name_to_node_map); + std::queue nodes_to_visit; + for (const string& tensor_name : outputs) { + // We need to strip off the tensor part to get the node name. + std::vector tensor_name_parts = str_util::Split(tensor_name, ':'); + nodes_to_visit.push(tensor_name_parts[0]); + } + // We do a traversal backwards from the outputs specified in the MetaGraphDef. + while (!nodes_to_visit.empty()) { + const string node_name = nodes_to_visit.front(); + nodes_to_visit.pop(); + if (reachable_node_names->find(node_name) != reachable_node_names->end()) { + continue; + } + reachable_node_names->insert(node_name); + NodeDef* node = name_to_node_map[node_name]; + if (kVariableTypes->find(node->op()) != kVariableTypes->end()) { + variable_node_names->insert(node->name()); + } + for (const string& input : node->input()) { + nodes_to_visit.push(input); + } + } +} + +// Gets a map from variable name to variable value. +Status GetVariableNameToTensorMap( + Session* session, std::unordered_set variable_names_set, + std::unordered_map* variable_name_to_value_map) { + if (variable_names_set.empty()) { + return Status::OK(); + } + std::vector variable_names; + std::vector tensor_names; + for (const string& node_name : variable_names_set) { + variable_names.push_back(node_name); + // We need to run tensors, so append ":0". + tensor_names.push_back(node_name + ":0"); + } + std::vector outputs; + TF_RETURN_IF_ERROR( + session->Run(/* inputs */ {}, tensor_names, /* targets */ {}, &outputs)); + for (size_t i = 0; i < variable_names.size(); i++) { + (*variable_name_to_value_map)[variable_names[i]] = outputs[i]; + } + return Status::OK(); +} + +// Converts a Variable NodeDef into a Constant NodeDef. +void ConvertVariableToConstant(const NodeDef& variable_node, + const Tensor& variable_value, + NodeDef* const_node) { + const_node->set_name(variable_node.name()); + const_node->set_op("Const"); + (*const_node->mutable_attr())["dtype"] = variable_node.attr().at("dtype"); + variable_value.AsProtoTensorContent( + (*const_node->mutable_attr())["value"].mutable_tensor()); +} + +// Freezes the subgraph of all nodes needed by `outputs`. +Status FreezeGraphDef(const SavedModelBundle& saved_model_bundle, + const std::unordered_set& outputs, + GraphDef* frozen_graph_def) { + GraphDef graph_def = saved_model_bundle.meta_graph_def.graph_def(); + // Copy versions and library as-is from original graph. + *frozen_graph_def->mutable_versions() = graph_def.versions(); + *frozen_graph_def->mutable_library() = graph_def.library(); + // If the graph is empty there is nothing left to do. + if (graph_def.node_size() == 0) { + return Status::OK(); + } + std::unordered_set reachable_node_names; + std::unordered_set variable_node_names; + GetReachableNodesAndVariables(&graph_def, outputs, &reachable_node_names, + &variable_node_names); + std::unordered_map variable_to_value_map; + TF_RETURN_IF_ERROR( + GetVariableNameToTensorMap(saved_model_bundle.session.get(), + variable_node_names, &variable_to_value_map)); + // We copy the nodes in the same order they were in the original graph_def. + for (const NodeDef& node : graph_def.node()) { + if (reachable_node_names.find(node.name()) == reachable_node_names.end()) { + continue; + } + if (variable_node_names.find(node.name()) != variable_node_names.end()) { + ConvertVariableToConstant(node, variable_to_value_map[node.name()], + frozen_graph_def->add_node()); + } else { + // If the node isn't a variable, just copy the node as-is. + *frozen_graph_def->add_node() = node; + } + } + return Status::OK(); +} + +} // namespace + +Status FreezeSavedModel(const SavedModelBundle& saved_model_bundle, + GraphDef* frozen_graph_def, + std::unordered_set* inputs, + std::unordered_set* outputs) { + GetSignatureDefsInputsAndOutputs(saved_model_bundle, inputs, outputs); + TF_RETURN_IF_ERROR( + FreezeGraphDef(saved_model_bundle, *outputs, frozen_graph_def)); + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/cc/tools/freeze_saved_model.h b/tensorflow/cc/tools/freeze_saved_model.h new file mode 100644 index 0000000000000000000000000000000000000000..bd5e0516c8999dc235747ccec75a57542b0f9bf7 --- /dev/null +++ b/tensorflow/cc/tools/freeze_saved_model.h @@ -0,0 +1,43 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CC_TOOLS_FREEZE_SAVED_MODEL_H_ +#define THIRD_PARTY_TENSORFLOW_CC_TOOLS_FREEZE_SAVED_MODEL_H_ + +#include + +#include "tensorflow/cc/saved_model/loader.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +// Returns a frozen GraphDef, input tensors, and output tensors from the loaded +// SavedModelBundle. +// `inputs` and `outputs` consist of the union of all inputs and outputs in the +// SignatureDefs in the SavedModelBundle. +// FreezeSavedModel sets `frozen_graph_def` to a GraphDef of all nodes needed by +// `outputs`. All variables in the supplied SavedModelBundle are converted to +// constants, set to the value of the variables, by running the restored Session +// in the SavedModelBundle. +// WARNING: Only the variable checkpoints will be reflected in the frozen +// graph_def. All saved_model assets will be ignored. +Status FreezeSavedModel(const SavedModelBundle& saved_model_bundle, + GraphDef* frozen_graph_def, + std::unordered_set* inputs, + std::unordered_set* outputs); + +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CC_TOOLS_FREEZE_SAVED_MODEL_H_ diff --git a/tensorflow/cc/tools/freeze_saved_model_test.cc b/tensorflow/cc/tools/freeze_saved_model_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..57244a4f0adeb9775e35445f77205f3d221ee05b --- /dev/null +++ b/tensorflow/cc/tools/freeze_saved_model_test.cc @@ -0,0 +1,307 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/cc/tools/freeze_saved_model.h" + +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/framework/function_testlib.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/framework/versions.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/public/session.h" +#include "tensorflow/core/public/session_options.h" + +namespace tensorflow { +namespace { + +class FreezeTest : public ::testing::Test { + protected: + void GraphDefEqual(const GraphDef& actual, const GraphDef& expected) { + EXPECT_EQ(actual.ShortDebugString(), expected.ShortDebugString()); + } + + // Builds a SignatureDef with the provided `inputs` and `outputs`. + SignatureDef BuildSignatureDef(const std::unordered_set& inputs, + const std::unordered_set& outputs) { + SignatureDef signature_def; + for (const string& input : inputs) { + (*signature_def.mutable_inputs())[input].set_name(input); + } + for (const string& output : outputs) { + (*signature_def.mutable_outputs())[output].set_name(output); + } + return signature_def; + } + + // Adds `signature_def` to `saved_model_bundle` under `key`. + void AddSignatureDefToSavedModelBundle(const SignatureDef& signature_def, + const string& key, + SavedModelBundle* saved_model_bundle) { + MetaGraphDef* meta_graph_def = &saved_model_bundle->meta_graph_def; + (*meta_graph_def->mutable_signature_def())[key] = signature_def; + } + + // Adds an initialized session to `saved_model_bundle` using `graph_def` and + // initializing with `init_node`. + Status InitializeSavedModelBundleSession( + const GraphDef& graph_def, const string& init_node, + SavedModelBundle* saved_model_bundle) { + SessionOptions session_options; + saved_model_bundle->session.reset(NewSession(session_options)); + TF_RETURN_IF_ERROR(saved_model_bundle->session->Create(graph_def)); + if (!init_node.empty()) { + std::vector outputs; + return saved_model_bundle->session->Run( + /* inputs */ {}, /* output_tensors */ {}, {init_node}, &outputs); + } + return Status::OK(); + } + + // Adds `graph_def` to `saved_model_bundle` and intializes a session with + // `init_node`. + Status AddGraphDefToSavedModelBundle(const GraphDef& graph_def, + const string& init_node, + SavedModelBundle* saved_model_bundle) { + MetaGraphDef* meta_graph_def = &saved_model_bundle->meta_graph_def; + *meta_graph_def->mutable_graph_def() = graph_def; + return InitializeSavedModelBundleSession(graph_def, init_node, + saved_model_bundle); + } + + // Adds `graph_def` and `outputs` as the GraphDef and SignatureDef in + // `saved_model_bundle` and initializes a session with `init_node`. + Status AddGraphDefWithOutputsToSavedModelBundle( + const GraphDef& graph_def, const std::unordered_set& outputs, + const string& init_node, SavedModelBundle* saved_model_bundle) { + SignatureDef signature_def = + BuildSignatureDef(std::unordered_set(), outputs); + AddSignatureDefToSavedModelBundle(signature_def, "signature_def", + saved_model_bundle); + return AddGraphDefToSavedModelBundle(graph_def, init_node, + saved_model_bundle); + } + + // Runs and compares the outputs of `tensor_name` on both the + // `unfrozen_session` and the `frozen_graph_def. + void RunAndCompareFrozenAndUnfrozenGraphs(Session* unfrozen_session, + const GraphDef& frozen_graph_def, + const string& tensor_name) { + std::vector unfrozen_outputs; + TF_ASSERT_OK(unfrozen_session->Run(/* inputs */ {}, {tensor_name}, + /* targets */ {}, &unfrozen_outputs)); + + SessionOptions session_options; + std::unique_ptr frozen_session(NewSession(session_options)); + TF_ASSERT_OK(frozen_session->Create(frozen_graph_def)); + std::vector frozen_outputs; + TF_ASSERT_OK(frozen_session->Run(/* inputs */ {}, {tensor_name}, + /* targets */ {}, &frozen_outputs)); + + test::ExpectTensorEqual(unfrozen_outputs[0], frozen_outputs[0]); + } +}; + +TEST_F(FreezeTest, InputsAndOutputsSingleSignatureDef) { + // Test that inputs and outputs get correctly populated for a single + // SignatureDef. + SavedModelBundle saved_model_bundle; + std::unordered_set expected_inputs = {"input0:0", "input1:0"}; + std::unordered_set expected_outputs = {"output0:0", "output1:0"}; + SignatureDef signature_def = + BuildSignatureDef(expected_inputs, expected_outputs); + AddSignatureDefToSavedModelBundle(signature_def, "signature_def", + &saved_model_bundle); + GraphDef frozen_graph_def; + std::unordered_set inputs; + std::unordered_set outputs; + TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, &inputs, + &outputs)); + EXPECT_EQ(expected_inputs, inputs); + EXPECT_EQ(expected_outputs, outputs); +} + +TEST_F(FreezeTest, InputsAndOutputsMultipleSignatureDefs) { + // Test that inputs and outputs get correctly merged and populated when + // multiple SignatureDefs are provided. + SavedModelBundle saved_model_bundle; + SignatureDef signature_def_0 = BuildSignatureDef({"input0:0"}, {"output0:0"}); + SignatureDef signature_def_1 = BuildSignatureDef({"input1:0"}, {"output1:0"}); + AddSignatureDefToSavedModelBundle(signature_def_0, "signature_def_0", + &saved_model_bundle); + AddSignatureDefToSavedModelBundle(signature_def_1, "signature_def_1", + &saved_model_bundle); + GraphDef frozen_graph_def; + std::unordered_set inputs; + std::unordered_set outputs; + TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, &inputs, + &outputs)); + std::unordered_set expected_inputs = {"input0:0", "input1:0"}; + std::unordered_set expected_outputs = {"output0:0", "output1:0"}; + EXPECT_EQ(expected_inputs, inputs); + EXPECT_EQ(expected_outputs, outputs); +} + +TEST_F(FreezeTest, GraphDefVersionsAndLibrary) { + // Test that GraphDef versions and library are copied correctly into the + // frozen graph. + SavedModelBundle saved_model_bundle; + GraphDef graph_def; + graph_def.mutable_versions()->set_producer(1234); + graph_def.mutable_versions()->set_min_consumer(1234); + *graph_def.mutable_library()->add_function() = test::function::NonZero(); + TF_ASSERT_OK( + AddGraphDefToSavedModelBundle(graph_def, "", &saved_model_bundle)); + + GraphDef frozen_graph_def; + std::unordered_set inputs; + std::unordered_set outputs; + TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, &inputs, + &outputs)); + + GraphDefEqual(frozen_graph_def, graph_def); +} + +TEST_F(FreezeTest, GraphDefWithNoVariables) { + // Test freezing a graph with no variables. + SavedModelBundle saved_model_bundle; + GraphDef graph_def; + Scope scope = Scope::NewRootScope(); + Output a = ops::Const(scope.WithOpName("a"), 10.0f, {}); + Output b = ops::Const(scope.WithOpName("b"), 10.0f, {}); + Output c = ops::Mul(scope.WithOpName("c"), a, b); + TF_ASSERT_OK(scope.ToGraphDef(&graph_def)); + TF_ASSERT_OK(AddGraphDefWithOutputsToSavedModelBundle(graph_def, {"c:0"}, "", + &saved_model_bundle)); + + GraphDef frozen_graph_def; + std::unordered_set inputs; + std::unordered_set outputs; + TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, &inputs, + &outputs)); + + GraphDefEqual(frozen_graph_def, graph_def); +} + +TEST_F(FreezeTest, GraphDefWithVariablesNotNeededByOutputs) { + // Test freezing a graph with variables that are not needed by the outputs in + // the SignatureDef. The resulting graph shouldn't be frozen, but + // non-dependent nodes should be pruned. + SavedModelBundle saved_model_bundle; + GraphDef graph_def; + Scope scope = Scope::NewRootScope(); + Output a = ops::Const(scope.WithOpName("a"), 10.0f, {}); + Output b = ops::Const(scope.WithOpName("b"), 10.0f, {}); + Output c = ops::Mul(scope.WithOpName("c"), a, b); + Output var = ops::Variable(scope.WithOpName("var"), {}, DataType::DT_FLOAT); + Output assign = ops::Assign(scope.WithOpName("assign"), var, a); + TF_ASSERT_OK(scope.ToGraphDef(&graph_def)); + // "c" isnt dependent on the variable, so nothing should be frozen. + TF_ASSERT_OK(AddGraphDefWithOutputsToSavedModelBundle( + graph_def, {"c:0"}, assign.name(), &saved_model_bundle)); + + GraphDef frozen_graph_def; + std::unordered_set inputs; + std::unordered_set outputs; + TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, &inputs, + &outputs)); + + GraphDef expected_graph_def; + Scope expected_scope = Scope::NewRootScope(); + Output expected_a = ops::Const(expected_scope.WithOpName("a"), 10.0f, {}); + Output expected_b = ops::Const(expected_scope.WithOpName("b"), 10.0f, {}); + Output expected_c = + ops::Mul(expected_scope.WithOpName("c"), expected_a, expected_b); + TF_ASSERT_OK(expected_scope.ToGraphDef(&expected_graph_def)); + + GraphDefEqual(frozen_graph_def, expected_graph_def); + + RunAndCompareFrozenAndUnfrozenGraphs(saved_model_bundle.session.get(), + frozen_graph_def, "c:0"); +} + +TEST_F(FreezeTest, GraphDefWithVariablesNeededByOutputs) { + // Test freezing a graph with variables that are needed by outputs in the + // SignatureDef. The variables should be frozen. + SavedModelBundle saved_model_bundle; + GraphDef graph_def; + Scope scope = Scope::NewRootScope(); + Output a = ops::Const(scope.WithOpName("a"), 10.0f, {}); + Output var = ops::Variable(scope.WithOpName("var"), {}, DataType::DT_FLOAT); + Output c = ops::Mul(scope.WithOpName("c"), a, var); + Output assign = ops::Assign(scope.WithOpName("assign"), var, a); + TF_ASSERT_OK(scope.ToGraphDef(&graph_def)); + // "c" isnt dependent on the variable, so nothing should be frozen. + TF_ASSERT_OK(AddGraphDefWithOutputsToSavedModelBundle( + graph_def, {"c:0"}, assign.name(), &saved_model_bundle)); + + GraphDef frozen_graph_def; + std::unordered_set inputs; + std::unordered_set outputs; + TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, &inputs, + &outputs)); + + // There should be 3 nodes in the resulting graph_def, and none should be + // variables. + EXPECT_EQ(frozen_graph_def.node_size(), 3); + for (const NodeDef& node : frozen_graph_def.node()) { + EXPECT_NE(node.op(), "Variable") << node.name(); + EXPECT_NE(node.op(), "VariableV2") << node.name(); + } + + RunAndCompareFrozenAndUnfrozenGraphs(saved_model_bundle.session.get(), + frozen_graph_def, "c:0"); +} + +TEST_F(FreezeTest, GraphDefWithVariablesNeededAndNotNeededByOutputs) { + // Test freezing a graph with some variables that are needed and not needed by + // the outputs in the SignatureDef. The resulting graph should only freeze + // dependent variables. + SavedModelBundle saved_model_bundle; + GraphDef graph_def; + Scope scope = Scope::NewRootScope(); + Output a = ops::Const(scope.WithOpName("a"), 10.0f, {}); + Output var = ops::Variable(scope.WithOpName("var"), {}, DataType::DT_FLOAT); + Output c = ops::Mul(scope.WithOpName("c"), a, var); + Output assign = ops::Assign(scope.WithOpName("assign"), var, a); + Output var_1 = + ops::Variable(scope.WithOpName("var_1"), {}, DataType::DT_FLOAT); + Output assign_1 = ops::Assign(scope.WithOpName("assign_1"), var, a); + TF_ASSERT_OK(scope.ToGraphDef(&graph_def)); + // "c" isnt dependent on the variable, so nothing should be frozen. + TF_ASSERT_OK(AddGraphDefWithOutputsToSavedModelBundle( + graph_def, {"c:0"}, assign.name(), &saved_model_bundle)); + + GraphDef frozen_graph_def; + std::unordered_set inputs; + std::unordered_set outputs; + TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, &inputs, + &outputs)); + + // There should be 3 nodes in the resulting graph_def, and none should be + // variables. + EXPECT_EQ(frozen_graph_def.node_size(), 3); + for (const NodeDef& node : frozen_graph_def.node()) { + EXPECT_NE(node.op(), "Variable") << node.name(); + EXPECT_NE(node.op(), "VariableV2") << node.name(); + } + + RunAndCompareFrozenAndUnfrozenGraphs(saved_model_bundle.session.get(), + frozen_graph_def, "c:0"); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/aot/BUILD b/tensorflow/compiler/aot/BUILD index a9a6ea84319a18a8fbce648391bf5918ff6d9a08..0540260efd83e18258ec6e93c514d14e328791b1 100644 --- a/tensorflow/compiler/aot/BUILD +++ b/tensorflow/compiler/aot/BUILD @@ -24,7 +24,6 @@ tf_cc_test( srcs = ["runtime_test.cc"], deps = [ ":runtime", - "//tensorflow/compiler/tf2xla:xla_local_runtime_context", "//tensorflow/core:framework", "//tensorflow/core:test", "//tensorflow/core:test_main", @@ -53,6 +52,7 @@ cc_library( "flags.h", ], deps = [ + ":embedded_protocol_buffers", ":runtime", # needed by codegen to print aligned_buffer_bytes "//tensorflow/compiler/tf2xla", "//tensorflow/compiler/tf2xla:common", @@ -69,9 +69,7 @@ cc_library( "//tensorflow/compiler/xla/client:compile_only_client", "//tensorflow/compiler/xla/service:compiler", "//tensorflow/compiler/xla/service/cpu:cpu_compiler", - "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", - "//tensorflow/core:framework", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", @@ -81,13 +79,18 @@ cc_library( tf_cc_test( name = "codegen_test", srcs = ["codegen_test.cc"], - data = ["codegen_test_h.golden"], + data = [ + "codegen_test_h.golden", + "codegen_test_o.golden", + ], deps = [ ":tfcompile_lib", "//tensorflow/compiler/xla:shape_util", "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", + "@llvm//:support", # fixdeps: keep + "@llvm//:x86_code_gen", # fixdeps: keep ], ) @@ -111,6 +114,7 @@ cc_library( "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", + "//tensorflow/core:graph", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", ], @@ -190,6 +194,23 @@ cc_library( visibility = ["//visibility:public"], ) +cc_library( + name = "embedded_protocol_buffers", + srcs = ["embedded_protocol_buffers.cc"], + hdrs = ["embedded_protocol_buffers.h"], + deps = [ + "//tensorflow/compiler/tf2xla:common", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", + "//tensorflow/core:lib", + "@llvm//:core", + "@llvm//:execution_engine", + "@llvm//:support", + "@llvm//:target", + ], +) + tf_cc_test( name = "benchmark_test", srcs = ["benchmark_test.cc"], diff --git a/tensorflow/compiler/aot/codegen.cc b/tensorflow/compiler/aot/codegen.cc index ae22f7edc423247b34895411d19d7a3c21f86d4f..2cae85e8965216eaaee4d3032015d0016258a5c1 100644 --- a/tensorflow/compiler/aot/codegen.cc +++ b/tensorflow/compiler/aot/codegen.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/aot/embedded_protocol_buffers.h" #include "tensorflow/compiler/aot/runtime.h" #include "tensorflow/compiler/tf2xla/str_util.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" @@ -101,21 +102,8 @@ Status ComputeArgSizes(const CompileResult& compile_result, std::vector* arg_sizes) { const xla::ProgramShape& ps = compile_result.program_shape; for (int i = 0; i < ps.parameters_size(); ++i) { - if (i == ps.parameters_size() - 1 && compile_result.has_context_arg) { - // If the compiled function needs a XlaLocalRuntimeContext* arg, it's - // always last, and must be represented as an opaque type. - const xla::PrimitiveType type = ps.parameters(i).element_type(); - if (type != xla::OPAQUE) { - return errors::InvalidArgument( - "expected final context arg to be opaque, but got type: ", - xla::PrimitiveType_Name(type), ", from program shape: ", - xla::ShapeUtil::HumanString(ps)); - } - arg_sizes->push_back(-1); - } else { - arg_sizes->push_back(xla::ShapeUtil::ByteSizeOf( - ps.parameters(i), compile_result.pointer_size)); - } + arg_sizes->push_back(xla::ShapeUtil::ByteSizeOf( + ps.parameters(i), compile_result.pointer_size)); } return Status::OK(); } @@ -165,11 +153,6 @@ string RewriteWithName(const string& name, string code, Status GenArgMethods(const tf2xla::Config& config, const xla::ProgramShape& ps, const CompileResult& compile_result, string* methods) { size_t num_args = ps.parameters_size(); - if (compile_result.has_context_arg) { - // If the compiled function needs a XlaLocalRuntimeContext* arg, it's - // always last, and is set in the class constructor. - num_args--; - } if (config.feed_size() != num_args) { return errors::InvalidArgument("mismatch between feed_size(", config.feed_size(), ") and num_args(", @@ -281,49 +264,6 @@ string GenNameToIndexCode(const T& entries, bool generate) { return code; } -// Converts the given `str` into a comma-separated list of per-character values. -string StringToCharList(const string& str) { - string list; - for (const char c : str) { - if (!list.empty()) { - list += ","; - } - list += strings::StrCat(static_cast(c)); - } - return list; -} - -string GenProgramShapeCode(xla::ProgramShape program_shape, bool generate) { - // No need for any static magic if we're not supposed to generate the data. - if (!generate) { - return "{\n return nullptr;\n }"; - } - // The parameter names are currently meaningless, and redundant with the rest - // of our metadata, so clear them out to avoid confusion and save space. - program_shape.clear_parameter_names(); - const string proto_str = program_shape.SerializeAsString(); - // Embed the program shape as a serialized protobuf in the header file. - // - // TODO(toddw): This strategy will likely fail for larger protobufs, depending - // on the C++ compiler that is used. Figure out another solution if necessary. - string code = R"({ - static const xla::ProgramShape* kShape = []() { - static const char kProto[] = {{{PROTO_LIST}}}; - static constexpr int kProtoSize = {{PROTO_SIZE}}; - xla::ProgramShape* shape = new xla::ProgramShape; - shape->ParseFromArray(kProto, kProtoSize); - return shape; - }(); - return kShape; - })"; - str_util::ReplaceAllPairs( - &code, { - {"{{PROTO_LIST}}", StringToCharList(proto_str)}, - {"{{PROTO_SIZE}}", strings::StrCat(proto_str.size())}, - }); - return code; -} - Status ValidateFeedFetchCppNames(const tf2xla::Config& config) { for (const tf2xla::Feed& feed : config.feed()) { if (!feed.name().empty()) { @@ -340,8 +280,9 @@ Status ValidateFeedFetchCppNames(const tf2xla::Config& config) { } // namespace -Status GenerateHeader(const HeaderOpts& opts, const tf2xla::Config& config, - const CompileResult& compile_result, string* header) { +Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config, + const CompileResult& compile_result, + const MetadataResult& metadata_result, string* header) { TF_RETURN_IF_ERROR(ValidateConfig(config)); TF_RETURN_IF_ERROR(ValidateFeedFetchCppNames(config)); const int64 result_index = compile_result.aot->result_buffer_index(); @@ -391,8 +332,6 @@ Status GenerateHeader(const HeaderOpts& opts, const tf2xla::Config& config, ? R"(#include "tensorflow/compiler/xla/xla_data.pb.h")" : ""; - const string program_shape_code = - GenProgramShapeCode(ps, opts.gen_program_shape); // Use a poor-man's text templating mechanism; first populate the full header // with placeholder tokens, and then rewrite the tokens with real values. @@ -418,7 +357,9 @@ namespace xla { class ExecutableRunOptions; } // (Implementation detail) Entry point to the function in the object file. extern "C" void {{ENTRY}}( void* result, const xla::ExecutableRunOptions* run_options, - const void** args, void** temps); + const void** args, void** temps, tensorflow::int64* profile_counters); + +{{DECLS_FROM_OBJ_FILE}} {{NS_START}} // {{CLASS}} represents a computation previously specified in a @@ -474,7 +415,6 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction { data->temp_sizes = TempSizes(); data->num_temps = kNumTemps; data->result_index = kResultIndex; - data->requires_runtime_context = {{HAS_CONTEXT_ARG}}; data->arg_names = StaticArgNames(); data->result_names = StaticResultNames(); data->program_shape = StaticProgramShape(); @@ -483,7 +423,7 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction { return *kStaticData; } - {{CLASS}}(AllocMode alloc_mode = AllocMode::ARGS_RESULTS_AND_TEMPS) + {{CLASS}}(AllocMode alloc_mode = AllocMode::ARGS_RESULTS_PROFILES_AND_TEMPS) : XlaCompiledCpuFunction(StaticData(), alloc_mode) {} {{CLASS}}(const {{CLASS}}&) = delete; @@ -496,8 +436,8 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction { // void set_argN_data(void* data) // Sets the buffer of type T for positional argument N. May be called in // any AllocMode. Must be called before Run to have an affect. Must be - // called in AllocMode::RESULTS_AND_TEMPS_ONLY for each positional argument, - // to set the argument buffers. + // called in AllocMode::RESULTS_PROFILES_AND_TEMPS_ONLY for each positional + // argument, to set the argument buffers. // // T* argN_data() // Returns the buffer of type T for positional argument N. @@ -543,7 +483,10 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction { static const char** StaticResultNames() {{RESULT_NAMES_CODE}} // Shape of the args and results. - static const xla::ProgramShape* StaticProgramShape() {{PROGRAM_SHAPE_CODE}} + static const xla::ProgramShape* StaticProgramShape() { + static const xla::ProgramShape* kShape = {{PROGRAM_SHAPE_SHIM_EXPRESSION}}; + return kShape; + } }; {{NS_END}} @@ -560,26 +503,68 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction { {"{{ARG_SIZES}}", str_util::Join(arg_sizes, ", ")}, {"{{CLASS}}", opts.class_name}, {"{{ENTRY}}", compile_result.entry_point}, - {"{{HAS_CONTEXT_ARG}}", - compile_result.has_context_arg ? "true" : "false"}, {"{{INCLUDE_XLA_DATA_PROTO}}", include_xla_data_proto}, {"{{METHODS_ARG}}\n", methods_arg}, {"{{METHODS_RESULT}}\n", methods_result}, {"{{NS_END}}\n", ns_end}, {"{{NS_START}}\n", ns_start}, {"{{PROGRAM_SHAPE}}", xla::ShapeUtil::HumanString(ps)}, - {"{{PROGRAM_SHAPE_CODE}}", program_shape_code}, {"{{RESULT_INDEX}}", strings::StrCat(result_index)}, {"{{RESULT_NAMES_CODE}}", result_names_code}, {"{{TEMP_BYTES_ALIGNED}}", strings::StrCat(temp_bytes_aligned)}, {"{{TEMP_BYTES_TOTAL}}", strings::StrCat(temp_bytes_total)}, {"{{TEMP_NUM}}", strings::StrCat(temp_sizes.size())}, {"{{TEMP_SIZES}}", str_util::Join(temp_sizes, ", ")}, - }; + {"{{DECLS_FROM_OBJ_FILE}}", + str_util::Join(metadata_result.header_variable_decls, "\n")}, + {"{{PROGRAM_SHAPE_SHIM_EXPRESSION}}", + metadata_result.program_shape_access_shim}}; str_util::ReplaceAllPairs(header, rewrites); return Status::OK(); } +static string CreateUniqueIdentifierForProgramShape(const CodegenOpts& opts) { + string result = "__tfcompile"; + for (const string& n : opts.namespaces) { + strings::StrAppend(&result, "_", n); + } + + strings::StrAppend(&result, "_", opts.class_name, "_ProgramShape"); + return result; +} + +Status GenerateMetadata(const CodegenOpts& opts, + const CompileResult& compile_result, + MetadataResult* metadata_result) { + std::unique_ptr program_shape; + + if (opts.gen_program_shape) { + program_shape = + tensorflow::MakeUnique(compile_result.program_shape); + // The parameter names are currently meaningless, and redundant with the + // rest of our metadata, so clear them out to avoid confusion and save + // space. + program_shape->clear_parameter_names(); + } + + // When asked to serialize a null protobuf, CreateEmbeddedProtocolBuffer gives + // a shim that evaluates to nullptr, which is what we want. + + TF_ASSIGN_OR_RETURN( + EmbeddedProtocolBuffer embedded_program_shape, + CreateEmbeddedProtocolBuffer(opts.target_triple, + CreateUniqueIdentifierForProgramShape(opts), + "xla::ProgramShape", program_shape.get())); + + metadata_result->program_shape_access_shim = + std::move(embedded_program_shape.cpp_shim_expression); + metadata_result->header_variable_decls.emplace_back( + std::move(embedded_program_shape.cpp_variable_decl)); + metadata_result->object_file_data = + std::move(embedded_program_shape.object_file_data); + return Status::OK(); +} + Status ParseCppClass(const string& cpp_class, string* class_name, std::vector* namespaces) { class_name->clear(); diff --git a/tensorflow/compiler/aot/codegen.h b/tensorflow/compiler/aot/codegen.h index 76dd0cc3cf9470a1beb2a4725724f640aecfec7f..3430b1f96cf4d3c035b76c77ccf124c5d164751e 100644 --- a/tensorflow/compiler/aot/codegen.h +++ b/tensorflow/compiler/aot/codegen.h @@ -26,11 +26,15 @@ limitations under the License. namespace tensorflow { namespace tfcompile { -// HeaderOpts specifies options for header-file generation. -struct HeaderOpts { +// CodegenOpts specifies code generation options for the generated header file +// and the generated metadata object file. +struct CodegenOpts { // The name of the generated C++ class, wrapping the generated function. string class_name; + // Target triple for the architecture we're targeting. + string target_triple; + // Namespaces specifies a list of C++ namespaces to add to the generated // header. If empty, all symbols will be in the global namespace. std::vector namespaces; @@ -42,11 +46,36 @@ struct HeaderOpts { bool gen_program_shape = false; }; +// Describes a generated metadata object file. +struct MetadataResult { + // These are top level "extern C" declarations that are expected to be visible + // wherever program_shape_access_shim is emitted. + std::vector header_variable_decls; + + // program_shape_access_shim is a C++ expression that constructs the + // xla::ProgramShape instance for the CompileResult passed to + // GenerateMetadata. + string program_shape_access_shim; + + // The contents of the object (".o") file. + string object_file_data; +}; + +// Generates a metadata object file according to `opts` and `compile_result`. +// The generated object file is returned via `metadata_result`. +Status GenerateMetadata(const CodegenOpts& opts, + const CompileResult& compile_result, + MetadataResult* metadata_result); + // GenerateHeader uses the meta-information from compile_result to generate a // C++ header giving access to the function in the generated object file. The // header includes API usage documentation. -Status GenerateHeader(const HeaderOpts& opts, const tf2xla::Config& config, - const CompileResult& compile_result, string* header); +// +// metadata_result is an instance of MetadataResult obtained by a previous +// invocation to GenerateMetadata. +Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config, + const CompileResult& compile_result, + const MetadataResult& metadata_result, string* header); // ParseCppClass parses `cpp_class` into its `class_name` and `namespaces` // components. The syntax is [[::],...]. This diff --git a/tensorflow/compiler/aot/codegen_test.cc b/tensorflow/compiler/aot/codegen_test.cc index 0f6114666fcc89c631434527d2ae8c92c039ffea..972b7d51ecb3798e61757ac55e973075a23b433a 100644 --- a/tensorflow/compiler/aot/codegen_test.cc +++ b/tensorflow/compiler/aot/codegen_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "llvm/Support/TargetSelect.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -123,9 +124,39 @@ TEST_F(ParseCppClassTest, ParseFail) { ExpectFail("good::0bad"); } -TEST(GenerateHeader, Golden) { - HeaderOpts opts; +static void CompareWithGoldenFile( + const string& tensorflow_relative_golden_file_name, + const string& expected_contents) { + // To update the golden file, flip update_golden to true and run the + // following: + // bazel test --test_strategy=local \ + // third_party/tensorflow/compiler/aot:codegen_test + const bool update_golden = false; + const string golden_file_name = io::JoinPath( + testing::TensorFlowSrcRoot(), tensorflow_relative_golden_file_name); + + if (update_golden) { + TF_EXPECT_OK( + WriteStringToFile(Env::Default(), golden_file_name, expected_contents)); + } + + string golden_file_contents; + TF_ASSERT_OK(ReadFileToString(Env::Default(), golden_file_name, + &golden_file_contents)); + EXPECT_EQ(golden_file_contents, expected_contents); +} + +TEST(CodegenTest, Golden) { + // Normally CpuCompiler::CpuCompiler does this, but in this test we've + // bypassed the Cpu compiler so we have to do this manually. + llvm::InitializeNativeTarget(); + llvm::InitializeNativeTargetAsmPrinter(); + LLVMInitializeX86Target(); + LLVMInitializeX86TargetMC(); + + CodegenOpts opts; opts.class_name = "MyClass"; + opts.target_triple = "x86_64-pc-linux"; opts.namespaces = {"foo", "bar"}; opts.gen_name_to_index = true; opts.gen_program_shape = true; @@ -145,32 +176,27 @@ TEST(GenerateHeader, Golden) { { xla::ShapeUtil::MakeShape(xla::F32, {1, 2}), xla::ShapeUtil::MakeShape(xla::S64, {3, 4}), - xla::ShapeUtil::MakeOpaqueShape(), }, xla::ShapeUtil::MakeTupleShape( {xla::ShapeUtil::MakeShape(xla::U32, {5, 6})})); - compile_result.has_context_arg = true; compile_result.entry_point = "entry_point"; compile_result.pointer_size = 8; + + MetadataResult metadata_result; + TF_ASSERT_OK(GenerateMetadata(opts, compile_result, &metadata_result)); + + // The other fields in metadata_result are tested as part of the generated + // header test. + + CompareWithGoldenFile("compiler/aot/codegen_test_o.golden", + metadata_result.object_file_data); + string header; - TF_EXPECT_OK(GenerateHeader(opts, config, compile_result, &header)); + TF_ASSERT_OK( + GenerateHeader(opts, config, compile_result, metadata_result, &header)); - // Compare against the golden file. - const string golden_name = io::JoinPath(testing::TensorFlowSrcRoot(), - "compiler/aot/codegen_test_h.golden"); - // To update the golden file, flip update_golden to true and run the - // following: - // bazel test --test_strategy=local \ - // third_party/tensorflow/compiler/aot:codegen_test - const bool update_golden = false; - if (update_golden) { - TF_EXPECT_OK(WriteStringToFile(Env::Default(), golden_name, header)); - } - string golden_data; - TF_EXPECT_OK(ReadFileToString(Env::Default(), golden_name, &golden_data)); - EXPECT_EQ(header, golden_data); + CompareWithGoldenFile("compiler/aot/codegen_test_h.golden", header); } - } // namespace } // namespace tfcompile } // namespace tensorflow diff --git a/tensorflow/compiler/aot/codegen_test_h.golden b/tensorflow/compiler/aot/codegen_test_h.golden index 65f342ce27ef09092f252f791973f245a8cdd6f3..ac3b5873318873b5fdf41bd556a0b2abddc2b30b 100644 --- a/tensorflow/compiler/aot/codegen_test_h.golden +++ b/tensorflow/compiler/aot/codegen_test_h.golden @@ -19,7 +19,9 @@ namespace xla { class ExecutableRunOptions; } // (Implementation detail) Entry point to the function in the object file. extern "C" void entry_point( void* result, const xla::ExecutableRunOptions* run_options, - const void** args, void** temps); + const void** args, void** temps, tensorflow::int64* profile_counters); + +extern "C" char __tfcompile_foo_bar_MyClass_ProgramShape_protobuf_array_contents[]; namespace foo { namespace bar { @@ -48,7 +50,7 @@ namespace bar { // is guaranteed that no thread may call a non-const method. // // The logical function signature is: -// ((unknown): f32[1,2], (unknown): s64[3,4], (unknown): opaque[]) -> (u32[5,6]) +// ((unknown): f32[1,2], (unknown): s64[3,4]) -> (u32[5,6]) // // Memory stats: // arg bytes total: 104 @@ -58,11 +60,11 @@ namespace bar { class MyClass : public tensorflow::XlaCompiledCpuFunction { public: // Number of input arguments for the compiled computation. - static constexpr size_t kNumArgs = 3; + static constexpr size_t kNumArgs = 2; // Byte size of each argument buffer. There are kNumArgs entries. static const intptr_t* ArgSizes() { - static constexpr intptr_t kArgSizes[kNumArgs] = {8, 96, -1}; + static constexpr intptr_t kArgSizes[kNumArgs] = {8, 96}; return kArgSizes; } @@ -77,7 +79,6 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction { data->temp_sizes = TempSizes(); data->num_temps = kNumTemps; data->result_index = kResultIndex; - data->requires_runtime_context = true; data->arg_names = StaticArgNames(); data->result_names = StaticResultNames(); data->program_shape = StaticProgramShape(); @@ -86,7 +87,7 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction { return *kStaticData; } - MyClass(AllocMode alloc_mode = AllocMode::ARGS_RESULTS_AND_TEMPS) + MyClass(AllocMode alloc_mode = AllocMode::ARGS_RESULTS_PROFILES_AND_TEMPS) : XlaCompiledCpuFunction(StaticData(), alloc_mode) {} MyClass(const MyClass&) = delete; @@ -99,8 +100,8 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction { // void set_argN_data(void* data) // Sets the buffer of type T for positional argument N. May be called in // any AllocMode. Must be called before Run to have an affect. Must be - // called in AllocMode::RESULTS_AND_TEMPS_ONLY for each positional argument, - // to set the argument buffers. + // called in AllocMode::RESULTS_PROFILES_AND_TEMPS_ONLY for each positional + // argument, to set the argument buffers. // // T* argN_data() // Returns the buffer of type T for positional argument N. @@ -236,12 +237,10 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction { // Shape of the args and results. static const xla::ProgramShape* StaticProgramShape() { static const xla::ProgramShape* kShape = []() { - static const char kProto[] = {10,12,16,11,26,2,1,2,42,4,10,2,1,0,10,12,16,5,26,2,3,4,42,4,10,2,1,0,10,2,16,14,18,16,16,13,34,12,16,8,26,2,5,6,42,4,10,2,1,0}; - static constexpr int kProtoSize = 50; - xla::ProgramShape* shape = new xla::ProgramShape; - shape->ParseFromArray(kProto, kProtoSize); - return shape; - }(); + xla::ProgramShape* proto = new xla::ProgramShape; + proto->ParseFromArray(&__tfcompile_foo_bar_MyClass_ProgramShape_protobuf_array_contents[0], 52); + return proto; + }(); return kShape; } }; diff --git a/tensorflow/compiler/aot/codegen_test_o.golden b/tensorflow/compiler/aot/codegen_test_o.golden new file mode 100644 index 0000000000000000000000000000000000000000..eb001c5d45bdfefc76629d7303d89f5480432235 Binary files /dev/null and b/tensorflow/compiler/aot/codegen_test_o.golden differ diff --git a/tensorflow/compiler/aot/compile.cc b/tensorflow/compiler/aot/compile.cc index 2b8cc6024cb85e4f6269313927ff66d1d9a1cf79..c87f2b75dfa18ad5c3eda4bd6fcbcb3083ef73fd 100644 --- a/tensorflow/compiler/aot/compile.cc +++ b/tensorflow/compiler/aot/compile.cc @@ -94,9 +94,8 @@ Status CompileGraph(const GraphDef& graph_def, const tf2xla::Config& config, xla::ClientLibrary::GetOrCreateCompileOnlyClient(cpu_platform) .ValueOrDie(); xla::Computation computation; - TF_RETURN_IF_ERROR(ConvertGraphDefToXla(graph_def, config, client, - &computation, - &compile_result->has_context_arg)); + TF_RETURN_IF_ERROR( + ConvertGraphDefToXla(graph_def, config, client, &computation)); if (!flags.out_session_module.empty()) { TF_ASSIGN_OR_RETURN(std::unique_ptr module, computation.Snapshot()); diff --git a/tensorflow/compiler/aot/compile.h b/tensorflow/compiler/aot/compile.h index 965c2960816b3acc8d2209e6824d88647de0ce14..e03c5b1aa77c1262ed903aae3072ef65f34d80a2 100644 --- a/tensorflow/compiler/aot/compile.h +++ b/tensorflow/compiler/aot/compile.h @@ -34,7 +34,6 @@ struct CompileResult { // Contains object file and meta-info. std::unique_ptr aot; xla::ProgramShape program_shape; // Static shape of args and results. - bool has_context_arg = false; // Is last arg XlaLocalRuntimeContext? string entry_point; // Name of generated function. int pointer_size = 0; // Size of a pointer in bytes. }; diff --git a/tensorflow/compiler/aot/embedded_protocol_buffers.cc b/tensorflow/compiler/aot/embedded_protocol_buffers.cc new file mode 100644 index 0000000000000000000000000000000000000000..6489929a576d6469c4ff1358ca5ee9d27fb578bb --- /dev/null +++ b/tensorflow/compiler/aot/embedded_protocol_buffers.cc @@ -0,0 +1,158 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/aot/embedded_protocol_buffers.h" + +#include +#include + +#include "llvm/ADT/Triple.h" +#include "llvm/ExecutionEngine/ObjectMemoryBuffer.h" +#include "llvm/IR/GlobalVariable.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/LegacyPassManager.h" +#include "llvm/IR/Module.h" +#include "llvm/Support/TargetRegistry.h" +#include "llvm/Target/TargetMachine.h" +#include "llvm/Target/TargetOptions.h" +#include "tensorflow/compiler/tf2xla/str_util.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" +#include "tensorflow/compiler/xla/util.h" + +namespace tensorflow { +namespace tfcompile { + +using xla::llvm_ir::AsStringRef; + +static std::unique_ptr CreateModuleWithEmbeddedProtocolBuffer( + llvm::LLVMContext* llvm_context, llvm::TargetMachine* target_machine, + const ::tensorflow::protobuf::MessageLite& proto, + StringPiece unique_identifier, string* protobuf_array_symbol_name, + int64* protobuf_array_size) { + string protobuf_array_contents = proto.SerializeAsString(); + *protobuf_array_symbol_name = + strings::StrCat(unique_identifier, "_protobuf_array_contents"); + *protobuf_array_size = protobuf_array_contents.size(); + + std::unique_ptr module = + MakeUnique("embedded_data_module", *llvm_context); + + llvm::Constant* protobuf_array_initializer = + llvm::ConstantDataArray::getString(*llvm_context, + AsStringRef(protobuf_array_contents), + /*AddNull=*/false); + new llvm::GlobalVariable( + *module, protobuf_array_initializer->getType(), + /*isConstant=*/true, llvm::GlobalValue::ExternalLinkage, + protobuf_array_initializer, AsStringRef(*protobuf_array_symbol_name)); + + return module; +} + +static string CreateCPPShimExpression(StringPiece qualified_cpp_protobuf_name, + StringPiece protobuf_array_symbol_name, + int64 protobuf_array_size) { + string code = + "[]() {\n" + " {{PROTOBUF_NAME}}* proto = new {{PROTOBUF_NAME}};\n" + " proto->ParseFromArray(&{{ARRAY_SYMBOL}}[0], {{ARRAY_SIZE}});\n" + " return proto;\n" + " }()"; + + str_util::ReplaceAllPairs( + &code, + { + {"{{ARRAY_SYMBOL}}", strings::StrCat(protobuf_array_symbol_name)}, + {"{{ARRAY_SIZE}}", strings::StrCat(protobuf_array_size)}, + {"{{PROTOBUF_NAME}}", strings::StrCat(qualified_cpp_protobuf_name)}, + }); + return code; +} + +static StatusOr CodegenModule(llvm::TargetMachine* target_machine, + std::unique_ptr module) { + llvm::SmallVector stream_buffer; + llvm::raw_svector_ostream ostream(stream_buffer); + llvm::legacy::PassManager codegen_passes; + + if (target_machine->addPassesToEmitFile( + codegen_passes, ostream, llvm::TargetMachine::CGFT_ObjectFile)) { + return xla::InternalError( + "Could not create pass pipeline to generate object file"); + } + + codegen_passes.run(*module); + + return string(stream_buffer.begin(), stream_buffer.end()); +} + +static StatusOr> +GetTargetMachineFromTriple(StringPiece target_triple) { + std::string error; + std::string normalized_triple = + llvm::Triple::normalize(AsStringRef(target_triple)); + const llvm::Target* target = + llvm::TargetRegistry::lookupTarget(normalized_triple, error); + if (target == nullptr) { + return xla::InternalError("TargetRegistry::lookupTarget failed: %s", + error.c_str()); + } + + return WrapUnique(target->createTargetMachine( + normalized_triple, /*CPU=*/"", + /*Features=*/"", llvm::TargetOptions(), llvm::None)); +} + +StatusOr CreateEmbeddedProtocolBuffer( + StringPiece target_triple, StringPiece symbol_prefix, + StringPiece qualified_cpp_protobuf_name, + const ::tensorflow::protobuf::MessageLite* proto) { + TF_ASSIGN_OR_RETURN(std::unique_ptr target_machine, + GetTargetMachineFromTriple(target_triple)); + + llvm::LLVMContext llvm_context; + string object_file, cpp_shim, cpp_variable_decl; + + if (proto) { + string protobuf_array_symbol_name; + int64 protobuf_array_size; + + std::unique_ptr module_with_serialized_proto = + CreateModuleWithEmbeddedProtocolBuffer( + &llvm_context, target_machine.get(), *proto, symbol_prefix, + &protobuf_array_symbol_name, &protobuf_array_size); + TF_ASSIGN_OR_RETURN(object_file, + CodegenModule(target_machine.get(), + std::move(module_with_serialized_proto))); + cpp_shim = CreateCPPShimExpression(qualified_cpp_protobuf_name, + protobuf_array_symbol_name, + protobuf_array_size); + + cpp_variable_decl = strings::StrCat("extern \"C\" char ", + protobuf_array_symbol_name, "[];"); + } else { + TF_ASSIGN_OR_RETURN( + object_file, + CodegenModule(target_machine.get(), + MakeUnique("empty_module", llvm_context))); + cpp_shim = "nullptr"; + } + + return {{cpp_shim, cpp_variable_decl, object_file}}; +} + +} // namespace tfcompile +} // namespace tensorflow diff --git a/tensorflow/compiler/aot/embedded_protocol_buffers.h b/tensorflow/compiler/aot/embedded_protocol_buffers.h new file mode 100644 index 0000000000000000000000000000000000000000..8436e0ff67f352a24e3d16b46f16c1ad2f3a5957 --- /dev/null +++ b/tensorflow/compiler/aot/embedded_protocol_buffers.h @@ -0,0 +1,73 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file defines utilities to help "embed" protocol buffers into object +// (".o") files. These C++ binaries and shared objects can link in these .o to +// get access to said protocol buffers at runtime. + +#ifndef TENSORFLOW_COMPILER_AOT_EMBEDDED_PROTOCOL_BUFFERS_H_ +#define TENSORFLOW_COMPILER_AOT_EMBEDDED_PROTOCOL_BUFFERS_H_ + +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/platform/protobuf.h" + +namespace tensorflow { +namespace tfcompile { +using xla::StatusOr; + +// Represents a protocol buffer embedded into an object file and describes a way +// to access it at runtime. +struct EmbeddedProtocolBuffer { + // cpp_shim_expression is a C++ expression that creates an instance of said + // protocol buffer when executed. + string cpp_shim_expression; + + // cpp_variable_decl is an "extern C" array declaration that is used in + // cpp_shim_expression. It must be visible wherever cpp_shim_expression is + // emitted. + string cpp_variable_decl; + + // The contents of the object (".o") file the protocol buffer is embbed in. + // This needs to be linked in to any program that wants to execute + // cpp_variable_decl . + string object_file_data; +}; + +// Creates an object file that contains `proto`. +// +// `proto` is allowed to be nullptr, in which case the generated C++ shim +// expression is just `nullptr`, and the generated object file does not define +// any symbols. +// +// `target_triple` is the target triple for the target architecture for the +// generated object file. +// +// `symbol_prefix` is prefix that is guaranteed to be unique across the binary +// or DSO the generated object file will be linked into. +// +// `qualified_cpp_protobuf_name` is a qualified ("qualified" as in C++ +// namespace qualified) protocol buffer name. This needs is only used in +// EmbeddedProtocolBuffer::cpp_shim_expression so relatively qualified +// names are fine as long as they're valid wherever cpp_shim_expression +// is emitted. +StatusOr CreateEmbeddedProtocolBuffer( + StringPiece target_triple, StringPiece symbol_prefix, + StringPiece qualified_cpp_protobuf_name, + const ::tensorflow::protobuf::MessageLite* proto); + +} // namespace tfcompile +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_AOT_EMBEDDED_PROTOCOL_BUFFERS_H_ diff --git a/tensorflow/compiler/aot/flags.cc b/tensorflow/compiler/aot/flags.cc index 7c2f27e550d44c2487f91acf1029c962ac3f5d01..8c95cb8f90ee031fdbb97fabd9d86f848b42e4c5 100644 --- a/tensorflow/compiler/aot/flags.cc +++ b/tensorflow/compiler/aot/flags.cc @@ -59,8 +59,13 @@ void AppendMainFlags(std::vector* flag_list, MainFlags* flags) { "namespaces may precede the class name, separated by double-colons. " "The class will be generated in the given namespace(s), or if no " "namespaces are given, within the global namespace."}, - {"out_object", &flags->out_object, "Output object file name."}, + {"out_function_object", &flags->out_function_object, + "Output object file containing the generated function for the " + "TensorFlow model."}, {"out_header", &flags->out_header, "Output header file name."}, + {"out_metadata_object", &flags->out_metadata_object, + "Output object file name containing optional metadata for the generated " + "function."}, {"out_session_module", &flags->out_session_module, "Output session module proto."}, {"gen_name_to_index", &flags->gen_name_to_index, diff --git a/tensorflow/compiler/aot/flags.h b/tensorflow/compiler/aot/flags.h index 3519659e3af7cd345f30080a07ce91fb858623fb..d266fbead61f7eb43863d1c67c0f86926ae9452d 100644 --- a/tensorflow/compiler/aot/flags.h +++ b/tensorflow/compiler/aot/flags.h @@ -34,7 +34,8 @@ struct MainFlags { string target_features; string entry_point; string cpp_class; - string out_object; + string out_function_object; + string out_metadata_object; string out_header; string out_session_module; diff --git a/tensorflow/compiler/aot/runtime_test.cc b/tensorflow/compiler/aot/runtime_test.cc index ac79c278c1fdf8b6aedcb52121c767b8ba0ad358..6d603a02eb4ceade6832ba67b2981814ee25327a 100644 --- a/tensorflow/compiler/aot/runtime_test.cc +++ b/tensorflow/compiler/aot/runtime_test.cc @@ -15,7 +15,6 @@ limitations under the License. #include "tensorflow/compiler/aot/runtime.h" -#include "tensorflow/compiler/tf2xla/xla_local_runtime_context.h" #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/platform/test.h" diff --git a/tensorflow/compiler/aot/tests/make_test_graphs.py b/tensorflow/compiler/aot/tests/make_test_graphs.py index a898eab1d1ab0eb5d55983bf366753c968887296..89c7cd4507cbd476104a039d6083d8f89de11278 100644 --- a/tensorflow/compiler/aot/tests/make_test_graphs.py +++ b/tensorflow/compiler/aot/tests/make_test_graphs.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function import argparse +import os import sys from tensorflow.core.protobuf import saver_pb2 @@ -53,7 +54,7 @@ def tfadd_with_ckpt(out_dir): sess.run(init_op) sess.run(y.assign(y + 42)) # Without the checkpoint, the variable won't be set to 42. - ckpt = '%s/test_graph_tfadd_with_ckpt.ckpt' % out_dir + ckpt = os.path.join(out_dir, 'test_graph_tfadd_with_ckpt.ckpt') saver.save(sess, ckpt) @@ -68,10 +69,10 @@ def tfadd_with_ckpt_saver(out_dir): sess.run(init_op) sess.run(y.assign(y + 42)) # Without the checkpoint, the variable won't be set to 42. - ckpt_file = '%s/test_graph_tfadd_with_ckpt_saver.ckpt' % out_dir + ckpt_file = os.path.join(out_dir, 'test_graph_tfadd_with_ckpt_saver.ckpt') saver.save(sess, ckpt_file) # Without the SaverDef, the restore op won't be named correctly. - saver_file = '%s/test_graph_tfadd_with_ckpt_saver.saver' % out_dir + saver_file = os.path.join(out_dir, 'test_graph_tfadd_with_ckpt_saver.saver') with open(saver_file, 'wb') as f: f.write(saver.as_saver_def().SerializeToString()) @@ -129,7 +130,7 @@ def write_graph(build_graph, out_dir): g = ops.Graph() with g.as_default(): build_graph(out_dir) - filename = '%s/test_graph_%s.pb' % (out_dir, build_graph.__name__) + filename = os.path.join(out_dir, 'test_graph_%s.pb' % build_graph.__name__) with open(filename, 'wb') as f: f.write(g.as_graph_def().SerializeToString()) diff --git a/tensorflow/compiler/aot/tests/tfcompile_test.cc b/tensorflow/compiler/aot/tests/tfcompile_test.cc index 6b037f276ad1d6771b904bb970f45f32ae9531b8..413efd9cea3b6f71574615ad9ca92471ff925781 100644 --- a/tensorflow/compiler/aot/tests/tfcompile_test.cc +++ b/tensorflow/compiler/aot/tests/tfcompile_test.cc @@ -70,7 +70,7 @@ TEST(TFCompileTest, Add) { // Run tests that use set_argN_data separately, to avoid accidentally re-using // non-existent buffers. TEST(TFCompileTest, Add_SetArg) { - AddComp add(AddComp::AllocMode::RESULTS_AND_TEMPS_ONLY); + AddComp add(AddComp::AllocMode::RESULTS_PROFILES_AND_TEMPS_ONLY); int32 arg_x = 10; int32 arg_y = 32; @@ -258,7 +258,7 @@ TEST(TFCompileTest, MatMul2_SetArg) { Eigen::ThreadPoolDevice device(&tp, tp.NumThreads()); foo::bar::MatMulComp matmul( - foo::bar::MatMulComp::AllocMode::RESULTS_AND_TEMPS_ONLY); + foo::bar::MatMulComp::AllocMode::RESULTS_PROFILES_AND_TEMPS_ONLY); matmul.set_thread_pool(&device); // Test using the set_argN_data() methods. diff --git a/tensorflow/compiler/aot/tfcompile.bzl b/tensorflow/compiler/aot/tfcompile.bzl index 6c385af3b36df78b3f674b3464d68d904ca92907..2b9c83ba149adf9e089786b91039e256216579c8 100644 --- a/tensorflow/compiler/aot/tfcompile.bzl +++ b/tensorflow/compiler/aot/tfcompile.bzl @@ -128,7 +128,8 @@ def tf_library(name, graph, config, # Rule that runs tfcompile to produce the header and object file. header_file = name + ".h" - object_file = name + ".o" + metadata_object_file = name + "_tfcompile_metadata.o" + function_object_file = name + "_tfcompile_function.o" ep = ("__" + PACKAGE_NAME + "__" + name).replace("/", "_") if type(tfcompile_flags) == type(""): flags = tfcompile_flags @@ -142,7 +143,8 @@ def tf_library(name, graph, config, ], outs=[ header_file, - object_file, + metadata_object_file, + function_object_file, ], cmd=("$(location " + tfcompile_tool + ")" + " --graph=$(location " + tfcompile_graph + ")" + @@ -151,7 +153,8 @@ def tf_library(name, graph, config, " --cpp_class=" + cpp_class + " --target_triple=" + target_llvm_triple() + " --out_header=$(@D)/" + header_file + - " --out_object=$(@D)/" + object_file + + " --out_metadata_object=$(@D)/" + metadata_object_file + + " --out_function_object=$(@D)/" + function_object_file + " " + flags), tools=[tfcompile_tool], visibility=visibility, @@ -202,7 +205,7 @@ def tf_library(name, graph, config, need_xla_data_proto = (flags and flags.find("--gen_program_shape") != -1) native.cc_library( name=name, - srcs=[object_file], + srcs=[function_object_file, metadata_object_file], hdrs=[header_file], visibility=visibility, testonly=testonly, @@ -267,7 +270,6 @@ def tf_library(name, graph, config, srcs=[test_file], deps=[ ":" + name, - "@org_tensorflow//tensorflow/compiler/tf2xla:xla_local_runtime_context", "@org_tensorflow//tensorflow/compiler/aot:runtime", "@org_tensorflow//tensorflow/compiler/aot:tf_library_test_main", "@org_tensorflow//tensorflow/compiler/xla:executable_run_options", @@ -313,7 +315,6 @@ def tf_library(name, graph, config, linkopts = if_android(["-pie", "-s"]), deps=[ ":" + name, - "@org_tensorflow//tensorflow/compiler/tf2xla:xla_local_runtime_context", "@org_tensorflow//tensorflow/compiler/aot:benchmark", "@org_tensorflow//tensorflow/compiler/aot:runtime", "@org_tensorflow//tensorflow/compiler/xla:executable_run_options", diff --git a/tensorflow/compiler/aot/tfcompile_main.cc b/tensorflow/compiler/aot/tfcompile_main.cc index 6ab3d474187c7df2131f94c9f42f0d0f2f9d99d7..e2f01179d4e2e4f6ef72b2761d06e130ffa3a94f 100644 --- a/tensorflow/compiler/aot/tfcompile_main.cc +++ b/tensorflow/compiler/aot/tfcompile_main.cc @@ -91,19 +91,26 @@ Status Main(const MainFlags& flags) { // Write output files. Env* env = Env::Default(); const std::vector& obj = compile_result.aot->object_file_data(); - TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_object, + TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_function_object, StringPiece(obj.data(), obj.size()))); - HeaderOpts header_opts; - header_opts.gen_name_to_index = flags.gen_name_to_index; - header_opts.gen_program_shape = flags.gen_program_shape; + CodegenOpts codegen_opts; + codegen_opts.gen_name_to_index = flags.gen_name_to_index; + codegen_opts.gen_program_shape = flags.gen_program_shape; + codegen_opts.target_triple = flags.target_triple; if (flags.cpp_class.empty()) { return errors::InvalidArgument("Must specify --cpp_class"); } - TF_RETURN_IF_ERROR(ParseCppClass(flags.cpp_class, &header_opts.class_name, - &header_opts.namespaces)); - string header; + TF_RETURN_IF_ERROR(ParseCppClass(flags.cpp_class, &codegen_opts.class_name, + &codegen_opts.namespaces)); + + MetadataResult metadata_result; TF_RETURN_IF_ERROR( - GenerateHeader(header_opts, config, compile_result, &header)); + GenerateMetadata(codegen_opts, compile_result, &metadata_result)); + TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_metadata_object, + metadata_result.object_file_data)); + string header; + TF_RETURN_IF_ERROR(GenerateHeader(codegen_opts, config, compile_result, + metadata_result, &header)); TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_header, header)); return Status::OK(); } @@ -114,7 +121,8 @@ Status Main(const MainFlags& flags) { int main(int argc, char** argv) { tensorflow::tfcompile::MainFlags flags; flags.target_triple = "x86_64-pc-linux"; - flags.out_object = "out.o"; + flags.out_function_object = "out_model.o"; + flags.out_metadata_object = "out_helper.o"; flags.out_header = "out.h"; flags.entry_point = "entry"; diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index bf7d9cf14d10f41aa48ea594a8d63db97b9973e1..a711319607f4ff2b83aa0ebe50e215b3d0e2258e 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -110,19 +110,6 @@ cc_library( alwayslink = True, ) -# Internal targets below this point. - -cc_library( - name = "common", - srcs = [ - "defs.cc", - ], - hdrs = [ - "defs.h", - ], - visibility = [":friends"], -) - cc_library( name = "xla_device", srcs = [ @@ -135,6 +122,8 @@ cc_library( "xla_device_context.h", "xla_device_ops.h", ], + # Public visibility is needed for external TF/XLA backends. + visibility = ["//visibility:public"], deps = [ ":common", ":jit_compilation_passes", @@ -164,6 +153,19 @@ cc_library( ], ) +# Internal targets below this point. + +cc_library( + name = "common", + srcs = [ + "defs.cc", + ], + hdrs = [ + "defs.h", + ], + visibility = [":friends"], +) + cc_library( name = "xla_compilation_cache", srcs = ["xla_compilation_cache.cc"], @@ -215,7 +217,6 @@ cc_library( ":common", ":compilation_passes", "//tensorflow/compiler/jit/kernels:xla_launch_op", - "//tensorflow/compiler/tf2xla:const_analysis", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", @@ -245,12 +246,13 @@ cc_library( "//tensorflow/compiler/jit/legacy_flags:mark_for_compilation_pass_flags", "//tensorflow/compiler/jit/ops:parallel_check_op", "//tensorflow/compiler/jit/ops:xla_ops", - "//tensorflow/compiler/tf2xla:const_analysis", "//tensorflow/compiler/tf2xla:dump_graph", "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/xla:status_macros", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", + "//tensorflow/core:graph", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc index 22899ebeebc929055518893b358f7950d380d6f6..0de163d3a8f082eab4d8d802485da1bbc56e8180 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc @@ -16,13 +16,18 @@ limitations under the License. #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" #include +#include #include +#include +#include +#include #include "tensorflow/compiler/jit/graph_to_functiondef.h" #include "tensorflow/compiler/jit/legacy_flags/encapsulate_subgraphs_pass_flags.h" #include "tensorflow/compiler/jit/mark_for_compilation_pass.h" #include "tensorflow/compiler/tf2xla/const_analysis.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" +#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/optimization_registry.h" #include "tensorflow/core/framework/function.h" @@ -32,6 +37,7 @@ limitations under the License. #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/tensor_id.h" +#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/strings/str_util.h" @@ -48,19 +54,75 @@ const char* const kXlaNumResourceArgsAttr = "_XlaNumResourceArgs"; namespace { +bool AreAllParentsConst(const Node& n, + const gtl::FlatSet& runtime_const_nodes) { + if (n.type_string() == "GuaranteeConst" || n.type_string() == "Const") { + // If the current node is itself a cast-to-const, no need + // to look at the incoming edges. + return true; + } + + bool all_parents_const = true; + bool atleast_one_non_control_edge = false; + for (const Edge* in : n.in_edges()) { + atleast_one_non_control_edge = + atleast_one_non_control_edge || !in->IsControlEdge(); + if (!in->IsControlEdge() && runtime_const_nodes.count(in->src()) == 0) { + all_parents_const = false; + break; + } + } + return all_parents_const && atleast_one_non_control_edge; +} + +void MarkGuaranteedConstants( + const Graph& graph, + const std::vector>& src_arg_pairs) { + gtl::FlatSet guaranteed_const_nodes; + std::vector srcs; + srcs.reserve(src_arg_pairs.size()); + for (const auto& src_arg : src_arg_pairs) { + srcs.push_back(src_arg.first); + } + ReverseDFSFrom(graph, srcs, /*enter=*/nullptr, + /*leave=*/[&guaranteed_const_nodes](const Node* n) { + // TODO(vinuraja): Doesn't work in the presence of loops. + if (AreAllParentsConst(*n, guaranteed_const_nodes)) { + guaranteed_const_nodes.insert(n); + } + }); + + for (auto& src_arg : src_arg_pairs) { + if (guaranteed_const_nodes.count(src_arg.first) != 0) { + VLOG(1) << "Guaranteed const found: " << src_arg.first->DebugString(); + src_arg.second->AddAttr("_is_guaranteed_constant", true); + } + } +} + // A node/slot pair. // TODO(phawkins): is there a common definition of this? struct NodeSlot { - NodeSlot() : node(nullptr), slot(-1) {} - NodeSlot(const Node* node, int slot) : node(node), slot(slot) {} + NodeSlot() : node(nullptr), slot(-1), dtype(DT_INVALID) {} + NodeSlot(const Node* node, int slot) + : node(node), slot(slot), dtype(DT_INVALID) {} + NodeSlot(const Node* node, int slot, DataType dtype) + : node(node), slot(slot), dtype(dtype) {} const Node* node; int slot; + // Optional: used to record the destination type of a source NodeSlot in case + // the source output is a Ref type that is cast to a Tensor at the + // destination. + DataType dtype; + bool operator==(const NodeSlot& other) const { - return node == other.node && slot == other.slot; + return node == other.node && slot == other.slot && dtype == other.dtype; } + // Leave dtype out of the hash since there are never two NodeSlots with the + // same node and slot and different dtypes. struct Hasher { uint64 operator()(NodeSlot const& s) const { return Hash64Combine(std::hash()(s.node), @@ -75,10 +137,23 @@ struct NodeSlot { }; }; +// TODO(phawkins) add a canonical copy of these operator names and refactor +// everything to use it. +static const char* const kArgOp = "_Arg"; +static const char* const kRetValOp = "_Retval"; +static const char* const kSendToHostOp = "_XlaSendToHost"; +static const char* const kRecvFromHostOp = "_XlaRecvFromHost"; +static const char* const kSendFromHostOp = "_XlaSendFromHost"; +static const char* const kRecvAtHostOp = "_XlaRecvAtHost"; + class Encapsulator { public: - Encapsulator(string group_attribute, Graph const* graph_in) - : group_attribute_(std::move(group_attribute)), graph_in_(graph_in) {} + Encapsulator(string group_attribute, string outside_compilation_attribute, + Graph const* graph_in) + : group_attribute_(std::move(group_attribute)), + outside_compilation_attribute_( + std::move(outside_compilation_attribute)), + graph_in_(graph_in) {} // Find subgraphs marked with 'group_attribute', and build a new // subgraph, one for each value of 'group_attribute'. @@ -99,54 +174,350 @@ class Encapsulator { Status BuildOutputGraph(bool parallel_checking, Graph* graph_out); private: - // Returns the key attribute associated with a node. Returns the empty string - // if no key attribute is found. - string GetFunctionNameAttr(const Node* node) const; - // A subgraph of the input, all marked with a common 'group_attribute' - // value. - struct Subgraph { + // value. A subgraph may contain multiple `outside_compilation' clusters. + // + // In the following simple example, A, B, ..., E are nodes in the original + // graph. The group attributes and outside_compilation attributes g and oc are + // each shown as either 0 or empty. + // + // A --> B --> C --> D --> E + // g: g:0 g:0 g:0 g: + // oc: oc: oc:0 oc: oc: + // + // The example is rewritten to two graphs; one on the host and one to be + // compiled. The host graph is as follows. RAH is a RecvAtHost node receiving + // input from the compiled cluster, and SFH is a SendFromHost node sending + // input back to the compiled cluster. Dotted edges are control edges. A + // 'sequencing' node S is inserted, and both RAH and SFH are connected via S + // to E (and in general all nodes that depend on nodes in the compiled + // cluster) to ensure that they are not pruned. + // + // A --> Call --> E + // ^ + // . + // ........> S + // .... ^ + // .. . + // RAH --> C --> SFH + // + // The compiled cluster is as follows. STH is a SendToHost node which is the + // source of a channel to the RAH node above. RFH is a RecvFromHost node which + // is the destination of a channel from the SFH node above. There is a control + // edge that ensures RFH follows STH, which is used in shape inference to + // ensure that the shapes on the STH host channel are known before the RFH + // channel is compiled. + // + // Arg --> B --> STH ..> RFH --> D --> Retval + // + // The channels STH/RAH and SFH/RFH each transmit a tuple, so there is at most + // one RAH and SFH in each compiled cluster. This design is preferred over + // adding separate Arg/Retval nodes for each transmitted value because it + // simplifies the host code that would like to limit communication between + // host and device and, e.g., raise only one interrupt per channel rather than + // one per transmitted value. + class Subgraph { + public: + // Creates a graph to build the subgraph in, if it doesn't already exist, + // using the same op registry and versions as graph_in. + Node* MakeNodeImage(const Graph* graph_in, Node* node); + + // Returns the graph the subgraph is being built in. + Graph* GetGraph() const; + + // Builds a FunctionDef, and adds it to 'library'. The value of the + // 'group_attribute' annotations becomes the function name. If + // 'reuse_existing_functions' is set, use an existing function with the same + // name, if any. If 'rewrite_subgraph_fn' is set, it is applied to the + // subgraph before function conversion. + Status BuildFunctionDef(const string& name_in, + const RewriteSubgraphFn& rewrite_subgraph_fn, + bool reuse_existing_functions, + FunctionLibraryDefinition* library); + + // Adds the function call node to graph_out. + Status AddFunctionCallNode( + const std::unordered_map& node_images, + bool parallel_checking, Graph* graph_out); + + // Adds _RecvAtHost and _SendFromHost nodes, where needed, to graph_out. + Status AddOutsideCompilationHostIONodes( + const string& subgraph_name, + const std::unordered_map& node_images, + Graph* graph_out); + + // Returns the Node that inputs to the function should be wired up to. + Node* GetCallNodeForInputs() const; + + // Returns the Node that outputs to the function should be wired up to. + Node* GetCallNodeForOutputs() const; + + // Returns the index of the arg that the dst of edge should connect to. + int GetArgIndexForEdge(const Edge* edge) const; + + // Returns the index of the result that the src of edge should connect to. + int GetResultIndexForEdge(const Edge* edge) const; + + // Returns the RecvAtHost node for an outside_compilation subgraph. + Node* GetRecvAtHostNode( + const string& outside_compilation_subgraph_name) const; + + // Returns the output slot for the RecvAtHost node that corresponds to the + // source of edge in an outside_compilation subgraph. + int GetRecvAtHostSlot(const string& outside_compilation_subgraph_name, + const Edge* edge) const; + + // Returns the SendFromHost node for an outside_compilation subgraph. + Node* GetSendFromHostNode( + const string& outside_compilation_subgraph_name) const; + + // Returns the input slot for the SendFromHost node that corresponds to the + // destination of edge in an outside_compilation subgraph. + int GetSendFromHostSlot(const string& outside_compilation_subgraph_name, + const Edge* edge) const; + + // Creates an _Arg node for the src node of edge, and add its index to + // args_by_src_, if none exists yet. Also adds its index to args_by_dst_, + // and adds the edge within the subgraph from the _Arg node to the image of + // the dst node. + Status RecordArg(const Edge* edge, + const std::unordered_map& node_images, + std::vector>* src_arg_pairs); + + // Creates a _Retval node for the src node of edge, and add it to results_, + // if none exists yet. If a new _Retval node is created, also adds the edge + // within the subgraph from the src to the _Retval node. + Status RecordResult( + const Edge* edge, + const std::unordered_map& node_images); + + // Creates an outside_compilation subgraph for outside_compilation_id if + // none exists yet. Creates an entry for the src node of edge in the list of + // inputs for the outside_compilation subgraph, if none exists yet. + void RecordOutsideCompilationInputOrControl( + const string& outside_compilation_id, const Edge* edge); + + // Creates an outside_compilation subgraph for outside_compilation_id if + // none exists yet. Creates an entry for the src node of edge in the list of + // outputs by src for the outside_compilation subgraph, if none exists + // yet. Creates an entry for the dst node of edge in the list of outputs by + // dst for the outside_compilation subgraph. + void RecordOutsideCompilationOutputOrControl( + const string& outside_compilation_id, const Edge* edge); + + // Adds the SendToHost nodes for each outside_compilation subgraph once the + // edges have all been recorded via RecordOutsideCompilationInputOrControl. + Status AddSendsToOutsideCompilation( + const std::unordered_map& node_images); + + // Adds the RecvFromHost nodes for each outside_compilation subgraph once + // the edges have all been recorded via + // RecordOutsideCompilationOutputOrControl. + Status AddRecvsFromOutsideCompilation( + const std::unordered_map& node_images); + + // Creates the sequencer node if it doesn't exist, adding it to graph_out. + Status MakeSequencingNode(const string& subgraph_name, Graph* graph_out); + + // If there is a sequencer node, adds a control edge from the sequencer to + // all the downstream nodes of call_node_outputs. + void ConnectSequencerToOutputs(Graph* graph_out); + + private: + struct OutsideCompilationSubgraph { + // Map from source (producer node/slot) tensors in the original graph to + // input index (slot number in the SendToHost/RecvAtHost nodes that will + // be created) for the outside_compilation subgraph. + std::unordered_map inputs; + + // Set of nodes in the original graph that are the source of control edges + // that cross from the containing compiled subgraph into the + // outside_compilation subgraph. These are recorded by + // RecordOutsideCompilationInputOrControl while walking all the subgraph + // edges, and lifted control edges within the subgraph are added by + // AddSendsToOutsideCompilation once the _SendToHost node has been + // created. The matching control edge from _RecvAtHost to the + // destination is added by CopyEdgeToOutputGraph. + std::unordered_set control_inputs; + + // Maps from source (producer node/slot) and destination (consumer + // node/slot) tensors in the original graph to output index (slot number + // in the SendFromHost/RecvFromHost nodes that will be created) for the + // outside_compilation subgraph. + std::unordered_map outputs_by_src; + std::unordered_map outputs_by_dst; + + // Set of nodes in the original graph that are the destination of control + // edges that cross from the outside_compilation subgraph into the + // containing compiled subgraph. These are recorded by + // RecordOutsideCompilationOutputOrControl while walking all the subgraph + // edges, and lifted control edges within the subgraph are added by + // AddRecvsFromToOutsideCompilation once the _RecvFromHost node has been + // created. The matching control edge from the source to _SendFromHost to + // the destination is added by CopyEdgeToOutputGraph. + std::unordered_set control_outputs; + + // _SendToHost node in the subgraph. Not owned. + Node* send_to_host = nullptr; + + // _RecvAtHost node in the output graph. Not owned. + Node* recv_at_host = nullptr; + + // _SendFromHost node in the output graph. Not owned. + Node* send_from_host = nullptr; + }; + + // Builds a ParallelCheck op that compares the output of the original + // subgraph with the encapsulated subgraph. + Status BuildParallelCheckOp( + const std::unordered_map& node_images, + Graph* graph_out); + + // Builds a _RecvAtHost node producing all the inputs of an + // outside_compilation subgraph and stores it in oc_subgraph.recv_at_host. + Status AddRecvAtHostNode(const string& subgraph_name, + const string& oc_subgraph_name, + OutsideCompilationSubgraph* oc_subgraph, + Graph* graph_out); + + // Builds a _SendFromHost node consuming all the outputs of an + // outside_compilation subgraph and stores it in oc_subgraph.send_from_host. + Status AddSendFromHostNode( + const std::unordered_map& node_images, + const string& subgraph_name, const string& oc_subgraph_name, + OutsideCompilationSubgraph* oc_subgraph, Graph* graph_out); + // The subgraph extracted from the input graph, suitable for being turned // into a FunctionDef. Inputs are fed by _Arg nodes, and outputs are // returned by _Retval nodes. - std::unique_ptr graph; + std::unique_ptr graph_; // Which device are these nodes on? Used to assign a device to the call // node. - string device; + string device_; // NodeDef for the function call node. - NodeDef call_node_def; + NodeDef call_node_def_; // Function call node(s) in the output graph. Not owned. // If parallel_checking is enabled, 'call_node_inputs' is the function call // node to which inputs should be fed, and 'call_node_outputs' is the // parallel check op from which outputs should be read. If parallel checking // is disabled, both point to the function call node. - Node* call_node_inputs; - Node* call_node_outputs; + Node* call_node_inputs_; + Node* call_node_outputs_; // Maps from source (producer node/slot) and destination // (consumer node/slot) tensors in the input graph to _Arg numbers in // the subgraph. The source map is one-to-one, whereas the dest map may be // many-to-one. - std::unordered_map args_by_src; - std::unordered_map args_by_dst; + std::unordered_map args_by_src_; + std::unordered_map args_by_dst_; // The _Arg nodes in the subgraph, in order by argument number. - std::vector args; + std::vector args_; // Map from source tensor in the input graph to result #. - std::unordered_map results; + std::unordered_map results_; + + // The outside_compilation clusters in this subgraph. + std::unordered_map + outside_compilation_subgraphs_; + + // NoOp node in the output graph that is sequenced after the call node and + // used to prevent host-side outside_compilation sends and recvs from being + // pruned. + Node* sequencer_ = nullptr; }; - // Builds a ParallelCheck op that compares the output of the original subgraph - // with the encapsulated subgraph. - Status BuildParallelCheckOp( + // Returns the key attribute and outside_compilation attribute associated + // with a node in attr, and outside_compilation_attr, respectively. Sets + // either result to the empty string if the respective attribute is not + // found. Returns error status if there is an outside_compilation attribute + // and no key attribute, + Status GetFunctionNameAttr(Node const* node, string* attr, + string* outside_compilation_attr) const; + + // Copies edges local to a subgraph. Adds _Arg and _Retval nodes to + // subgraphs for data edges that cross subgraph boundaries. + Status CopySubgraphEdges( const std::unordered_map& node_images, - const Subgraph& subgraph, Graph* graph_out, Node** parallel_check_op); + std::vector>* src_arg_pairs); + + // Copies all marked nodes to a subgraph. Does nothing for unmarked nodes, + // or nodes marked outside_compilation. + Status CopySubgraphNodes(std::unordered_map* node_images); + + // Copies all nodes that aren't in a compiled subgraph to the output graph. + Status CopyNodesToOutputGraph( + bool parallel_checking, Graph* graph_out, + std::unordered_map* node_images); + + // Adds function call nodes for each compiled subgraph. + Status AddFunctionCallNodes( + const std::unordered_map& node_images, + bool parallel_checking, Graph* graph_out); + + // Adds _RecvAtHost and _SendFromHost nodes, where needed, for all + // outside_compilation subgraphs. + Status AddOutsideCompilationHostIONodes( + const std::unordered_map& node_images, + Graph* graph_out); + + // Finds the image of an edge source in the output graph. If the edge crosses + // a subgraph boundary it is the output of a call node, otherwise it is a node + // in the output graph. + Status FindOutputImageOfEdgeSrc( + const string& src_func_id, const string& src_outside_compilation_id, + const string& dst_func_id, const string& dst_outside_compilation_id, + const std::unordered_map& node_images, + const Node* original_src_node, Node** src_image); + + // Finds an edge source slot in the output graph. If the edge crosses a + // subgraph boundary it is a slot on the output of a call node or a + // _RecvAtHost node, otherwise it is a slot on a node in the output graph. + int FindOutputSlotOfEdgeSrc(const string& src_func_id, + const string& src_outside_compilation_id, + const string& dst_func_id, + const string& dst_outside_compilation_id, + const Edge* edge); + + // Finds the image of an edge destination in the output graph. If the edge + // crosses a subgraph boundary it is the input of a call node or a + // _SendFromHost node, otherwise it is a node in the output graph. + Status FindOutputImageOfEdgeDst( + const string& src_func_id, const string& src_outside_compilation_id, + const string& dst_func_id, const string& dst_outside_compilation_id, + const std::unordered_map& node_images, + const Node* original_dst_node, Node** dst_image); + + // Finds an edge destination slot in the output graph. If the edge crosses a + // subgraph boundary it is a slot on the input of a call node or a + // _SendFromHost node, otherwise it is a slot on a node in the output graph. + int FindOutputSlotOfEdgeDst(const string& src_func_id, + const string& src_outside_compilation_id, + const string& dst_func_id, + const string& dst_outside_compilation_id, + const Edge* edge); + + // Copies a single edge to the output graph. The edge is either entirely + // within the output graph, or crosses into or out of a compiled subgraph. + Status CopyEdgeToOutputGraph( + const Edge* edge, const string& src_func_id, + const string& src_outside_compilation_id, const string& dst_func_id, + const string& dst_outside_compilation_id, + const std::unordered_map& node_images, + bool parallel_checking, Graph* graph_out, + std::unordered_set, NodeSlot::PairHasher>* + edges_added); + + // Adds all edges to the output graph. + Status AddEdgesToOutputGraph( + const std::unordered_map& node_images, + bool parallel_checking, Graph* graph_out); const string group_attribute_; + const string outside_compilation_attribute_; const Graph* graph_in_; std::unordered_map subgraphs_; @@ -154,224 +525,370 @@ class Encapsulator { TF_DISALLOW_COPY_AND_ASSIGN(Encapsulator); }; -// TODO(phawkins) add a canonical copy of these operator names and refactor -// everything to use it. -static const char* const kArgOp = "_Arg"; -static const char* const kRetValOp = "_Retval"; +Node* Encapsulator::Subgraph::GetCallNodeForInputs() const { + return call_node_inputs_; +} -// Returns the function name attached to 'node', or the empty string if there is -// none. -string Encapsulator::GetFunctionNameAttr(Node const* node) const { - string attr; - if (!GetNodeAttr(node->attrs(), group_attribute_, &attr).ok()) { - attr.clear(); - } - return attr; +Node* Encapsulator::Subgraph::GetCallNodeForOutputs() const { + return call_node_outputs_; } -Status Encapsulator::SplitIntoSubgraphs() { - Status s; +int Encapsulator::Subgraph::GetArgIndexForEdge(const Edge* edge) const { + return args_by_dst_.at(NodeSlot(edge->dst(), edge->dst_input())); +} - // Map from input graph nodes to subgraph nodes. - std::unordered_map node_images; +int Encapsulator::Subgraph::GetResultIndexForEdge(const Edge* edge) const { + return results_.at(NodeSlot(edge->src(), edge->src_output())); +} - // Copy all marked nodes to a subgraph. Do nothing for unmarked nodes. - for (Node* node : graph_in_->op_nodes()) { - string func_id = GetFunctionNameAttr(node); - if (func_id.empty()) continue; +Node* Encapsulator::Subgraph::GetRecvAtHostNode( + const string& outside_compilation_subgraph_name) const { + return outside_compilation_subgraphs_.at(outside_compilation_subgraph_name) + .recv_at_host; +} - Subgraph& subgraph = subgraphs_[func_id]; - if (!subgraph.graph) { - subgraph.graph.reset(new Graph(graph_in_->op_registry())); - subgraph.graph->set_versions(graph_in_->versions()); - } +int Encapsulator::Subgraph::GetRecvAtHostSlot( + const string& outside_compilation_subgraph_name, const Edge* edge) const { + return outside_compilation_subgraphs_.at(outside_compilation_subgraph_name) + .inputs.at(NodeSlot(edge->src(), edge->src_output())); +} - Node* image = subgraph.graph->CopyNode(node); - image->ClearAttr(group_attribute_); - node_images[node] = image; +Node* Encapsulator::Subgraph::GetSendFromHostNode( + const string& outside_compilation_subgraph_name) const { + return outside_compilation_subgraphs_.at(outside_compilation_subgraph_name) + .send_from_host; +} - if (subgraph.device.empty()) { - subgraph.device = node->assigned_device_name().empty() - ? node->requested_device() - : node->assigned_device_name(); - } +int Encapsulator::Subgraph::GetSendFromHostSlot( + const string& outside_compilation_subgraph_name, const Edge* edge) const { + return outside_compilation_subgraphs_.at(outside_compilation_subgraph_name) + .outputs_by_dst.at(NodeSlot(edge->dst(), edge->dst_input())); +} + +Node* Encapsulator::Subgraph::MakeNodeImage(const Graph* graph_in, Node* node) { + if (!graph_) { + graph_.reset(new Graph(graph_in->op_registry())); + graph_->set_versions(graph_in->versions()); } - // Copy edges local to a subgraph. Add _Arg and _Retval nodes to subgraphs for - // data edges that cross subgraph boundaries. - for (const Edge* edge : graph_in_->edges()) { - string src_func_id = GetFunctionNameAttr(edge->src()); - string dst_func_id = GetFunctionNameAttr(edge->dst()); - Node* src_image = gtl::FindWithDefault(node_images, edge->src(), nullptr); - Node* dst_image = gtl::FindWithDefault(node_images, edge->dst(), nullptr); + if (device_.empty()) { + device_ = node->assigned_device_name().empty() + ? node->requested_device() + : node->assigned_device_name(); + } - // Copy edges that are local to a subgraph. - if (!src_func_id.empty() && src_func_id == dst_func_id) { - Graph* g = subgraphs_[src_func_id].graph.get(); - if (edge->IsControlEdge()) { - g->AddControlEdge(src_image, dst_image); - } else { - g->AddEdge(src_image, edge->src_output(), dst_image, edge->dst_input()); - } - continue; - } + return graph_->CopyNode(node); +} - // Ignore cross-boundary control edges for right now. We will lift them - // onto the enclosing call operators in BuildOutputGraph(). - if (edge->IsControlEdge()) continue; +Graph* Encapsulator::Subgraph::GetGraph() const { return graph_.get(); } + +Status Encapsulator::Subgraph::RecordArg( + const Edge* edge, const std::unordered_map& node_images, + std::vector>* src_arg_pairs) { + Node* src_node = edge->src(); + int src_slot = edge->src_output(); + std::unordered_map::iterator iter; + bool inserted; + std::tie(iter, inserted) = + args_by_src_.emplace(NodeSlot(src_node, src_slot), args_by_src_.size()); + int arg_index = iter->second; + if (inserted) { + NodeDef arg_def; + NodeDefBuilder builder( + strings::StrCat(src_node->name(), "_", src_slot, "_arg"), kArgOp); + DataType dtype = edge->dst()->input_type(edge->dst_input()); + builder.Attr("T", dtype); + builder.Attr("index", arg_index); + Status s = builder.Finalize(&arg_def); + if (!s.ok()) return s; - // Add 'src' as an output of its subgraph, if applicable. - if (!src_func_id.empty()) { - Subgraph& src_subgraph = subgraphs_[src_func_id]; - int ret_index = src_subgraph.results.size(); - if (src_subgraph.results - .emplace(NodeSlot(edge->src(), edge->src_output()), ret_index) - .second) { - // Create a new _Retval node - DataType dtype = edge->src()->output_type(edge->src_output()); + Node* arg = graph_->AddNode(arg_def, &s); + if (!s.ok()) return s; - if (IsRefType(dtype)) { - return errors::InvalidArgument( - "Ref Tensors (e.g., Variables) are not supported: tensor ", - edge->src()->name(), ":", edge->src_output()); - } + src_arg_pairs->push_back({src_node, arg}); + args_.push_back(arg); + } + Node* dst_node = edge->dst(); + Node* dst_image = node_images.at(dst_node); + int dst_slot = edge->dst_input(); + args_by_dst_[NodeSlot(dst_node, dst_slot)] = arg_index; + graph_->AddEdge(args_[arg_index], 0, dst_image, dst_slot); + return Status::OK(); +} - NodeDef ret_def; - ret_def.set_op(kRetValOp); - ret_def.set_name(strings::StrCat(edge->src()->name(), "_", - edge->src_output(), "_retval")); - AddNodeAttr("T", dtype, &ret_def); - AddNodeAttr("index", ret_index, &ret_def); - Node* ret = src_subgraph.graph->AddNode(ret_def, &s); - if (!s.ok()) return s; - - // Add an edge from 'src' to _Retval. - src_subgraph.graph->AddEdge(src_image, edge->src_output(), ret, 0); - } - } +Status Encapsulator::Subgraph::RecordResult( + const Edge* edge, + const std::unordered_map& node_images) { + Node* src_node = edge->src(); + Node* src_image = node_images.at(src_node); + int src_slot = edge->src_output(); + std::unordered_map::iterator iter; + bool inserted; + std::tie(iter, inserted) = + results_.emplace(NodeSlot(src_node, src_slot), results_.size()); + int ret_index = iter->second; + if (inserted) { + NodeDef ret_def; + NodeDefBuilder builder( + strings::StrCat(src_node->name(), "_", src_slot, "_retval"), kRetValOp); + DataType dtype = src_node->output_type(src_slot); + builder.Attr("T", dtype); + builder.Attr("index", ret_index); + builder.Input(src_image->name(), src_slot, dtype); + Status s = builder.Finalize(&ret_def); + if (!s.ok()) return s; + Node* ret = graph_->AddNode(ret_def, &s); + if (!s.ok()) return s; - // Add 'dst' as an input of its subgraph, if applicable. - if (!dst_func_id.empty()) { - Subgraph& dst_subgraph = subgraphs_[dst_func_id]; + graph_->AddEdge(src_image, src_slot, ret, 0); + } + return Status::OK(); +} - // Create an _Arg node for this tensor, if none exists yet. - std::unordered_map::iterator iter; - bool inserted; - std::tie(iter, inserted) = dst_subgraph.args_by_src.emplace( - NodeSlot(edge->src(), edge->src_output()), dst_subgraph.args.size()); - int arg_index = iter->second; - if (inserted) { - // This is the first time we have seen this tensor. Create an _Arg node. - DataType dtype = edge->dst()->input_type(edge->dst_input()); +void Encapsulator::Subgraph::RecordOutsideCompilationInputOrControl( + const string& outside_compilation_id, const Edge* edge) { + auto iter = outside_compilation_subgraphs_ + .emplace(outside_compilation_id, OutsideCompilationSubgraph()) + .first; + OutsideCompilationSubgraph& outside_subgraph = iter->second; + if (edge->IsControlEdge()) { + outside_subgraph.control_inputs.insert(edge->src()); + } else { + int input_index = outside_subgraph.inputs.size(); + outside_subgraph.inputs.emplace(NodeSlot(edge->src(), edge->src_output()), + input_index); + } +} - if (IsRefType(dtype)) { - return errors::InvalidArgument( - "Ref Tensors (e.g., Variables) are not supported: tensor ", - edge->src()->name(), ":", edge->src_output()); - } +void Encapsulator::Subgraph::RecordOutsideCompilationOutputOrControl( + const string& outside_compilation_id, const Edge* edge) { + auto subgraph_iter = + outside_compilation_subgraphs_ + .emplace(outside_compilation_id, OutsideCompilationSubgraph()) + .first; + OutsideCompilationSubgraph& outside_subgraph = subgraph_iter->second; + if (edge->IsControlEdge()) { + outside_subgraph.control_outputs.insert(edge->dst()); + } else { + DataType dtype = edge->dst()->input_type(edge->dst_input()); + auto output_iter = + outside_subgraph.outputs_by_src + .emplace(NodeSlot(edge->src(), edge->src_output(), dtype), + outside_subgraph.outputs_by_src.size()) + .first; + int output_index = output_iter->second; + outside_subgraph.outputs_by_dst[NodeSlot(edge->dst(), edge->dst_input())] = + output_index; + } +} - NodeDef arg_def; - NodeDefBuilder builder(strings::StrCat(edge->src()->name(), "_", - edge->src_output(), "_arg"), - kArgOp); - builder.Attr("T", dtype); - builder.Attr("index", arg_index); - s = builder.Finalize(&arg_def); - if (!s.ok()) return s; +Status Encapsulator::Subgraph::AddSendsToOutsideCompilation( + const std::unordered_map& node_images) { + for (auto& oc_subgraph_iter : outside_compilation_subgraphs_) { + const string& oc_subgraph_name = oc_subgraph_iter.first; + OutsideCompilationSubgraph& oc_subgraph = oc_subgraph_iter.second; + if (!oc_subgraph.inputs.empty() || !oc_subgraph.control_inputs.empty()) { + // Build a _SendToHost node sending all the args of the appropriate + // types. + std::vector dtypes(oc_subgraph.inputs.size(), DT_INVALID); + std::vector inputs(oc_subgraph.inputs.size()); + + for (const auto& input_src : oc_subgraph.inputs) { + const Node* src_node = input_src.first.node; + Node* src_image = node_images.at(src_node); + int src_slot = input_src.first.slot; + int input_index = input_src.second; + + DataType dtype = src_node->output_type(src_slot); + dtypes[input_index] = dtype; + inputs[input_index].Reset(src_image->name(), src_slot, dtype); + } - Node* arg = dst_subgraph.graph->AddNode(arg_def, &s); - if (!s.ok()) return s; + NodeDef send_def; + NodeDefBuilder builder( + strings::StrCat("outside_compilation_", oc_subgraph_name, "_send"), + kSendToHostOp); + builder.Attr("dtypes", dtypes); + builder.Input(inputs); + Status s = builder.Finalize(&send_def); + if (!s.ok()) return s; + + oc_subgraph.send_to_host = graph_->AddNode(send_def, &s); + if (!s.ok()) return s; + + // Connect the _SendToHost node to its producers in the subgraph. + for (auto& input_src : oc_subgraph.inputs) { + const Node* src_node = input_src.first.node; + Node* src_image = node_images.at(src_node); + int src_slot = input_src.first.slot; + int input_index = input_src.second; + graph_->AddEdge(src_image, src_slot, oc_subgraph.send_to_host, + input_index); + } - dst_subgraph.args.push_back(arg); + // Connect the _SendToHost node to its control edge producers in the + // subgraph. + for (const auto& src_node : oc_subgraph.control_inputs) { + Node* src_image = node_images.at(src_node); + graph_->AddControlEdge(src_image, oc_subgraph.send_to_host); } - // Add an edge from the _Arg node to 'dst' in the subgraph. - dst_subgraph.args_by_dst[NodeSlot(edge->dst(), edge->dst_input())] = - arg_index; - dst_subgraph.graph->AddEdge(dst_subgraph.args[arg_index], 0, dst_image, - edge->dst_input()); } } - for (auto& entry : subgraphs_) { - FixupSourceAndSinkEdges(entry.second.graph.get()); - } - - return s; + return Status::OK(); } -Status Encapsulator::BuildFunctionDefs( - const RewriteSubgraphFn& rewrite_subgraph_fn, bool reuse_existing_functions, - FunctionLibraryDefinition* library) { - // For each subgraph, build a FunctionDef. - for (auto& subgraph_entry : subgraphs_) { - string name = subgraph_entry.first; - Subgraph& subgraph = subgraph_entry.second; - - subgraph.call_node_def.set_op(name); - subgraph.call_node_def.set_name(name); - subgraph.call_node_def.set_device(subgraph.device); - - if (rewrite_subgraph_fn) { - // Initialize the input and output permutations to the identity. - std::vector input_permutation(subgraph.args_by_src.size()); - std::iota(input_permutation.begin(), input_permutation.end(), 0); - std::vector output_permutation(subgraph.results.size()); - std::iota(output_permutation.begin(), output_permutation.end(), 0); - - TF_RETURN_IF_ERROR( - rewrite_subgraph_fn(&subgraph.graph, &input_permutation, - &output_permutation, &subgraph.call_node_def)); - - // Apply the input/output permutations to the 'args_by_...' and 'results' - // mappings in 'subgraph', so when we build edges in BuildOutputGraph() we - // connect them to the right input/output positions. - if (input_permutation.size() != subgraph.args_by_src.size()) { - return errors::InvalidArgument("Input permutation has incorrect size."); +Status Encapsulator::Subgraph::AddRecvsFromOutsideCompilation( + const std::unordered_map& node_images) { + for (auto& oc_subgraph_iter : outside_compilation_subgraphs_) { + const string& oc_subgraph_name = oc_subgraph_iter.first; + OutsideCompilationSubgraph& oc_subgraph = oc_subgraph_iter.second; + if (!oc_subgraph.outputs_by_src.empty() || + !oc_subgraph.control_outputs.empty()) { + // Build a _RecvFromHost node producing all the outputs of the appropriate + // types. + std::vector dtypes(oc_subgraph.outputs_by_src.size(), + DT_INVALID); + + for (const auto& output : oc_subgraph.outputs_by_src) { + DataType dtype = output.first.dtype; + int output_index = output.second; + dtypes[output_index] = dtype; } - if (output_permutation.size() != subgraph.results.size()) { - return errors::InvalidArgument( - "Output permutation has incorrect size."); - } - for (auto& arg : subgraph.args_by_src) { - arg.second = input_permutation[arg.second]; - } - for (auto& arg : subgraph.args_by_dst) { - arg.second = input_permutation[arg.second]; + + NodeDef recv_def; + NodeDefBuilder builder( + strings::StrCat("outside_compilation_", oc_subgraph_name, "_recv"), + kRecvFromHostOp); + builder.Attr("dtypes", dtypes); + Status s = builder.Finalize(&recv_def); + if (!s.ok()) return s; + + Node* recv = graph_->AddNode(recv_def, &s); + if (!s.ok()) return s; + + // Connect the consumers in the subgraph to the _RecvFromHost node. + for (const auto& output : oc_subgraph.outputs_by_dst) { + const Node* dst_node = output.first.node; + Node* dst_image = node_images.at(dst_node); + int dst_slot = output.first.slot; + int output_index = output.second; + + graph_->AddEdge(recv, output_index, dst_image, dst_slot); } - for (auto& result : subgraph.results) { - result.second = output_permutation[result.second]; + + // Connect the control edge consumers in the subgraph to the _RecvFromHost + // node. + for (const auto& dst_node : oc_subgraph.control_outputs) { + Node* dst_image = node_images.at(dst_node); + graph_->AddControlEdge(recv, dst_image); } - name = subgraph.call_node_def.op(); + // Add a control edge in the subgraph so that the _SendToHost node, if + // any, is compiled before the _RecvFromHost node. + if (oc_subgraph.send_to_host != nullptr) { + graph_->AddControlEdge(oc_subgraph.send_to_host, recv); + } } + } + + return Status::OK(); +} + +Status Encapsulator::Subgraph::MakeSequencingNode(const string& subgraph_name, + Graph* graph_out) { + if (sequencer_ == nullptr) { + NodeDef seq_def; + NodeDefBuilder builder(strings::StrCat(subgraph_name, "_sequencer"), + "NoOp"); + Status s = builder.Finalize(&seq_def); + if (!s.ok()) return s; - FunctionDef fdef; - TF_RETURN_IF_ERROR(GraphToFunctionDef(*subgraph.graph, name, &fdef)); + sequencer_ = graph_out->AddNode(seq_def, &s); + if (!s.ok()) return s; + sequencer_->set_assigned_device_name(device_); + } + return Status::OK(); +} - if (VLOG_IS_ON(1)) { - VLOG(2) << "Build function def " << name; - dump_graph::DumpGraphToFile( - strings::StrCat("encapsulate_fdef_graph_", name), *subgraph.graph, - library); - dump_graph::DumpFunctionDefToFile( - strings::StrCat("encapsulate_fdef_", name), fdef); +void Encapsulator::Subgraph::ConnectSequencerToOutputs(Graph* graph_out) { + if (sequencer_ != nullptr) { + std::unordered_set output_dependencies; + for (Node* node : call_node_outputs_->out_nodes()) { + output_dependencies.insert(node); + } + for (Node* node : output_dependencies) { + graph_out->AddControlEdge(sequencer_, node); } + } +} - if (!reuse_existing_functions || library->Find(name) == nullptr) { - TF_RETURN_IF_ERROR(library->AddFunctionDef(fdef)); +Status Encapsulator::Subgraph::BuildFunctionDef( + const string& name_in, const RewriteSubgraphFn& rewrite_subgraph_fn, + bool reuse_existing_functions, FunctionLibraryDefinition* library) { + // name_in is copied here because name may be modified below if + // rewrite_subgraph_fn is true. + string name = name_in; + call_node_def_.set_op(name); + call_node_def_.set_name(name); + call_node_def_.set_device(device_); + + if (rewrite_subgraph_fn) { + // Initialize the input and output permutations to the identity. + std::vector input_permutation(args_by_src_.size()); + std::iota(input_permutation.begin(), input_permutation.end(), 0); + std::vector output_permutation(results_.size()); + std::iota(output_permutation.begin(), output_permutation.end(), 0); + + TF_RETURN_IF_ERROR(rewrite_subgraph_fn( + &graph_, &input_permutation, &output_permutation, &call_node_def_)); + + // Apply the input/output permutations to the 'args_by_...' and 'results_' + // mappings, so when we build edges in BuildOutputGraph() we + // connect them to the right input/output positions. + if (input_permutation.size() != args_by_src_.size()) { + return errors::InvalidArgument("Input permutation has incorrect size."); + } + if (output_permutation.size() != results_.size()) { + return errors::InvalidArgument("Output permutation has incorrect size."); + } + for (auto& arg : args_by_src_) { + arg.second = input_permutation[arg.second]; + } + for (auto& arg : args_by_dst_) { + arg.second = input_permutation[arg.second]; } + for (auto& result : results_) { + result.second = output_permutation[result.second]; + } + + name = call_node_def_.op(); + } + + FunctionDef fdef; + TF_RETURN_IF_ERROR(GraphToFunctionDef(*graph_, name, &fdef)); + + if (VLOG_IS_ON(1)) { + VLOG(2) << "Build function def " << name; + dump_graph::DumpGraphToFile( + strings::StrCat("encapsulate_fdef_graph_", name), *graph_, library); + dump_graph::DumpFunctionDefToFile( + strings::StrCat("encapsulate_fdef_", name), fdef); + } + + if (!reuse_existing_functions || library->Find(name) == nullptr) { + TF_RETURN_IF_ERROR(library->AddFunctionDef(fdef)); } return Status::OK(); } -Status Encapsulator::BuildParallelCheckOp( +Status Encapsulator::Subgraph::BuildParallelCheckOp( const std::unordered_map& node_images, - const Encapsulator::Subgraph& subgraph, Graph* graph_out, - Node** parallel_check_op) { + Graph* graph_out) { // Build an index mapping output positions to node/slot pairs in the // original graph. - std::vector results_by_num(subgraph.results.size()); - for (const auto& entry : subgraph.results) { + std::vector results_by_num(results_.size()); + for (const auto& entry : results_) { results_by_num[entry.second] = entry.first; } @@ -386,22 +903,22 @@ Status Encapsulator::BuildParallelCheckOp( expected_outputs[i] = NodeDefBuilder::NodeOut(node_images.at(node_slot.node)->name(), node_slot.slot, result_dtypes[i]); - actual_outputs[i] = NodeDefBuilder::NodeOut(subgraph.call_node_def.name(), - i, result_dtypes[i]); + actual_outputs[i] = + NodeDefBuilder::NodeOut(call_node_def_.name(), i, result_dtypes[i]); } // Assign the parallel check op to a CPU on the same task as the cluster it is // checking. string device, dummy; if (!DeviceNameUtils::SplitDeviceName( - subgraph.call_node_inputs->assigned_device_name(), &device, &dummy)) { + call_node_inputs_->assigned_device_name(), &device, &dummy)) { return errors::InvalidArgument("Could not parse device name"); } strings::StrAppend(&device, "/cpu:0"); NodeDef check_def; TF_RETURN_IF_ERROR( - NodeDefBuilder(graph_out->NewName(strings::StrCat( - subgraph.call_node_def.name(), "_parallel_check")), + NodeDefBuilder(graph_out->NewName(strings::StrCat(call_node_def_.name(), + "_parallel_check")), "ParallelCheck") .Device(device) .Attr("T", result_dtypes) @@ -421,65 +938,548 @@ Status Encapsulator::BuildParallelCheckOp( const NodeSlot& node_slot = results_by_num[i]; graph_out->AddEdge(node_images.at(node_slot.node), node_slot.slot, check_op, i); - graph_out->AddEdge(subgraph.call_node_inputs, i, check_op, num_results + i); + graph_out->AddEdge(call_node_inputs_, i, check_op, num_results + i); } - *parallel_check_op = check_op; + call_node_outputs_ = check_op; return Status::OK(); } -Status Encapsulator::BuildOutputGraph(bool parallel_checking, - Graph* graph_out) { +Status Encapsulator::Subgraph::AddFunctionCallNode( + const std::unordered_map& node_images, + bool parallel_checking, Graph* graph_out) { Status s; + call_node_inputs_ = graph_out->AddNode(call_node_def_, &s); + if (!s.ok()) return s; - // Map from nodes in the input graph to nodes in the output graph. + // Copy the assigned device and the key_annotation over. + call_node_inputs_->set_assigned_device_name(device_); + call_node_outputs_ = call_node_inputs_; + + if (parallel_checking) { + TF_RETURN_IF_ERROR(BuildParallelCheckOp(node_images, graph_out)); + } + return Status::OK(); +} + +Status Encapsulator::Subgraph::AddRecvAtHostNode( + const string& subgraph_name, const string& oc_subgraph_name, + OutsideCompilationSubgraph* oc_subgraph, Graph* graph_out) { + std::vector dtypes(oc_subgraph->inputs.size(), DT_INVALID); + + for (const auto& input : oc_subgraph->inputs) { + const Node* src_node = input.first.node; + int src_slot = input.first.slot; + int input_index = input.second; + + DataType dtype = src_node->output_type(src_slot); + dtypes[input_index] = dtype; + } + + NodeDef recv_def; + NodeDefBuilder builder(strings::StrCat("outside_compilation_", subgraph_name, + "_", oc_subgraph_name, "_recv"), + kRecvAtHostOp); + builder.Attr("dtypes", dtypes); + Status s = builder.Finalize(&recv_def); + if (!s.ok()) return s; + + oc_subgraph->recv_at_host = graph_out->AddNode(recv_def, &s); + if (!s.ok()) return s; + oc_subgraph->recv_at_host->set_assigned_device_name(device_); + + // Add a control dependency forcing the RecvAtHost to run before the subgraph + // completes. This has no effect on execution order but prevents the + // RecvAtHost being pruned. + TF_RETURN_IF_ERROR(MakeSequencingNode(subgraph_name, graph_out)); + graph_out->AddControlEdge(oc_subgraph->recv_at_host, sequencer_); + + return Status::OK(); +} + +Status Encapsulator::Subgraph::AddSendFromHostNode( + const std::unordered_map& node_images, + const string& subgraph_name, const string& oc_subgraph_name, + OutsideCompilationSubgraph* oc_subgraph, Graph* graph_out) { + std::vector dtypes(oc_subgraph->outputs_by_src.size(), DT_INVALID); + std::vector inputs( + oc_subgraph->outputs_by_src.size()); + + for (const auto& output : oc_subgraph->outputs_by_src) { + const Node* src_node = output.first.node; + Node* src_image = node_images.at(src_node); + int src_slot = output.first.slot; + int output_index = output.second; + + DataType dtype = src_node->output_type(src_slot); + dtypes[output_index] = dtype; + inputs[output_index].Reset(src_image->name(), src_slot, dtype); + } + + NodeDef send_def; + NodeDefBuilder builder(strings::StrCat("outside_compilation_", subgraph_name, + "_", oc_subgraph_name, "_send"), + kSendFromHostOp); + builder.Attr("dtypes", dtypes); + builder.Input(inputs); + Status s = builder.Finalize(&send_def); + if (!s.ok()) return s; + + oc_subgraph->send_from_host = graph_out->AddNode(send_def, &s); + if (!s.ok()) return s; + oc_subgraph->send_from_host->set_assigned_device_name(device_); + + // Add a control dependency forcing the SendFromHost to run before the + // subgraph completes. This has no effect on execution order but prevents the + // RecvAtHost being pruned. + TF_RETURN_IF_ERROR(MakeSequencingNode(subgraph_name, graph_out)); + graph_out->AddControlEdge(oc_subgraph->send_from_host, sequencer_); + + return Status::OK(); +} + +Status Encapsulator::Subgraph::AddOutsideCompilationHostIONodes( + const string& subgraph_name, + const std::unordered_map& node_images, + Graph* graph_out) { + for (auto& outside_compilation_subgraph_entry : + outside_compilation_subgraphs_) { + const string& oc_name = outside_compilation_subgraph_entry.first; + OutsideCompilationSubgraph& oc_subgraph = + outside_compilation_subgraph_entry.second; + + if (!oc_subgraph.inputs.empty() || !oc_subgraph.control_inputs.empty()) { + TF_RETURN_IF_ERROR( + AddRecvAtHostNode(subgraph_name, oc_name, &oc_subgraph, graph_out)); + } + + if (!oc_subgraph.outputs_by_src.empty() || + !oc_subgraph.control_outputs.empty()) { + TF_RETURN_IF_ERROR(AddSendFromHostNode(node_images, subgraph_name, + oc_name, &oc_subgraph, graph_out)); + } + } + return Status::OK(); +} + +Status Encapsulator::GetFunctionNameAttr( + Node const* node, string* attr, string* outside_compilation_attr) const { + Status s = GetNodeAttr(node->attrs(), group_attribute_, attr); + if (s.code() == error::Code::NOT_FOUND) { + // Return empty attr if there's no group_attribute. + attr->clear(); + } else { + TF_RETURN_IF_ERROR(s); + } + bool has_group_attr = s.ok(); + s = GetNodeAttr(node->attrs(), outside_compilation_attribute_, + outside_compilation_attr); + if (s.code() == error::Code::NOT_FOUND) { + // Return empty attr if there's no outside_compilation attribute. + outside_compilation_attr->clear(); + } else { + TF_RETURN_IF_ERROR(s); + if (!has_group_attr) { + return errors::InvalidArgument( + "Node ", node->name(), " has ", outside_compilation_attribute_, + " attribute but no ", group_attribute_, " attribute."); + } + } + return Status::OK(); +} + +bool IsInSubgraph(const string& func_id, const string& outside_compilation_id) { + return !func_id.empty() && outside_compilation_id.empty(); +} + +Status Encapsulator::CopySubgraphNodes( + std::unordered_map* node_images) { + for (Node* node : graph_in_->op_nodes()) { + string func_id; + string outside_compilation_id; + TF_RETURN_IF_ERROR( + GetFunctionNameAttr(node, &func_id, &outside_compilation_id)); + if (!IsInSubgraph(func_id, outside_compilation_id)) continue; + + Subgraph& subgraph = subgraphs_[func_id]; + Node* image = subgraph.MakeNodeImage(graph_in_, node); + image->ClearAttr(group_attribute_); + (*node_images)[node] = image; + } + return Status::OK(); +} + +Status Encapsulator::CopySubgraphEdges( + const std::unordered_map& node_images, + std::vector>* src_arg_pairs) { + for (const Edge* edge : graph_in_->edges()) { + string src_func_id; + string src_outside_compilation_id; + TF_RETURN_IF_ERROR(GetFunctionNameAttr(edge->src(), &src_func_id, + &src_outside_compilation_id)); + string dst_func_id; + string dst_outside_compilation_id; + TF_RETURN_IF_ERROR(GetFunctionNameAttr(edge->dst(), &dst_func_id, + &dst_outside_compilation_id)); + Node* src_image = gtl::FindWithDefault(node_images, edge->src(), nullptr); + Node* dst_image = gtl::FindWithDefault(node_images, edge->dst(), nullptr); + + // Copy edges that are local to a subgraph. + if (IsInSubgraph(src_func_id, src_outside_compilation_id) && + IsInSubgraph(dst_func_id, dst_outside_compilation_id) && + src_func_id == dst_func_id) { + Graph* g = subgraphs_[src_func_id].GetGraph(); + if (edge->IsControlEdge()) { + g->AddControlEdge(src_image, dst_image); + } else { + g->AddEdge(src_image, edge->src_output(), dst_image, edge->dst_input()); + } + continue; + } + + // Record 'src' as an output of its subgraph, if applicable. + if (IsInSubgraph(src_func_id, src_outside_compilation_id)) { + if (!edge->IsControlEdge()) { + DataType dtype = edge->src()->output_type(edge->src_output()); + if (IsRefType(dtype)) { + return errors::InvalidArgument( + "Ref Tensors (e.g., Variables) are not supported as results: " + "tensor ", + edge->src()->name(), ":", edge->src_output()); + } + } + + Subgraph& src_subgraph = subgraphs_[src_func_id]; + if (src_func_id == dst_func_id) { + // src is in the subgraph and dst is outside_compilation in the same + // subgraph. + src_subgraph.RecordOutsideCompilationInputOrControl( + dst_outside_compilation_id, edge); + } else { + // Ignore control edges leaving the subgraph. We will lift them onto the + // enclosing call operators in BuildOutputGraph(). + if (!edge->IsControlEdge()) { + TF_RETURN_IF_ERROR(src_subgraph.RecordResult(edge, node_images)); + } + } + } + + // Record 'dst' as an input of its subgraph, if applicable. + if (IsInSubgraph(dst_func_id, dst_outside_compilation_id)) { + // Look at the type of the destination not the source, since Ref output + // Tensors can be automatically cast to non-Ref Tensors at the + // destination. + if (!edge->IsControlEdge()) { + DataType dtype = edge->dst()->input_type(edge->dst_input()); + if (IsRefType(dtype)) { + return errors::InvalidArgument( + "Ref Tensors (e.g., Variables) are not supported as args: " + "tensor ", + edge->src()->name(), ":", edge->src_output()); + } + } + + Subgraph& dst_subgraph = subgraphs_[dst_func_id]; + if (src_func_id == dst_func_id) { + // dst is in the subgraph and src is outside_compilation in the same + // subgraph. + dst_subgraph.RecordOutsideCompilationOutputOrControl( + src_outside_compilation_id, edge); + } else { + // Ignore control edges entering the subgraph. We will lift them onto + // the enclosing call operators in BuildOutputGraph(). + if (!edge->IsControlEdge()) { + TF_RETURN_IF_ERROR( + dst_subgraph.RecordArg(edge, node_images, src_arg_pairs)); + } + } + } + } + return Status::OK(); +} + +Status Encapsulator::SplitIntoSubgraphs() { + Status s; + + // Map from input graph nodes to subgraph nodes. std::unordered_map node_images; - // Copy all unmarked nodes to the output graph. + // Each entry of src_arg_pairs is a pair whose first element is a node in the + // original graph that has an output edge in the subgraph, and whose second + // element is the arg node in the subgraph that it sends to. The vector will + // be filled in below in AddArgs. + std::vector> src_arg_pairs; + + TF_RETURN_IF_ERROR(CopySubgraphNodes(&node_images)); + TF_RETURN_IF_ERROR(CopySubgraphEdges(node_images, &src_arg_pairs)); + + // For each subgraph, add the nodes that deal with inputs and outputs its + // nested outside_compilation subgraphs. These could not be added earlier + // during CopySubgraphEdges since we need to discover all the types of the + // inputs and outputs for an outside_compilation subgraph before creating a + // single input and output node for it. + for (auto& entry : subgraphs_) { + Subgraph& subgraph = entry.second; + TF_RETURN_IF_ERROR(subgraph.AddSendsToOutsideCompilation(node_images)); + TF_RETURN_IF_ERROR(subgraph.AddRecvsFromOutsideCompilation(node_images)); + } + + MarkGuaranteedConstants(*graph_in_, src_arg_pairs); + + for (auto& entry : subgraphs_) { + Subgraph& subgraph = entry.second; + FixupSourceAndSinkEdges(subgraph.GetGraph()); + } + + return s; +} + +Status Encapsulator::BuildFunctionDefs( + const RewriteSubgraphFn& rewrite_subgraph_fn, bool reuse_existing_functions, + FunctionLibraryDefinition* library) { + for (auto& subgraph_entry : subgraphs_) { + string name = subgraph_entry.first; + Subgraph& subgraph = subgraph_entry.second; + TF_RETURN_IF_ERROR(subgraph.BuildFunctionDef( + name, rewrite_subgraph_fn, reuse_existing_functions, library)); + } + return Status::OK(); +} + +Status Encapsulator::CopyNodesToOutputGraph( + bool parallel_checking, Graph* graph_out, + std::unordered_map* node_images) { for (Node* node : graph_in_->op_nodes()) { - string func_id = GetFunctionNameAttr(node); + string func_id; + string outside_compilation_id; + TF_RETURN_IF_ERROR( + GetFunctionNameAttr(node, &func_id, &outside_compilation_id)); // Don't copy nodes that going to be encapsulated, unless parallel checking // is enabled. - if (!func_id.empty() && !parallel_checking) continue; + if (IsInSubgraph(func_id, outside_compilation_id) && !parallel_checking) + continue; Node* image = graph_out->CopyNode(node); - node_images[node] = image; + if (!outside_compilation_id.empty()) { + if (parallel_checking) { + return errors::InvalidArgument( + "Parallel checking is not supported when outside_compilation " + "clusters are present."); + } + image->ClearAttr(group_attribute_); + image->ClearAttr(outside_compilation_attribute_); + } + (*node_images)[node] = image; + } + (*node_images)[graph_in_->source_node()] = graph_out->source_node(); + (*node_images)[graph_in_->sink_node()] = graph_out->sink_node(); + return Status::OK(); +} + +Status Encapsulator::AddFunctionCallNodes( + const std::unordered_map& node_images, + bool parallel_checking, Graph* graph_out) { + for (auto& subgraph_entry : subgraphs_) { + TF_RETURN_IF_ERROR(subgraph_entry.second.AddFunctionCallNode( + node_images, parallel_checking, graph_out)); } - node_images[graph_in_->source_node()] = graph_out->source_node(); - node_images[graph_in_->sink_node()] = graph_out->sink_node(); + return Status::OK(); +} - // Add function call nodes for each subgraph. +Status Encapsulator::AddOutsideCompilationHostIONodes( + const std::unordered_map& node_images, + Graph* graph_out) { for (auto& subgraph_entry : subgraphs_) { + const string& subgraph_name = subgraph_entry.first; Subgraph& subgraph = subgraph_entry.second; + TF_RETURN_IF_ERROR(subgraph.AddOutsideCompilationHostIONodes( + subgraph_name, node_images, graph_out)); + } + return Status::OK(); +} - subgraph.call_node_inputs = graph_out->AddNode(subgraph.call_node_def, &s); - if (!s.ok()) return s; +Status Encapsulator::FindOutputImageOfEdgeSrc( + const string& src_func_id, const string& src_outside_compilation_id, + const string& dst_func_id, const string& dst_outside_compilation_id, + const std::unordered_map& node_images, + const Node* original_src_node, Node** src_image) { + if (IsInSubgraph(src_func_id, src_outside_compilation_id)) { + if (dst_func_id == src_func_id) { + // The edge is from a subgraph to an outside_compilation cluster in the + // same subgraph so use the appropriate _RecvAtHost node in the output + // graph. + TF_RET_CHECK(!dst_outside_compilation_id.empty()); + *src_image = subgraphs_.at(src_func_id) + .GetRecvAtHostNode(dst_outside_compilation_id); + } else { + // The edge is from a subgraph to a regular node in the output graph so + // use the subgraph's call node output. + *src_image = subgraphs_.at(src_func_id).GetCallNodeForOutputs(); + } + } else { + // The source of the edge is in the output graph so use the node image in + // the output graph. + *src_image = node_images.at(original_src_node); + } + return Status::OK(); +} + +int Encapsulator::FindOutputSlotOfEdgeSrc( + const string& src_func_id, const string& src_outside_compilation_id, + const string& dst_func_id, const string& dst_outside_compilation_id, + const Edge* edge) { + if (IsInSubgraph(src_func_id, src_outside_compilation_id)) { + const Subgraph& src_subgraph = subgraphs_.at(src_func_id); + if (src_func_id == dst_func_id) { + // 'src' is in a subgraph and 'dst' is outside_compilation in the same + // subgraph. Use the corresponding _RecvAtHost output instead. + return src_subgraph.GetRecvAtHostSlot(dst_outside_compilation_id, edge); + } else { + // 'src' is in a subgraph and 'dst' is a regular node in the output + // graph. Use the corresponding call output instead. + return src_subgraph.GetResultIndexForEdge(edge); + } + } else { + // The source of the edge is in the output graph so use the regular edge + // slot. + return edge->src_output(); + } +} + +Status Encapsulator::FindOutputImageOfEdgeDst( + const string& src_func_id, const string& src_outside_compilation_id, + const string& dst_func_id, const string& dst_outside_compilation_id, + const std::unordered_map& node_images, + const Node* original_dst_node, Node** dst_image) { + if (IsInSubgraph(dst_func_id, dst_outside_compilation_id)) { + if (src_func_id == dst_func_id) { + // The edge is to a subgraph from an outside_compilation cluster in the + // same subgraph so use the appropriate _SendFromHost node in the output + // graph. + TF_RET_CHECK(!src_outside_compilation_id.empty()); + *dst_image = subgraphs_.at(dst_func_id) + .GetSendFromHostNode(src_outside_compilation_id); + } else { + // The edge is to a subgraph from a regular node in the output graph so + // use the subgraph's call node input. + *dst_image = subgraphs_.at(dst_func_id).GetCallNodeForInputs(); + } + } else { + // The destination of the edge is in the output graph so use the node image + // in the output graph. + *dst_image = node_images.at(original_dst_node); + } + return Status::OK(); +} - // Copy the assigned device and the key_annotation over. - subgraph.call_node_inputs->set_assigned_device_name(subgraph.device); - subgraph.call_node_outputs = subgraph.call_node_inputs; +int Encapsulator::FindOutputSlotOfEdgeDst( + const string& src_func_id, const string& src_outside_compilation_id, + const string& dst_func_id, const string& dst_outside_compilation_id, + const Edge* edge) { + if (IsInSubgraph(dst_func_id, dst_outside_compilation_id)) { + const Subgraph& dst_subgraph = subgraphs_.at(dst_func_id); + if (dst_func_id == src_func_id) { + // 'dst' is in a subgraph and 'src' is outside_compilation in the same + // subgraph. Use the corresponding _SendFromHost input instead. + return dst_subgraph.GetSendFromHostSlot(src_outside_compilation_id, edge); + } else { + // 'dst' is in a subgraph and 'src' is a regular node in the output + // graph. Use the corresponding call input instead. + return dst_subgraph.GetArgIndexForEdge(edge); + } + } else { + // The destination of the edge is in the output graph so use the regular + // edge slot. + return edge->dst_input(); + } +} +Status Encapsulator::CopyEdgeToOutputGraph( + const Edge* edge, const string& src_func_id, + const string& src_outside_compilation_id, const string& dst_func_id, + const string& dst_outside_compilation_id, + const std::unordered_map& node_images, + bool parallel_checking, Graph* graph_out, + std::unordered_set, NodeSlot::PairHasher>* + edges_added) { + Node* src_image; + TF_RETURN_IF_ERROR(FindOutputImageOfEdgeSrc( + src_func_id, src_outside_compilation_id, dst_func_id, + dst_outside_compilation_id, node_images, edge->src(), &src_image)); + Node* dst_image; + TF_RETURN_IF_ERROR(FindOutputImageOfEdgeDst( + src_func_id, src_outside_compilation_id, dst_func_id, + dst_outside_compilation_id, node_images, edge->dst(), &dst_image)); + + // If this is a control edge then copy it and return. Lift control edges onto + // the enclosing call operator. + if (edge->IsControlEdge()) { + // Add the control edge, if we have not already added it, using the images + // determined above (potentially call operators or RecvAtHost/SendFromHost). + if (edges_added->emplace(NodeSlot(src_image, -1), NodeSlot(dst_image, -1)) + .second) { + graph_out->AddControlEdge(src_image, dst_image); + } + + // If parallel checking is enabled, also add a control edge to the + // corresponding parallel check op. if (parallel_checking) { - TF_RETURN_IF_ERROR(BuildParallelCheckOp(node_images, subgraph, graph_out, - &subgraph.call_node_outputs)); + graph_out->AddControlEdge(src_image, node_images.at(edge->dst())); } + return Status::OK(); } + int src_output = + FindOutputSlotOfEdgeSrc(src_func_id, src_outside_compilation_id, + dst_func_id, dst_outside_compilation_id, edge); + + int dst_input = + FindOutputSlotOfEdgeDst(src_func_id, src_outside_compilation_id, + dst_func_id, dst_outside_compilation_id, edge); + + if (IsInSubgraph(dst_func_id, dst_outside_compilation_id) && + parallel_checking) { + // If we are parallel checking, also feed the tensor as an input to the + // corresponding parallel check subgraph. + graph_out->AddEdge(src_image, src_output, node_images.at(edge->dst()), + edge->dst_input()); + } + + // Add the edge, if we have not already added it. + if (edges_added + ->emplace(NodeSlot(src_image, src_output), + NodeSlot(dst_image, dst_input)) + .second) { + graph_out->AddEdge(src_image, src_output, dst_image, dst_input); + } + return Status::OK(); +} + +Status Encapsulator::AddEdgesToOutputGraph( + const std::unordered_map& node_images, + bool parallel_checking, Graph* graph_out) { // Set of edges already added to the output graph, represented as (src, dst) // pairs. We use the set to deduplicate edges; multiple edges in the input // graph may map to one edge in the output graph. std::unordered_set, NodeSlot::PairHasher> edges_added; - // Add edges to the graph_out graph. for (const Edge* edge : graph_in_->edges()) { - string src_func_id = GetFunctionNameAttr(edge->src()); - string dst_func_id = GetFunctionNameAttr(edge->dst()); + string src_func_id; + string src_outside_compilation_id; + TF_RETURN_IF_ERROR(GetFunctionNameAttr(edge->src(), &src_func_id, + &src_outside_compilation_id)); + string dst_func_id; + string dst_outside_compilation_id; + TF_RETURN_IF_ERROR(GetFunctionNameAttr(edge->dst(), &dst_func_id, + &dst_outside_compilation_id)); // Ignore edges that are strictly contained within one subgraph, unless // we are constructing parallel check graphs. - if (!src_func_id.empty() && src_func_id == dst_func_id) { + if (IsInSubgraph(src_func_id, src_outside_compilation_id) && + IsInSubgraph(dst_func_id, dst_outside_compilation_id) && + src_func_id == dst_func_id) { if (parallel_checking) { Node* src_image = node_images.at(edge->src()); Node* dst_image = node_images.at(edge->dst()); @@ -493,89 +1493,62 @@ Status Encapsulator::BuildOutputGraph(bool parallel_checking, continue; } - // We have an edge that crosses a cluster boundary. - Node* src_image = src_func_id.empty() - ? node_images.at(edge->src()) - : subgraphs_.at(src_func_id).call_node_outputs; - Node* dst_image = dst_func_id.empty() - ? node_images.at(edge->dst()) - : subgraphs_.at(dst_func_id).call_node_inputs; - - // Copy control edges. Lift control edges onto the enclosing call operator. - if (edge->IsControlEdge()) { - // Add the control edge, if we have not already added it. - if (edges_added.emplace(NodeSlot(src_image, -1), NodeSlot(dst_image, -1)) - .second) { - graph_out->AddControlEdge(src_image, dst_image); - } - - // If parallel checking is enabled, also add a control edge to the - // corresponding parallel check op. - if (parallel_checking) { - graph_out->AddControlEdge(src_image, node_images.at(edge->dst())); - } - continue; - } + // We have an edge that crosses a cluster boundary or is entirely within the + // unclustered graph. + TF_RETURN_IF_ERROR(CopyEdgeToOutputGraph( + edge, src_func_id, src_outside_compilation_id, dst_func_id, + dst_outside_compilation_id, node_images, parallel_checking, graph_out, + &edges_added)); + } - int src_output = edge->src_output(); - if (!src_func_id.empty()) { - // 'src' is in a subgraph. Use the corresponding call output instead. - const Subgraph& src_subgraph = subgraphs_.at(src_func_id); - src_output = - src_subgraph.results.at(NodeSlot(edge->src(), edge->src_output())); - } + for (auto& subgraph_entry : subgraphs_) { + Subgraph& subgraph = subgraph_entry.second; + subgraph.ConnectSequencerToOutputs(graph_out); + } - int dst_input = edge->dst_input(); + return Status::OK(); +} - if (!dst_func_id.empty()) { - // 'dst' is in a subgraph. Use the corresponding call input instead. - const Subgraph& dst_subgraph = subgraphs_.at(dst_func_id); - dst_input = - dst_subgraph.args_by_dst.at(NodeSlot(edge->dst(), edge->dst_input())); +Status Encapsulator::BuildOutputGraph(bool parallel_checking, + Graph* graph_out) { + // Map from nodes in the input graph to nodes in the output graph. + std::unordered_map node_images; - // If we are parallel checking, also feed the tensor as an input to the - // corresponding parallel check subgraph. - if (parallel_checking) { - graph_out->AddEdge(src_image, src_output, node_images.at(edge->dst()), - edge->dst_input()); - } - } - // Add the edge, if we have not already added it. - if (edges_added - .emplace(NodeSlot(src_image, src_output), - NodeSlot(dst_image, dst_input)) - .second) { - graph_out->AddEdge(src_image, src_output, dst_image, dst_input); - } - } + TF_RETURN_IF_ERROR( + CopyNodesToOutputGraph(parallel_checking, graph_out, &node_images)); + TF_RETURN_IF_ERROR( + AddFunctionCallNodes(node_images, parallel_checking, graph_out)); + TF_RETURN_IF_ERROR(AddOutsideCompilationHostIONodes(node_images, graph_out)); + TF_RETURN_IF_ERROR( + AddEdgesToOutputGraph(node_images, parallel_checking, graph_out)); - return s; + return Status::OK(); } } // anonymous namespace Status EncapsulateSubgraphsInFunctions( - string group_attribute, const Graph& graph_in, - const RewriteSubgraphFn& rewrite_subgraph_fn, bool parallel_checking, - bool reuse_existing_functions, std::unique_ptr* graph_out, - FunctionLibraryDefinition* library) { + string group_attribute, string outside_compilation_attribute, + const Graph& graph_in, const RewriteSubgraphFn& rewrite_subgraph_fn, + bool parallel_checking, bool reuse_existing_functions, + std::unique_ptr* graph_out, FunctionLibraryDefinition* library) { Status s; - Encapsulator encapsulator(std::move(group_attribute), &graph_in); - s = encapsulator.SplitIntoSubgraphs(); - if (!s.ok()) return s; + Encapsulator encapsulator(std::move(group_attribute), + std::move(outside_compilation_attribute), + &graph_in); + TF_RETURN_IF_ERROR(encapsulator.SplitIntoSubgraphs()); - s = encapsulator.BuildFunctionDefs(rewrite_subgraph_fn, - reuse_existing_functions, library); - if (!s.ok()) return s; + TF_RETURN_IF_ERROR(encapsulator.BuildFunctionDefs( + rewrite_subgraph_fn, reuse_existing_functions, library)); std::unique_ptr out(new Graph(library)); out->set_versions(graph_in.versions()); - s = encapsulator.BuildOutputGraph(parallel_checking, out.get()); - if (!s.ok()) return s; + TF_RETURN_IF_ERROR( + encapsulator.BuildOutputGraph(parallel_checking, out.get())); *graph_out = std::move(out); - return s; + return Status::OK(); } // Finds the types of the _Arg nodes, indexed by position. @@ -690,9 +1663,9 @@ Status EncapsulateSubgraphsPass::Run( }; TF_RETURN_IF_ERROR(EncapsulateSubgraphsInFunctions( - kXlaClusterAttr, **options.graph, rewrite_subgraph, - flags->tf_xla_parallel_checking, /*reuse_existing_functions=*/false, - &graph_out, library)); + kXlaClusterAttr, kXlaOutsideCompilationAttr, **options.graph, + rewrite_subgraph, flags->tf_xla_parallel_checking, + /*reuse_existing_functions=*/false, &graph_out, library)); if (VLOG_IS_ON(1)) { dump_graph::DumpGraphToFile("after_encapsulate_subgraphs", *graph_out, diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h index b0987f76c91ed48df52fab303ea6052ebd8fd336..34be4409a381197d2191e083727aa8d48ab8cd63 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h @@ -48,6 +48,16 @@ typedef std::function* graph_out, - FunctionLibraryDefinition* library); + string group_attribute, string outside_compilation_attribute, + const Graph& graph_in, const RewriteSubgraphFn& rewrite_subgraph_fn, + bool parallel_checking, bool reuse_existing_functions, + std::unique_ptr* graph_out, FunctionLibraryDefinition* library); // The attribute that marks function calls produced by the encapsulate // subgraphs pass and that should in turn be compiled via _XlaLaunch operators. diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc index 4a1dbaf05dc7824835f3567c6abcf48222720230..b100861d5e9c04a8f9d32d486e0ee7252b79c62b 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc @@ -36,7 +36,7 @@ bool EqualFunctionDef(const FunctionDef& a, const FunctionDef& b, if (diff) { *diff = strings::StrCat("Definition mismatch for function ", a.signature().name(), ", expected:\n", - a.DebugString()); + a.DebugString(), "\ngot:\n", b.DebugString()); } return false; } @@ -82,6 +82,24 @@ bool EqualFunctionDefLibrary(const FunctionDefLibrary& expected, << diff << "\nActual: " << actual.DebugString(); \ } while (false) +// TODO(misard): remove these fake registrations once there are real Ops to be +// compiled. +REGISTER_OP("_XlaSendToHost") + .Input("input: dtypes") + .Attr("dtypes: list(type) >= 0"); + +REGISTER_OP("_XlaRecvFromHost") + .Output("output: dtypes") + .Attr("dtypes: list(type) >= 0"); + +REGISTER_OP("_XlaSendFromHost") + .Input("input: dtypes") + .Attr("dtypes: list(type) >= 0"); + +REGISTER_OP("_XlaRecvAtHost") + .Output("output: dtypes") + .Attr("dtypes: list(type) >= 0"); + REGISTER_OP("InputTest").Output("o: float"); REGISTER_OP("UnaryTest").Input("a: float").Output("o: float"); @@ -98,10 +116,32 @@ REGISTER_OP("AddNLikeTest") .SetIsCommutative() .SetIsAggregate(); +Node* NoOp(const GraphDefBuilder::Options& opts) { + return ops::SourceOp("NoOp", opts); +} + Node* Input(const GraphDefBuilder::Options& opts) { return ops::SourceOp("InputTest", opts); } +Node* RecvAtHost(const gtl::ArraySlice& dtypes, + const GraphDefBuilder::Options& opts) { + if (opts.HaveError()) return nullptr; + NodeBuilder node_builder(opts.GetNameForOp("_XlaRecvAtHost"), + "_XlaRecvAtHost", opts.op_registry()); + return opts.WithAttr("dtypes", dtypes).FinalizeBuilder(&node_builder); +} + +Node* SendFromHost(const std::vector& inputs, + const gtl::ArraySlice& dtypes, + const GraphDefBuilder::Options& opts) { + if (opts.HaveError()) return nullptr; + NodeBuilder node_builder(opts.GetNameForOp("_XlaSendFromHost"), + "_XlaSendFromHost", opts.op_registry()); + node_builder.Input(inputs); + return opts.WithAttr("dtypes", dtypes).FinalizeBuilder(&node_builder); +} + Node* Unary(ops::NodeOut a, const GraphDefBuilder::Options& opts) { return ops::UnaryOp("UnaryTest", std::move(a), opts); } @@ -145,7 +185,7 @@ Status Encapsulate(GraphDef* graphdef, FunctionDefLibrary* library) { if (!s.ok()) return s; std::unique_ptr graph_out; - s = EncapsulateSubgraphsInFunctions("_encapsulate", *graph, + s = EncapsulateSubgraphsInFunctions("_encapsulate", "_outside", *graph, /*rewrite_subgraph_fn=*/{}, /*parallel_checking=*/false, /*reuse_existing_functions=*/false, @@ -178,6 +218,7 @@ TEST(EncapsulateSubgraphsTest, NoFunctions) { FunctionDefLibrary library_out = library_in; TF_EXPECT_OK(Encapsulate(&graphdef_out, &library_out)); + // If there are no marked nodes, funcification should be a no-op. TF_EXPECT_GRAPH_EQ(graphdef_in, graphdef_out); TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_in, library_out); } @@ -230,7 +271,6 @@ TEST(EncapsulateSubgraphsTest, OneFunction) { TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); } - // If there are no marked nodes, funcification should be a no-op. TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef); TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library); } @@ -342,9 +382,9 @@ TEST(EncapsulateSubgraphsTest, InputDeduplication) { FunctionLibraryDefinition library(OpRegistry::Global(), {}); std::unique_ptr graph; TF_ASSERT_OK(EncapsulateSubgraphsInFunctions( - "_cluster", graph_before_encapsulation, /*rewrite_subgraph_fn=*/{}, - /*parallel_checking=*/false, /*reuse_existing_functions=*/false, &graph, - &library)); + "_cluster", "_outside", graph_before_encapsulation, + /*rewrite_subgraph_fn=*/{}, /*parallel_checking=*/false, + /*reuse_existing_functions=*/false, &graph, &library)); std::vector expected_nodes = {"cluster1", "cluster2", "mul", "x"}; EXPECT_EQ(expected_nodes, GraphNodes(*graph)); @@ -374,9 +414,9 @@ TEST(EncapsulateSubgraphsTest, ParallelChecking) { FunctionLibraryDefinition library(OpRegistry::Global(), {}); std::unique_ptr graph; TF_ASSERT_OK(EncapsulateSubgraphsInFunctions( - "_cluster", graph_before_encapsulation, /*rewrite_subgraph_fn=*/{}, - /*parallel_checking=*/true, /*reuse_existing_functions=*/false, &graph, - &library)); + "_cluster", "_outside", graph_before_encapsulation, + /*rewrite_subgraph_fn=*/{}, /*parallel_checking=*/true, + /*reuse_existing_functions=*/false, &graph, &library)); std::vector expected_nodes = { "add1", "add2", "cluster1", "cluster1_parallel_check/_0", @@ -398,5 +438,782 @@ TEST(EncapsulateSubgraphsTest, ParallelChecking) { EXPECT_EQ(expected_edges, GraphEdges(*graph)); } +const Node* FindNodeByName(const Graph& graph, const string& name) { + for (const Node* node : graph.nodes()) { + if (node->name() == name) return node; + } + return nullptr; +} + +bool HasGuaranteeConstAttr(const Node& n) { + bool is_guaranteed_constant = false; + if (!GetNodeAttr(n.attrs(), "_is_guaranteed_constant", + &is_guaranteed_constant) + .ok()) { + return false; + } + return is_guaranteed_constant; +} + +TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Simple) { + Scope root = Scope::NewRootScope().ExitOnError().WithDevice( + "/job:localhost/replica:0/task:0/cpu:0"); + auto x1 = ops::Placeholder(root.WithOpName("x1"), DT_FLOAT); + auto const_x2 = ops::Const(root.WithOpName("const_x2"), 10.0f); + auto const_guarantee_x1 = + ops::GuaranteeConst(root.WithOpName("const_guarantee_x1"), x1); + auto add1 = ops::Add(root.WithOpName("add1"), const_guarantee_x1, const_x2); + add1.node()->AddAttr("_encapsulate", "encapsulate1"); + + Graph graph_before(OpRegistry::Global()); + TF_ASSERT_OK(root.ToGraph(&graph_before)); + + std::unique_ptr graph_after; + FunctionLibraryDefinition library(OpRegistry::Global(), {}); + int guaranteed_consts = 0; + TF_ASSERT_OK(EncapsulateSubgraphsInFunctions( + "_encapsulate", "_outside", graph_before, + /*rewrite_subgraph_fn=*/ + [&guaranteed_consts](std::unique_ptr* graph_ptr, + std::vector* input_permutation, + std::vector* output_permutation, + NodeDef* call_def) { + Graph* graph = graph_ptr->get(); + for (const Node* n : graph->nodes()) { + if (n->type_string() == "_Arg" && + StringPiece(n->name()).starts_with("const")) { + ++guaranteed_consts; + EXPECT_TRUE(HasGuaranteeConstAttr(*n)); + } else { + EXPECT_FALSE(HasGuaranteeConstAttr(*n)); + } + } + return Status::OK(); + }, + /*parallel_checking=*/false, + /*reuse_existing_functions=*/false, &graph_after, &library)); + EXPECT_EQ(2, guaranteed_consts); +} + +TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Add) { + Scope root = Scope::NewRootScope().ExitOnError().WithDevice( + "/job:localhost/replica:0/task:0/cpu:0"); + auto x1 = ops::Placeholder(root.WithOpName("x1"), DT_FLOAT); + auto x2 = ops::Placeholder(root.WithOpName("x2"), DT_FLOAT); + auto const_guarantee_x1 = + ops::GuaranteeConst(root.WithOpName("const_guarantee_x1"), x1); + auto const_guarantee_x2 = + ops::GuaranteeConst(root.WithOpName("const_guarantee_x2"), x2); + auto const_guarantee_add1 = ops::Add(root.WithOpName("const_guarantee_add1"), + const_guarantee_x1, const_guarantee_x2); + auto add2 = ops::Add(root.WithOpName("add2"), const_guarantee_x1, x2); + auto mul1 = ops::Mul(root.WithOpName("mul1"), const_guarantee_add1, add2); + mul1.node()->AddAttr("_encapsulate", "encapsulate1"); + + Graph graph_before(OpRegistry::Global()); + TF_ASSERT_OK(root.ToGraph(&graph_before)); + + std::unique_ptr graph_after; + FunctionLibraryDefinition library(OpRegistry::Global(), {}); + int guaranteed_consts = 0; + TF_ASSERT_OK(EncapsulateSubgraphsInFunctions( + "_encapsulate", "_outside", graph_before, + /*rewrite_subgraph_fn=*/ + [&guaranteed_consts](std::unique_ptr* graph_ptr, + std::vector* input_permutation, + std::vector* output_permutation, + NodeDef* call_def) { + Graph* graph = graph_ptr->get(); + for (const Node* n : graph->nodes()) { + if (n->type_string() == "_Arg" && + StringPiece(n->name()).starts_with("const")) { + ++guaranteed_consts; + EXPECT_TRUE(HasGuaranteeConstAttr(*n)); + } else { + EXPECT_FALSE(HasGuaranteeConstAttr(*n)); + } + } + return Status::OK(); + }, + /*parallel_checking=*/false, + /*reuse_existing_functions=*/false, &graph_after, &library)); + // Only 1 runtime const, which is const_guarantee_add1. Add2 has one const + // and another non-const, so overall non-const. + EXPECT_EQ(1, guaranteed_consts); +} + +// Test with one function to transform and one outside_compilation cluster. +TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) { + FunctionDefLibrary library; + GraphDef graphdef; + + { + *library.add_function() = test::function::XTimesTwo(); + + GraphDefBuilder b1(GraphDefBuilder::kFailImmediately); + Node* a = Input(b1.opts().WithName("A")); + Node* b = Input(b1.opts().WithName("B")); + // Give nodes 'c' and 'd' names that collide after lowercasing. + Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1")); + Node* d = Binary(b, c, + b1.opts().WithName("c").WithControlInput(c).WithAttr( + "_encapsulate", "F1")); + Node* e = Binary(c, d, + b1.opts() + .WithName("E") + .WithControlInputs({b, d}) + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + Node* f = Binary(c, e, + b1.opts().WithName("F").WithControlInput(e).WithAttr( + "_encapsulate", "F1")); + Binary(a, f, b1.opts().WithName("G").WithControlInput(e)); + TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); + } + + TF_EXPECT_OK(Encapsulate(&graphdef, &library)); + + FunctionDefLibrary library_expected; + GraphDef graphdef_expected; + + *library_expected.add_function() = test::function::XTimesTwo(); + *library_expected.add_function() = FunctionDefHelper::Create( + "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval:float"}, {}, + { + {{"C"}, "UnaryTest", {"a_0_arg"}}, + {{"c"}, "BinaryTest", {"b_0_arg", "C:o:0"}, {}, {"C"}}, + {{"F"}, + "BinaryTest", + {"C:o:0", "outside_compilation_O1_recv:output:0"}, + {}, + {"outside_compilation_O1_recv"}}, + {{"outside_compilation_O1_send"}, + "_XlaSendToHost", + {"C:o:0", "c:o:0"}, + {{"dtypes", gtl::ArraySlice({DT_FLOAT, DT_FLOAT})}}, + {"c"}}, + {{"outside_compilation_O1_recv"}, + "_XlaRecvFromHost", + {}, + {{"dtypes", gtl::ArraySlice({DT_FLOAT})}}, + {"outside_compilation_O1_send"}}, + }, + {{"f_0_retval", "F:o:0"}}); + + { + std::unique_ptr lib_def( + new FunctionLibraryDefinition(OpRegistry::Global(), library_expected)); + GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get()); + Node* a = Input(b2.opts().WithName("A")); + Node* b = Input(b2.opts().WithName("B")); + + NodeBuilder node_builder("F1", "F1", lib_def.get()); + node_builder.Input(a).Input(b); + Node* call = b2.opts().FinalizeBuilder(&node_builder); + + Node* recv = + RecvAtHost({DT_FLOAT, DT_FLOAT}, + b2.opts().WithName("outside_compilation_F1_O1_recv")); + Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1), + b2.opts().WithName("E").WithControlInputs({recv, b})); + Node* send = SendFromHost({e}, {DT_FLOAT}, + b2.opts() + .WithName("outside_compilation_F1_O1_send") + .WithControlInput(e)); + + Node* s = NoOp( + b2.opts().WithName("F1_sequencer").WithControlInputs({recv, send})); + + Binary(a, call, b2.opts().WithName("G").WithControlInputs({s, e})); + TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); + } + + TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef); + TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library); +} + +// Test with one function to transform and two outside_compilation clusters. +TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { + FunctionDefLibrary library; + GraphDef graphdef; + + { + GraphDefBuilder b1(GraphDefBuilder::kFailImmediately); + Node* a = Input(b1.opts().WithName("A")); + Node* b = Input(b1.opts().WithName("B")); + Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1")); + Node* d = + Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1")); + Node* e = Binary(c, d, + b1.opts() + .WithName("E") + .WithControlInputs({b, d}) + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + Node* f = Binary(c, e, + b1.opts().WithName("F").WithControlInput(e).WithAttr( + "_encapsulate", "F1")); + Node* g = Binary(e, f, + b1.opts() + .WithName("G") + .WithControlInputs({e, f}) + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O2")); + Node* h = Binary(d, e, + b1.opts() + .WithName("H") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O2")); + Node* i = Unary(h, b1.opts().WithName("I").WithAttr("_encapsulate", "F1")); + Binary(g, i, b1.opts().WithName("J")); + TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); + } + + TF_EXPECT_OK(Encapsulate(&graphdef, &library)); + + FunctionDefLibrary library_expected; + GraphDef graphdef_expected; + + *library_expected.add_function() = FunctionDefHelper::Create( + "F1", {"a_0_arg:float", "b_0_arg:float"}, {"i_0_retval:float"}, {}, + { + {{"C"}, "UnaryTest", {"a_0_arg"}}, + {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}, {}}, + {{"I"}, "UnaryTest", {"outside_compilation_O2_recv:output:0"}}, + {{"F"}, + "BinaryTest", + {"C:o:0", "outside_compilation_O1_recv:output:0"}, + {}, + {"outside_compilation_O1_recv"}}, + {{"outside_compilation_O2_send"}, + "_XlaSendToHost", + {"D:o:0", "F:o:0"}, + {{"dtypes", gtl::ArraySlice({DT_FLOAT, DT_FLOAT})}}, + {"F"}}, + {{"outside_compilation_O1_send"}, + "_XlaSendToHost", + {"C:o:0", "D:o:0"}, + {{"dtypes", gtl::ArraySlice({DT_FLOAT, DT_FLOAT})}}, + {"D"}}, + {{"outside_compilation_O2_recv"}, + "_XlaRecvFromHost", + {}, + {{"dtypes", gtl::ArraySlice({DT_FLOAT})}}, + {"outside_compilation_O2_send"}}, + {{"outside_compilation_O1_recv"}, + "_XlaRecvFromHost", + {}, + {{"dtypes", gtl::ArraySlice({DT_FLOAT})}}, + {"outside_compilation_O1_send"}}, + }, + {{"i_0_retval", "I:o:0"}}); + + { + std::unique_ptr lib_def( + new FunctionLibraryDefinition(OpRegistry::Global(), library_expected)); + GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get()); + Node* a = Input(b2.opts().WithName("A")); + Node* b = Input(b2.opts().WithName("B")); + + NodeBuilder node_builder("F1", "F1", lib_def.get()); + node_builder.Input(a).Input(b); + Node* call = b2.opts().FinalizeBuilder(&node_builder); + + Node* recv1 = + RecvAtHost({DT_FLOAT, DT_FLOAT}, + b2.opts().WithName("outside_compilation_F1_O1_recv")); + Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1), + b2.opts().WithName("E").WithControlInputs({recv1, b})); + Node* send1 = SendFromHost({e}, {DT_FLOAT}, + b2.opts() + .WithName("outside_compilation_F1_O1_send") + .WithControlInput(e)); + + Node* recv2 = + RecvAtHost({DT_FLOAT, DT_FLOAT}, + b2.opts().WithName("outside_compilation_F1_O2_recv")); + Node* g = Binary(e, ops::NodeOut(recv2, 1), + b2.opts().WithName("G").WithControlInputs({recv2, e})); + Node* h = Binary(ops::NodeOut(recv2, 0), e, b2.opts().WithName("H")); + Node* send2 = SendFromHost( + {h}, {DT_FLOAT}, b2.opts().WithName("outside_compilation_F1_O2_send")); + + Node* s = NoOp(b2.opts() + .WithName("F1_sequencer") + .WithControlInputs({recv1, send1, recv2, send2})); + + Binary(g, call, b2.opts().WithName("J").WithControlInput(s)); + TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); + } + + TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef); + TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library); +} + +// Test with two functions to transform, each with one outside_compilation +// cluster. +TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) { + FunctionDefLibrary library; + GraphDef graphdef; + + { + GraphDefBuilder b1(GraphDefBuilder::kFailImmediately); + Node* a = Input(b1.opts().WithName("A")); + Node* b = Input(b1.opts().WithName("B")); + Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1")); + Node* d = + Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1")); + Node* e = Binary(c, d, + b1.opts() + .WithName("E") + .WithControlInputs({b, d}) + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + Node* f = Binary(c, e, + b1.opts().WithName("F").WithControlInput(e).WithAttr( + "_encapsulate", "F1")); + Node* g = Binary(e, f, + b1.opts().WithName("G").WithControlInputs({e, f}).WithAttr( + "_encapsulate", "F2")); + Node* h = Binary(d, g, + b1.opts() + .WithName("H") + .WithAttr("_encapsulate", "F2") + .WithAttr("_outside", "O1")); + Node* i = + Binary(f, h, b1.opts().WithName("I").WithAttr("_encapsulate", "F2")); + Binary(g, i, b1.opts().WithName("J")); + TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); + } + + TF_EXPECT_OK(Encapsulate(&graphdef, &library)); + + FunctionDefLibrary library_expected; + GraphDef graphdef_expected; + + *library_expected.add_function() = FunctionDefHelper::Create( + "F1", {"a_0_arg:float", "b_0_arg:float"}, + {"f_0_retval:float", "d_0_retval:float"}, {}, + { + {{"C"}, "UnaryTest", {"a_0_arg"}}, + {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}}, + {{"F"}, + "BinaryTest", + {"C:o:0", "outside_compilation_O1_recv:output:0"}, + {}, + {"outside_compilation_O1_recv"}}, + {{"outside_compilation_O1_send"}, + "_XlaSendToHost", + {"C:o:0", "D:o:0"}, + {{"dtypes", gtl::ArraySlice({DT_FLOAT, DT_FLOAT})}}, + {"D"}}, + {{"outside_compilation_O1_recv"}, + "_XlaRecvFromHost", + {}, + {{"dtypes", gtl::ArraySlice({DT_FLOAT})}}, + {"outside_compilation_O1_send"}}, + }, + {{"d_0_retval", "D:o:0"}, {"f_0_retval", "F:o:0"}}); + + *library_expected.add_function() = FunctionDefHelper::Create( + "F2", {"e_0_arg:float", "f_0_arg:float"}, + {"g_0_retval:float", "i_0_retval:float"}, {}, + { + {{"G"}, "BinaryTest", {"e_0_arg", "f_0_arg"}}, + {{"I"}, + "BinaryTest", + {"f_0_arg", "outside_compilation_O1_recv:output:0"}}, + {{"outside_compilation_O1_send"}, + "_XlaSendToHost", + {"G:o:0"}, + {{"dtypes", gtl::ArraySlice({DT_FLOAT})}}}, + {{"outside_compilation_O1_recv"}, + "_XlaRecvFromHost", + {}, + {{"dtypes", gtl::ArraySlice({DT_FLOAT})}}, + {"outside_compilation_O1_send"}}, + }, + {{"g_0_retval", "G:o:0"}, {"i_0_retval", "I:o:0"}}); + + { + std::unique_ptr lib_def( + new FunctionLibraryDefinition(OpRegistry::Global(), library_expected)); + GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get()); + Node* a = Input(b2.opts().WithName("A")); + Node* b = Input(b2.opts().WithName("B")); + + Node* recv1 = + RecvAtHost({DT_FLOAT, DT_FLOAT}, + b2.opts().WithName("outside_compilation_F1_O1_recv")); + Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1), + b2.opts().WithName("E").WithControlInputs({recv1, b})); + Node* send1 = SendFromHost({e}, {DT_FLOAT}, + b2.opts() + .WithName("outside_compilation_F1_O1_send") + .WithControlInput(e)); + NodeBuilder node_builder1("F1", "F1", lib_def.get()); + node_builder1.Input(a).Input(b); + Node* call1 = b2.opts().FinalizeBuilder(&node_builder1); + Node* s1 = NoOp( + b2.opts().WithName("F1_sequencer").WithControlInputs({recv1, send1})); + + Node* recv2 = RecvAtHost( + {DT_FLOAT}, b2.opts().WithName("outside_compilation_F2_O1_recv")); + Node* h = Binary(ops::NodeOut(call1, 1), recv2, + b2.opts().WithName("H").WithControlInput(s1)); + Node* send2 = SendFromHost( + {h}, {DT_FLOAT}, b2.opts().WithName("outside_compilation_F2_O1_send")); + + NodeBuilder node_builder2("F2", "F2", lib_def.get()); + node_builder2.Input(e).Input(call1); + Node* call2 = b2.opts() + .WithControlInputs({s1, e, call1}) + .FinalizeBuilder(&node_builder2); + Node* s2 = NoOp( + b2.opts().WithName("F2_sequencer").WithControlInputs({recv2, send2})); + Binary(call2, ops::NodeOut(call2, 1), + b2.opts().WithName("J").WithControlInput(s2)); + TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); + } + + TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef); + TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library); +} + +// Test with one outside_compilation cluster that has no inputs from the +// compiled subgraph. +TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputs) { + FunctionDefLibrary library; + GraphDef graphdef; + + { + GraphDefBuilder b1(GraphDefBuilder::kFailImmediately); + Node* a = Input(b1.opts().WithName("A")); + Node* b = Input(b1.opts().WithName("B")); + Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1")); + Node* d = + Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1")); + Node* e = Unary(a, b1.opts() + .WithName("E") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + Node* f = + Binary(d, e, b1.opts().WithName("F").WithAttr("_encapsulate", "F1")); + Unary(f, b1.opts().WithName("G")); + TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); + } + + TF_EXPECT_OK(Encapsulate(&graphdef, &library)); + + FunctionDefLibrary library_expected; + GraphDef graphdef_expected; + + *library_expected.add_function() = FunctionDefHelper::Create( + "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval:float"}, {}, + { + {{"C"}, "UnaryTest", {"a_0_arg"}}, + {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}}, + {{"F"}, + "BinaryTest", + {"D:o:0", "outside_compilation_O1_recv:output:0"}}, + {{"outside_compilation_O1_recv"}, + "_XlaRecvFromHost", + {}, + {{"dtypes", gtl::ArraySlice({DT_FLOAT})}}}, + }, + {{"f_0_retval", "F:o:0"}}); + + { + std::unique_ptr lib_def( + new FunctionLibraryDefinition(OpRegistry::Global(), library_expected)); + GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get()); + Node* a = Input(b2.opts().WithName("A")); + Node* b = Input(b2.opts().WithName("B")); + + Node* e = Unary(a, b2.opts().WithName("E")); + Node* send1 = SendFromHost( + {e}, {DT_FLOAT}, b2.opts().WithName("outside_compilation_F1_O1_send")); + NodeBuilder node_builder1("F1", "F1", lib_def.get()); + node_builder1.Input(a).Input(b); + Node* call1 = b2.opts().FinalizeBuilder(&node_builder1); + Node* s1 = NoOp(b2.opts().WithName("F1_sequencer").WithControlInput(send1)); + + Unary(call1, b2.opts().WithName("G").WithControlInput(s1)); + TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); + } + + TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef); + TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library); +} + +// Test with one outside_compilation cluster that has no data inputs but has a +// control input from the compiled subgraph. +TEST(EncapsulateSubgraphsTest, OutsideCompilationControlInput) { + FunctionDefLibrary library; + GraphDef graphdef; + + { + GraphDefBuilder b1(GraphDefBuilder::kFailImmediately); + Node* a = Input(b1.opts().WithName("A")); + Node* b = Input(b1.opts().WithName("B")); + Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1")); + Node* d = + Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1")); + Node* e = Unary(a, b1.opts() + .WithName("E") + .WithControlInput(d) + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + Node* f = + Binary(d, e, b1.opts().WithName("F").WithAttr("_encapsulate", "F1")); + Unary(f, b1.opts().WithName("G")); + TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); + } + + TF_EXPECT_OK(Encapsulate(&graphdef, &library)); + + FunctionDefLibrary library_expected; + GraphDef graphdef_expected; + + *library_expected.add_function() = FunctionDefHelper::Create( + "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval:float"}, {}, + { + {{"C"}, "UnaryTest", {"a_0_arg"}}, + {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}}, + {{"F"}, + "BinaryTest", + {"D:o:0", "outside_compilation_O1_recv:output:0"}}, + {{"outside_compilation_O1_send"}, + "_XlaSendToHost", + {}, + {{"dtypes", gtl::ArraySlice({})}}, + {"D"}}, + {{"outside_compilation_O1_recv"}, + "_XlaRecvFromHost", + {}, + {{"dtypes", gtl::ArraySlice({DT_FLOAT})}}, + {"outside_compilation_O1_send"}}, + }, + {{"f_0_retval", "F:o:0"}}); + + { + std::unique_ptr lib_def( + new FunctionLibraryDefinition(OpRegistry::Global(), library_expected)); + GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get()); + Node* a = Input(b2.opts().WithName("A")); + Node* b = Input(b2.opts().WithName("B")); + + Node* recv1 = + RecvAtHost({}, b2.opts().WithName("outside_compilation_F1_O1_recv")); + Node* e = Unary(a, b2.opts().WithName("E").WithControlInput(recv1)); + Node* send1 = SendFromHost( + {e}, {DT_FLOAT}, b2.opts().WithName("outside_compilation_F1_O1_send")); + NodeBuilder node_builder1("F1", "F1", lib_def.get()); + node_builder1.Input(a).Input(b); + Node* call1 = b2.opts().FinalizeBuilder(&node_builder1); + Node* s1 = NoOp( + b2.opts().WithName("F1_sequencer").WithControlInputs({recv1, send1})); + + Unary(call1, b2.opts().WithName("G").WithControlInput(s1)); + TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); + } + + TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef); + TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library); +} + +// Test with one outside_compilation cluster that has no outputs from the +// compiled subgraph. +TEST(EncapsulateSubgraphsTest, OutsideCompilationNoOutputs) { + FunctionDefLibrary library; + GraphDef graphdef; + + { + GraphDefBuilder b1(GraphDefBuilder::kFailImmediately); + Node* a = Input(b1.opts().WithName("A")); + Node* b = Input(b1.opts().WithName("B")); + Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1")); + Node* d = + Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1")); + Node* e = Unary(d, b1.opts() + .WithName("E") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + Node* f = Unary(d, b1.opts().WithName("F").WithAttr("_encapsulate", "F1")); + Binary(e, f, b1.opts().WithName("G")); + TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); + } + + TF_EXPECT_OK(Encapsulate(&graphdef, &library)); + + FunctionDefLibrary library_expected; + GraphDef graphdef_expected; + + *library_expected.add_function() = FunctionDefHelper::Create( + "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval:float"}, {}, + { + {{"C"}, "UnaryTest", {"a_0_arg"}}, + {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}}, + {{"F"}, "UnaryTest", {"D:o:0"}}, + {{"outside_compilation_O1_send"}, + "_XlaSendToHost", + {"D:o:0"}, + {{"dtypes", gtl::ArraySlice({DT_FLOAT})}}}, + }, + {{"f_0_retval", "F:o:0"}}); + + { + std::unique_ptr lib_def( + new FunctionLibraryDefinition(OpRegistry::Global(), library_expected)); + GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get()); + Node* a = Input(b2.opts().WithName("A")); + Node* b = Input(b2.opts().WithName("B")); + + Node* recv1 = RecvAtHost( + {DT_FLOAT}, b2.opts().WithName("outside_compilation_F1_O1_recv")); + Node* e = Unary(recv1, b2.opts().WithName("E")); + NodeBuilder node_builder1("F1", "F1", lib_def.get()); + node_builder1.Input(a).Input(b); + Node* call1 = b2.opts().FinalizeBuilder(&node_builder1); + Node* s1 = NoOp(b2.opts().WithName("F1_sequencer").WithControlInput(recv1)); + + Binary(e, call1, b2.opts().WithName("G").WithControlInput(s1)); + TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); + } + + TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef); + TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library); +} + +// Test with one outside_compilation cluster that has no data outputs but has a +// control output to the compiled subgraph. +TEST(EncapsulateSubgraphsTest, OutsideCompilationControlOutput) { + FunctionDefLibrary library; + GraphDef graphdef; + + { + GraphDefBuilder b1(GraphDefBuilder::kFailImmediately); + Node* a = Input(b1.opts().WithName("A")); + Node* b = Input(b1.opts().WithName("B")); + Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1")); + Node* d = + Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1")); + Node* e = Unary(d, b1.opts() + .WithName("E") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + Node* f = Unary(d, b1.opts().WithName("F").WithControlInput(e).WithAttr( + "_encapsulate", "F1")); + Binary(e, f, b1.opts().WithName("G")); + TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); + } + + TF_EXPECT_OK(Encapsulate(&graphdef, &library)); + + FunctionDefLibrary library_expected; + GraphDef graphdef_expected; + + *library_expected.add_function() = FunctionDefHelper::Create( + "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval:float"}, {}, + { + {{"C"}, "UnaryTest", {"a_0_arg"}}, + {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}}, + {{"F"}, "UnaryTest", {"D:o:0"}, {}, {"outside_compilation_O1_recv"}}, + {{"outside_compilation_O1_send"}, + "_XlaSendToHost", + {"D:o:0"}, + {{"dtypes", gtl::ArraySlice({DT_FLOAT})}}}, + {{"outside_compilation_O1_recv"}, + "_XlaRecvFromHost", + {}, + {{"dtypes", gtl::ArraySlice({})}}, + {"outside_compilation_O1_send"}}, + }, + {{"f_0_retval", "F:o:0"}}); + + { + std::unique_ptr lib_def( + new FunctionLibraryDefinition(OpRegistry::Global(), library_expected)); + GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get()); + Node* a = Input(b2.opts().WithName("A")); + Node* b = Input(b2.opts().WithName("B")); + + Node* recv1 = RecvAtHost( + {DT_FLOAT}, b2.opts().WithName("outside_compilation_F1_O1_recv")); + Node* e = Unary(recv1, b2.opts().WithName("E")); + Node* send1 = SendFromHost({}, {}, + b2.opts() + .WithName("outside_compilation_F1_O1_send") + .WithControlInput(e)); + NodeBuilder node_builder1("F1", "F1", lib_def.get()); + node_builder1.Input(a).Input(b); + Node* call1 = b2.opts().FinalizeBuilder(&node_builder1); + Node* s1 = NoOp( + b2.opts().WithName("F1_sequencer").WithControlInputs({recv1, send1})); + + Binary(e, call1, b2.opts().WithName("G").WithControlInput(s1)); + TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); + } + + TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef); + TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library); +} + +// Test with one outside_compilation cluster that has no outputs from the +// compiled subgraph. +TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputsOrOutputs) { + FunctionDefLibrary library; + GraphDef graphdef; + + { + GraphDefBuilder b1(GraphDefBuilder::kFailImmediately); + Node* a = Input(b1.opts().WithName("A")); + Node* b = Input(b1.opts().WithName("B")); + Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1")); + Node* d = + Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1")); + Node* e = Unary(a, b1.opts() + .WithName("E") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + Node* f = Unary(d, b1.opts().WithName("F").WithAttr("_encapsulate", "F1")); + Binary(e, f, b1.opts().WithName("G")); + TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); + } + + TF_EXPECT_OK(Encapsulate(&graphdef, &library)); + + FunctionDefLibrary library_expected; + GraphDef graphdef_expected; + + *library_expected.add_function() = FunctionDefHelper::Create( + "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval:float"}, {}, + { + {{"C"}, "UnaryTest", {"a_0_arg"}}, + {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}}, + {{"F"}, "UnaryTest", {"D:o:0"}}, + }, + {{"f_0_retval", "F:o:0"}}); + + { + std::unique_ptr lib_def( + new FunctionLibraryDefinition(OpRegistry::Global(), library_expected)); + GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get()); + Node* a = Input(b2.opts().WithName("A")); + Node* b = Input(b2.opts().WithName("B")); + + Node* e = Unary(a, b2.opts().WithName("E")); + NodeBuilder node_builder1("F1", "F1", lib_def.get()); + node_builder1.Input(a).Input(b); + Node* call1 = b2.opts().FinalizeBuilder(&node_builder1); + + Binary(e, call1, b2.opts().WithName("G")); + TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); + } + + TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef); + TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/jit/kernels/BUILD b/tensorflow/compiler/jit/kernels/BUILD index 459a582e157f5ddc63997ca93e7c0294293517d3..9bea5663319c8a25249fdc265cee0191556a7c04 100644 --- a/tensorflow/compiler/jit/kernels/BUILD +++ b/tensorflow/compiler/jit/kernels/BUILD @@ -16,7 +16,6 @@ cc_library( "//tensorflow/compiler/jit:xla_device", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", - "//tensorflow/compiler/tf2xla:xla_local_runtime_context", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_launch_op.cc index e481796d9e626fc8cdf36687ad110b0a8a788be0..4842877d9af332bdaa4a142867dde89ba66bd9a2 100644 --- a/tensorflow/compiler/jit/kernels/xla_launch_op.cc +++ b/tensorflow/compiler/jit/kernels/xla_launch_op.cc @@ -19,7 +19,6 @@ limitations under the License. #include "tensorflow/compiler/jit/xla_device.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" -#include "tensorflow/compiler/tf2xla/xla_local_runtime_context.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" @@ -103,7 +102,6 @@ xla::StatusOr XlaAllocator::Allocate( } void* data = reinterpret_cast(const_cast(t.tensor_data().data())); - TF_RET_CHECK(data != nullptr); tensors_[data] = t; return gpu::DeviceMemoryBase(data, size); } @@ -111,7 +109,6 @@ xla::StatusOr XlaAllocator::Allocate( Status XlaAllocator::RegisterArgument(const Tensor* t) { void* data = reinterpret_cast(const_cast(t->tensor_data().data())); - TF_RET_CHECK(data != nullptr); tensors_[data] = *t; return Status::OK(); } @@ -260,14 +257,15 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) { const XlaCompiler::CompilationResult* kernel; xla::LocalExecutable* executable; + OP_REQUIRES_OK(ctx, cache->Compile(options, function_, num_constant_args_, - variables, ctx, &kernel, &executable)); + variables, ctx, &kernel, &executable, + /*compile_options=*/nullptr)); VLOG(1) << "Executing XLA Computation..."; // Builds an XLA allocator for the device. XlaAllocator xla_allocator(client->platform(), ctx); - XlaLocalRuntimeContext local_runtime_context; std::unique_ptr output; // Build xla::ShapedBuffers that point directly to the Tensor buffers. @@ -291,27 +289,22 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) { gpu::DeviceMemoryBase dmem = gpu::DeviceMemoryBase( const_cast(t->tensor_data().data()), t->tensor_data().size()); - arg_buffers[i] = - xla::ShapedBuffer::MakeArrayShapedBuffer( - shape, client->platform(), client->default_device_ordinal(), dmem) - .ConsumeValueOrDie(); + const xla::Shape on_device_shape = + client->backend().transfer_manager()->HostShapeToDeviceShape(shape); + CHECK(xla::ShapeUtil::Equal(shape, on_device_shape)) + << "On-device shape " + << xla::ShapeUtil::HumanStringWithLayout(on_device_shape) + << " not the same as on-host shape " + << xla::ShapeUtil::HumanStringWithLayout(shape); + arg_buffers[i] = xla::MakeUnique( + /*on_host_shape=*/shape, /*on_device_shape=*/shape, client->platform(), + client->default_device_ordinal()); + arg_buffers[i]->set_buffer(dmem, /*index=*/{}); arg_ptrs[i] = arg_buffers[i].get(); OP_REQUIRES_OK(ctx, xla_allocator.RegisterArgument(t)); } - // Make the final parameter point at local_runtime_context. - if (kernel->requires_runtime_context) { - gpu::DeviceMemoryBase local_runtime_context_dmem( - &local_runtime_context, sizeof(local_runtime_context)); - arg_buffers.push_back( - xla::ShapedBuffer::MakeArrayShapedBuffer( - xla::ShapeUtil::MakeOpaqueShape(), client->platform(), - client->default_device_ordinal(), local_runtime_context_dmem) - .ConsumeValueOrDie()); - arg_ptrs.push_back(arg_buffers.back().get()); - } - // Execute the computation. VLOG(2) << "Executing computation."; xla::ExecutableRunOptions run_options; @@ -323,19 +316,13 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) { auto run_result = executable->Run(arg_ptrs, run_options); OP_REQUIRES(ctx, run_result.ok(), run_result.status()); - if (local_runtime_context.error) { - ctx->CtxFailure(errors::InvalidArgument("Compiled kernel returned error: ", - local_runtime_context.error_msg)); - return; - } - output = run_result.ConsumeValueOrDie()->release(); auto elapsed = env->NowMicros() - start_time; VLOG(2) << "Elapsed time: " << elapsed << "us"; // Computation output should always be a tuple. if (VLOG_IS_ON(2)) { - VLOG(2) << "Result tuple shape: " << output->shape().DebugString(); + VLOG(2) << "Result tuple shape: " << output->on_host_shape().DebugString(); } CHECK_EQ(ctx->num_outputs(), kernel->outputs.size()); diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 74c9791f5eaf1fbc43b152520df496a3b552af18..79b02baba83cb47f4f2f16544ad711a4b6937d90 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -41,6 +41,7 @@ limitations under the License. namespace tensorflow { const char* const kXlaClusterAttr = "_XlaCluster"; +const char* const kXlaOutsideCompilationAttr = "_XlaOutsideCompilation"; namespace { @@ -172,10 +173,15 @@ bool HasResourceInputOrOutput(const Node& node) { DT_RESOURCE) != node.output_types().end(); } +struct NodeCompare { + bool operator()(const Node* a, const Node* b) { return a->id() < b->id(); } +}; +using OrderedNodeSet = std::set; + Status FindCompilationCandidates( const Graph& graph, FunctionLibraryDefinition* flib_def, Env* env, const std::function& is_compilable_fn, - std::unordered_set* candidates) { + OrderedNodeSet* candidates) { OptimizerOptions opts; std::unique_ptr pflr( new ProcessFunctionLibraryRuntime(nullptr, env, TF_GRAPH_DEF_VERSION, @@ -210,6 +216,13 @@ Status FindCompilationCandidates( !IsCompilableWhile(*node, jit_device_type, 0, lib_runtime)) { continue; } + // _Retval nodes in a top-level function represent fetches. + // Do not compile them. + if (node->type_string() == "_Retval") { + VLOG(2) << "Compilation rejected node: return value " << node->name() + << ": " << node->type_string(); + continue; + } candidates->insert(node); } return Status::OK(); @@ -347,7 +360,7 @@ Status MarkForCompilationPass::RunImpl( Graph* graph = options.graph->get(); - std::unordered_set compilation_candidates; + OrderedNodeSet compilation_candidates; TF_RETURN_IF_ERROR(FindCompilationCandidates( *graph, options.flib_def, (options.session_options != nullptr) ? options.session_options->env diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.h b/tensorflow/compiler/jit/mark_for_compilation_pass.h index f91695800f585f37b72173d5e582c38b1154b69b..e9acbfb19e42cb43cb0b986c438a569de29b2ebc 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.h +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.h @@ -28,6 +28,10 @@ namespace tensorflow { // encapsulate subgraphs pass. extern const char* const kXlaClusterAttr; +// The attribute that marks nodes in a cluster to be placed outside the xla +// compilation by the encapsulate subgraphs pass. +extern const char* const kXlaOutsideCompilationAttr; + // Pass that marks a subset of operators in the graph with attribute // _XlaCluster so they are compiled by the EncapsulateSubgraphsPass. class MarkForCompilationPass : public GraphOptimizationPass { diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc index b3d258aea177fbefa4bae51d8156da2ff86c9032..454f0aeae98d7afd51f12b2cfb1810de275a57f7 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc @@ -525,5 +525,32 @@ TEST(XlaCompilationTest, IllegalCycle_UsefulErrorMessage) { "+-- c\n")); } +TEST(XlaCompilationTest, Retval) { + std::unique_ptr graph(new Graph(OpRegistry::Global())); + GraphDef graphdef; + { + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + Node* a = ops::SourceOp("Const", builder.opts() + .WithName("A") + .WithAttr("dtype", DT_FLOAT) + .WithAttr("value", Tensor())); + Node* b = ops::UnaryOp("Relu", a, builder.opts().WithName("B")); + ops::UnaryOp("_Retval", b, + builder.opts() + .WithName("R") + .WithAttr("T", DT_FLOAT) + .WithAttr("index", 0)); + + TF_EXPECT_OK(builder.ToGraph(graph.get())); + } + + TF_ASSERT_OK(MarkForCompilation(&graph)); + auto clusters = GetClusters(*graph); + + EXPECT_EQ(2, clusters.size()); + EXPECT_TRUE(clusters.find("R") == clusters.cend()); + EXPECT_EQ(clusters["A"], clusters["B"]); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc index bc2eccd2779b9ff68ae2121f7bc53d6f74aec3e3..bfff52c55a7d5a4490224347019db9b3333f7e2e 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.cc +++ b/tensorflow/compiler/jit/xla_compilation_cache.cc @@ -214,17 +214,12 @@ Status XlaCompilationCache::BuildExecutable( const XlaCompiler::CompilationResult& result, std::unique_ptr* executable) { VLOG(2) << "Compiling to local executable"; - xla::Shape opaque_shape = xla::ShapeUtil::MakeOpaqueShape(); std::vector argument_layouts( result.xla_input_shapes.size()); for (int i = 0; i < result.xla_input_shapes.size(); ++i) { argument_layouts[i] = &result.xla_input_shapes[i]; } - if (result.requires_runtime_context) { - // The final arg is the XlaLocalRuntimeContext*. - argument_layouts.push_back(&opaque_shape); - } xla::ExecutableBuildOptions build_options; build_options.set_device_ordinal(client_->default_device_ordinal()); build_options.set_result_layout(result.xla_output_shape); @@ -243,7 +238,8 @@ Status XlaCompilationCache::Compile( int num_constant_args, const std::vector& variable_args, OpKernelContext* ctx, const XlaCompiler::CompilationResult** compilation_result, - xla::LocalExecutable** executable) { + xla::LocalExecutable** executable, + const XlaCompiler::CompileOptions* compile_options) { VLOG(1) << "XlaCompilationCache::Compile " << DebugString(); if (VLOG_IS_ON(2)) { @@ -302,9 +298,9 @@ Status XlaCompilationCache::Compile( XlaCompiler compiler(options); entry->compiled = true; - entry->compilation_status = - compiler.CompileFunction(XlaCompiler::CompileOptions(), function, args, - &entry->compilation_result); + entry->compilation_status = compiler.CompileFunction( + compile_options ? *compile_options : XlaCompiler::CompileOptions(), + function, args, &entry->compilation_result); } *compilation_result = &entry->compilation_result; if (entry->compilation_status.ok() && executable) { diff --git a/tensorflow/compiler/jit/xla_compilation_cache.h b/tensorflow/compiler/jit/xla_compilation_cache.h index c3a8f68a157a2d34d4a6716c9951b2b698aead79..0858020716fcf4763e42dc0699ad22cfda756942 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.h +++ b/tensorflow/compiler/jit/xla_compilation_cache.h @@ -66,7 +66,8 @@ class XlaCompilationCache : public ResourceBase { const std::vector& variable_args, OpKernelContext* ctx, const XlaCompiler::CompilationResult** compilation_result, - xla::LocalExecutable** executable); + xla::LocalExecutable** executable, + const XlaCompiler::CompileOptions* compile_options); xla::LocalClient* client() const { return client_; } const DeviceType& device_type() const { return device_type_; } diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc index fed2c92d763c33aad3c5b3f07c1f33364c797793..c936222f32056e92efced82d5adb3a96c8041a17 100644 --- a/tensorflow/compiler/jit/xla_device_context.cc +++ b/tensorflow/compiler/jit/xla_device_context.cc @@ -71,12 +71,14 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor, void* dst_ptr = DMAHelper::base(device_tensor); se::DeviceMemoryBase dev_dst_ptr(dst_ptr, total_bytes); - Status status = Status::OK(); + Status status; stream_->ThenMemcpy(&dev_dst_ptr, src_ptr, total_bytes); // TODO(hpucha): Make this asynchronous. - if (!stream_->BlockHostUntilDone()) { + Status block_status = stream_->BlockHostUntilDone(); + if (!block_status.ok()) { status = xla::InternalError( - "Failed to complete data transfer on stream %p", stream_); + "Failed to complete data transfer on stream %p: %s", stream_, + block_status.error_message().c_str()); } done(status); @@ -105,12 +107,14 @@ void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor, se::DeviceMemoryBase dev_src_ptr(src_ptr, total_bytes); void* dst_ptr = DMAHelper::base(cpu_tensor); - Status status = Status::OK(); + Status status; stream_->ThenMemcpy(dst_ptr, dev_src_ptr, total_bytes); // TODO(hpucha): Make this asynchronous. - if (!stream_->BlockHostUntilDone()) { + Status block_status = stream_->BlockHostUntilDone(); + if (!block_status.ok()) { status = xla::InternalError( - "Failed to complete data transfer on stream %p", stream_); + "Failed to complete data transfer on stream %p: %s", stream_, + block_status.error_message().c_str()); } done(status); diff --git a/tensorflow/compiler/plugin/BUILD b/tensorflow/compiler/plugin/BUILD index c1edf2448c54ffddd7b70dcdfb1609080ca81b65..da4bc44c7a75c9f8faf16c537a17a1f2d16d5d61 100644 --- a/tensorflow/compiler/plugin/BUILD +++ b/tensorflow/compiler/plugin/BUILD @@ -41,6 +41,15 @@ cc_library( ], ) +# This target is added purely for the purpose of ensuring that `:xla_device` is +# always publicly visible to external XLA backend/plugin developers. +cc_library( + name = "plugin_device", + deps = [ + "//tensorflow/compiler/jit:xla_device", + ], +) + #----------------------------------------------------------------------------- filegroup( diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 6cad2b0824d86a9549cb77518448a7e4eb781bef..f7c6cd293a8a4788bd73cc42c5c61e60d4a2c110 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -240,6 +240,23 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "fft_test", + size = "medium", + srcs = ["fft_test.py"], + shard_count = 3, + tags = ["optonly"], + deps = [ + ":xla_test", + "//tensorflow/contrib/signal:signal_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:extra_py_tests_deps", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:platform_test", + "//tensorflow/python:spectral_ops", + ], +) + tf_xla_py_test( name = "slice_ops_test", size = "small", @@ -279,6 +296,22 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "image_ops_test", + size = "small", + srcs = ["image_ops_test.py"], + tags = [ + "optonly", # Times out frequently in fastbuild mode. + ], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:image_ops", + "//tensorflow/python:platform_test", + ], +) + tf_xla_py_test( name = "lrn_ops_test", size = "medium", @@ -367,7 +400,14 @@ tf_xla_py_test( size = "small", srcs = ["random_ops_test.py"], # TODO(b/31361304): enable RNG ops on GPU when parallelized. - disabled_backends = ["gpu"], + disabled_backends = [ + "gpu", + ], + tags = [ + "manual", + "no_oss", + "notap", + ], deps = [ ":xla_test", "//tensorflow/python:framework_for_generated_wrappers", @@ -416,6 +456,20 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "scan_ops_test", + size = "small", + srcs = ["scan_ops_test.py"], + tags = ["optonly"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + ], +) + tf_xla_py_test( name = "segment_reduction_ops_test", size = "medium", diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py index 654dc15e86b21c7742d49281d53c1a75e6a45d3b..65706b35d616eb4dce94f0a7056a1604a97ff4c1 100644 --- a/tensorflow/compiler/tests/binary_ops_test.py +++ b/tensorflow/compiler/tests/binary_ops_test.py @@ -94,14 +94,12 @@ class BinaryOpsTest(XLATestCase): dtype(4), expected=np.array([[16], [81]], dtype=dtype)) - atan2_supported = self.device == "XLA_GPU" - if atan2_supported: - self._testBinary( - math_ops.atan2, - np.array([0, np.sqrt(2), 1, np.sqrt(2), 0], dtype), - np.array([1, np.sqrt(2), 0, -np.sqrt(2), -1], dtype), - expected=np.array( - [0, np.pi / 4, np.pi / 2, np.pi * 3 / 4, np.pi], dtype=dtype)) + self._testBinary( + math_ops.atan2, + np.array([0, np.sqrt(2), 1, np.sqrt(2), 0], dtype), + np.array([1, np.sqrt(2), 0, -np.sqrt(2), -1], dtype), + expected=np.array( + [0, np.pi / 4, np.pi / 2, np.pi * 3 / 4, np.pi], dtype=dtype)) self._testBinary( gen_math_ops._reciprocal_grad, @@ -388,30 +386,28 @@ class BinaryOpsTest(XLATestCase): ], dtype=dtype)) - atan2_supported = self.device == "XLA_GPU" - if atan2_supported: - self._testBinary( - math_ops.pow, - dtype(3 + 2j), - dtype(4 - 5j), - expected=np.power(dtype(3 + 2j), dtype(4 - 5j))) - self._testBinary( # empty rhs - math_ops.pow, - np.array([1 + 2j, 2 - 3j], dtype=dtype), - np.zeros(shape=[0, 2], dtype=dtype), - expected=np.zeros(shape=[0, 2], dtype=dtype)) - self._testBinary( # to zero power - math_ops.pow, - np.array([1 + 2j, 2 - 3j], dtype=dtype), - np.zeros(shape=[1, 2], dtype=dtype), - expected=np.ones(shape=[1, 2], dtype=dtype)) - lhs = np.array([1 - 2j, 4 + 3j, 2 - 3j, 3, 2j, 1, 4], dtype=dtype) - rhs = np.array([2, 3j, 3 + 4j, 2 + 3j, 3 - 2j, 2, 3 + 3j], dtype=dtype) - scalar = dtype(2 + 2j) - self._testBinary(math_ops.pow, lhs, rhs, expected=np.power(lhs, rhs)) - self._testBinary( - math_ops.pow, scalar, rhs, expected=np.power(scalar, rhs)) - self._testBinary(math_ops.pow, lhs, scalar, np.power(lhs, scalar)) + self._testBinary( + math_ops.pow, + dtype(3 + 2j), + dtype(4 - 5j), + expected=np.power(dtype(3 + 2j), dtype(4 - 5j))) + self._testBinary( # empty rhs + math_ops.pow, + np.array([1 + 2j, 2 - 3j], dtype=dtype), + np.zeros(shape=[0, 2], dtype=dtype), + expected=np.zeros(shape=[0, 2], dtype=dtype)) + self._testBinary( # to zero power + math_ops.pow, + np.array([1 + 2j, 2 - 3j], dtype=dtype), + np.zeros(shape=[1, 2], dtype=dtype), + expected=np.ones(shape=[1, 2], dtype=dtype)) + lhs = np.array([1 - 2j, 4 + 3j, 2 - 3j, 3, 2j, 1, 4], dtype=dtype) + rhs = np.array([2, 3j, 3 + 4j, 2 + 3j, 3 - 2j, 2, 3 + 3j], dtype=dtype) + scalar = dtype(2 + 2j) + self._testBinary(math_ops.pow, lhs, rhs, expected=np.power(lhs, rhs)) + self._testBinary( + math_ops.pow, scalar, rhs, expected=np.power(scalar, rhs)) + self._testBinary(math_ops.pow, lhs, scalar, np.power(lhs, scalar)) lhs = np.array([4 + 2j, -3 - 1j, 2j, 1], dtype=dtype) rhs = np.array([5, -6j, 7 - 3j, -8j], dtype=dtype) @@ -421,9 +417,8 @@ class BinaryOpsTest(XLATestCase): self._testBinary( gen_math_ops._sigmoid_grad, lhs, rhs, expected=rhs * lhs * (1 - lhs)) - if atan2_supported: - self._testBinary( - gen_math_ops._rsqrt_grad, lhs, rhs, expected=lhs**3 * rhs / -2) + self._testBinary( + gen_math_ops._rsqrt_grad, lhs, rhs, expected=lhs**3 * rhs / -2) self._testBinary( gen_math_ops._sqrt_grad, lhs, rhs, expected=rhs / (2 * lhs)) @@ -547,7 +542,7 @@ class BinaryOpsTest(XLATestCase): self._testDivision(dtype) def testFloatDivision(self): - for dtype in self.float_types + self.complex_types: + for dtype in self.float_types | self.complex_types: self._testDivision(dtype) def _testRemainder(self, dtype): diff --git a/tensorflow/compiler/tests/categorical_op_test.py b/tensorflow/compiler/tests/categorical_op_test.py index 5e06f9a72401935b9681c35a164b51f50a8538ae..035cdea1786d39f3d21bb63be5c8ccffe1608bdf 100644 --- a/tensorflow/compiler/tests/categorical_op_test.py +++ b/tensorflow/compiler/tests/categorical_op_test.py @@ -35,6 +35,9 @@ from tensorflow.python.platform import googletest class CategoricalTest(XLATestCase): """Test cases for random-number generating operators.""" + def output_dtypes(self): + return set(self.int_types).intersection([np.int32, np.int64]) + def _chi2(self, expected, actual): """Returns Chi2 GOF statistic.""" actual = np.asarray(actual) @@ -55,7 +58,8 @@ class CategoricalTest(XLATestCase): """ with self.test_session() as sess, self.test_scope(): random_seed.set_random_seed(1618) - op = random_ops.multinomial(logits, num_samples) + op = random_ops.multinomial(logits, num_samples, + output_dtype=dtypes.int32) d = sess.run(op) batch_size, num_classes = logits.shape @@ -73,11 +77,11 @@ class CategoricalTest(XLATestCase): return freqs_mat - def _testRngIsNotConstant(self, rng, dtype): + def _testRngIsNotConstant(self, rng, dtype, output_dtype): # Tests that 'rng' does not always return the same value. with self.test_session() as sess: with self.test_scope(): - x = rng(dtype) + x = rng(dtype, output_dtype) # The random-number generator, if working correctly, should produce the # same output multiple times with low probability. @@ -92,21 +96,25 @@ class CategoricalTest(XLATestCase): (not np.array_equal(y, w))) def testCategoricalIsNotConstant(self): - def rng(unused_dtype): - return random_ops.multinomial([[1., 1., 1.]], 10) + def rng(dtype, output_dtype): + return random_ops.multinomial(np.array([[1., 1., 1.]], dtype=dtype), 10, + output_dtype=output_dtype) - dtype = dtypes.float32 - self._testRngIsNotConstant(rng, dtype) + dtype = np.float32 + for output_dtype in self.output_dtypes(): + self._testRngIsNotConstant(rng, dtype, output_dtype) def testCategoricalIsInRange(self): - for dtype in [dtypes.float32, dtypes.float64]: - with self.test_session() as sess: - with self.test_scope(): - x = random_ops.multinomial( - array_ops.ones(shape=[1, 20], dtype=dtype), 1000) - y = sess.run(x) - self.assertTrue((y >= 0).sum() == 1000) - self.assertTrue((y < 20).sum() == 1000) + for dtype in self.float_types: + for output_dtype in self.output_dtypes(): + with self.test_session() as sess: + with self.test_scope(): + x = random_ops.multinomial( + array_ops.ones(shape=[1, 20], dtype=dtype), 1000, + output_dtype=output_dtype) + y = sess.run(x) + self.assertTrue((y >= 0).sum() == 1000) + self.assertTrue((y < 20).sum() == 1000) def testSamplingCorrectness(self): np.random.seed(1618) # Make it reproducible. diff --git a/tensorflow/compiler/tests/conv2d_test.py b/tensorflow/compiler/tests/conv2d_test.py index 0d617eb37c5d92c87abb0f996b731112257a2b80..62577b70ce96e220d79978f01614b2d9a3647680 100644 --- a/tensorflow/compiler/tests/conv2d_test.py +++ b/tensorflow/compiler/tests/conv2d_test.py @@ -34,7 +34,13 @@ from tensorflow.python.platform import googletest class Conv2DTest(XLATestCase): - def _VerifyValues(self, input_sizes, filter_sizes, stride, padding, expected): + def _VerifyValues(self, + input_sizes=None, + filter_sizes=None, + strides=None, + dilations=None, + padding=None, + expected=None): """Tests that tf.nn.conv2d produces the expected value. Args: @@ -42,7 +48,8 @@ class Conv2DTest(XLATestCase): [batch, input_rows, input_cols, input_depth]. filter_sizes: Filter tensor dimensions in [kernel_rows, kernel_cols, input_depth, output_depth]. - stride: Stride. + strides: Strides. + dilations: RHS dilations. padding: Padding type. expected: Expected output. """ @@ -50,73 +57,136 @@ class Conv2DTest(XLATestCase): total_size_2 = np.prod(filter_sizes) x1 = np.arange(1, total_size_1 + 1, dtype=np.float32).reshape(input_sizes) x2 = np.arange(1, total_size_2 + 1, dtype=np.float32).reshape(filter_sizes) - strides = [1, stride, stride, 1] + strides = [1] + strides + [1] + if dilations is None: + dilations = [1, 1] + dilations = [1] + dilations + [1] with self.test_session() as sess: + t1 = array_ops.placeholder(dtypes.float32, shape=input_sizes) + t2 = array_ops.placeholder(dtypes.float32, shape=filter_sizes) with self.test_scope(): - t1 = array_ops.placeholder(dtypes.float32, shape=input_sizes) - t2 = array_ops.placeholder(dtypes.float32, shape=filter_sizes) out = nn_ops.conv2d( - t1, t2, strides=strides, padding=padding, data_format="NHWC") + t1, + t2, + strides=strides, + padding=padding, + data_format="NHWC", + dilations=dilations) value = sess.run(out, {t1: x1, t2: x2}) - self.assertArrayNear(expected, np.ravel(value), 1e-3) + self.assertAllClose(expected, value, 1e-3) def testConv2D1x1Filter(self): - expected_output = [ + expected_output = np.reshape([ 30.0, 36.0, 42.0, 66.0, 81.0, 96.0, 102.0, 126.0, 150.0, 138.0, 171.0, 204.0, 174.0, 216.0, 258.0, 210.0, 261.0, 312.0 - ] + ], [1, 2, 3, 3]) self._VerifyValues( input_sizes=[1, 2, 3, 3], filter_sizes=[1, 1, 3, 3], - stride=1, + strides=[1, 1], padding="VALID", expected=expected_output) def testConv2D2x2Filter(self): - expected_output = [2271.0, 2367.0, 2463.0, 2901.0, 3033.0, 3165.0] + expected_output = np.reshape( + [2271.0, 2367.0, 2463.0, 2901.0, 3033.0, 3165.0], [1, 1, 2, 3]) self._VerifyValues( input_sizes=[1, 2, 3, 3], filter_sizes=[2, 2, 3, 3], - stride=1, + strides=[1, 1], + padding="VALID", + expected=expected_output) + + def testConv2D2x2Filter2x1Dilation(self): + expected_output = np.array([[[[72], [82], [92]], [[112], [122], [132]]]]) + self._VerifyValues( + input_sizes=[1, 4, 4, 1], + filter_sizes=[2, 2, 1, 1], + strides=[1, 1], + dilations=[2, 1], padding="VALID", expected=expected_output) def testConv2D1x2Filter(self): - expected_output = [ + expected_output = np.reshape([ 231.0, 252.0, 273.0, 384.0, 423.0, 462.0, 690.0, 765.0, 840.0, 843.0, 936.0, 1029.0 - ] + ], [1, 2, 2, 3]) self._VerifyValues( input_sizes=[1, 2, 3, 3], filter_sizes=[1, 2, 3, 3], - stride=1, + strides=[1, 1], padding="VALID", expected=expected_output) def testConv2D2x2FilterStride2(self): - expected_output = [2271.0, 2367.0, 2463.0] + expected_output = np.reshape([2271.0, 2367.0, 2463.0], [1, 1, 1, 3]) self._VerifyValues( input_sizes=[1, 2, 3, 3], filter_sizes=[2, 2, 3, 3], - stride=2, + strides=[2, 2], padding="VALID", expected=expected_output) def testConv2D2x2FilterStride2Same(self): - expected_output = [2271.0, 2367.0, 2463.0, 1230.0, 1305.0, 1380.0] + expected_output = np.reshape( + [2271.0, 2367.0, 2463.0, 1230.0, 1305.0, 1380.0], [1, 1, 2, 3]) self._VerifyValues( input_sizes=[1, 2, 3, 3], filter_sizes=[2, 2, 3, 3], - stride=2, + strides=[2, 2], padding="SAME", expected=expected_output) + def testConv2DEmptyDilation(self): + self._VerifyValues( + input_sizes=[0, 2, 3, 3], + filter_sizes=[1, 1, 3, 3], + strides=[1, 1], + dilations=[2, 1], + padding="VALID", + expected=np.zeros([0, 2, 3, 3])) + + def testConv2D2x2FilterDilation(self): + self._VerifyValues( + input_sizes=[1, 2, 3, 3], + filter_sizes=[2, 2, 3, 3], + strides=[1, 1], + dilations=[1, 2], + padding="VALID", + expected=np.reshape([2667, 2781, 2895], [1, 1, 1, 3])) + + def testConv2D1x2FilterDilation(self): + self._VerifyValues( + input_sizes=[1, 2, 3, 3], + filter_sizes=[1, 2, 3, 3], + strides=[1, 1], + dilations=[2, 1], + padding="VALID", + expected=np.array([[[[231, 252, 273], [384, 423, 462]], + [[690, 765, 840], [843, 936, 1029]]]])) + + def testConv2DKernelSizeMatchesInputSizeDilation(self): + self._VerifyValues( + input_sizes=[1, 3, 3, 1], + filter_sizes=[2, 2, 1, 2], + strides=[1, 1], + dilations=[2, 2], + padding="VALID", + expected=np.reshape([108, 128], [1, 1, 1, 2])) + class Conv2DBackpropInputTest(XLATestCase): - def _VerifyValues(self, input_sizes, filter_sizes, out_backprop_sizes, stride, - padding, expected): + def _VerifyValues(self, + input_sizes=None, + filter_sizes=None, + out_backprop_sizes=None, + strides=None, + dilations=None, + padding=None, + expected=None): """Tests that gen_nn_ops.conv2d_backprop_input produces the expected output. Args: @@ -125,7 +195,8 @@ class Conv2DBackpropInputTest(XLATestCase): filter_sizes: Filter tensor dimensions in [kernel_rows, kernel_cols, input_depth, output_depth]. out_backprop_sizes: Output gradients tensor dimensions. - stride: Stride. + strides: Strides. + dilations: Dilations. padding: Padding type. expected: Expected output. """ @@ -134,21 +205,25 @@ class Conv2DBackpropInputTest(XLATestCase): x1 = np.arange(1, total_size_1 + 1, dtype=np.float32).reshape(filter_sizes) x2 = np.arange( 1, total_size_2 + 1, dtype=np.float32).reshape(out_backprop_sizes) - strides = [1, stride, stride, 1] + strides = [1] + strides + [1] + if dilations is not None: + dilations = [1] + dilations + [1] with self.test_session() as sess: + t1 = array_ops.placeholder(dtypes.float32, shape=filter_sizes) + t2 = array_ops.placeholder(dtypes.float32, shape=out_backprop_sizes) with self.test_scope(): - t1 = array_ops.placeholder(dtypes.float32, shape=filter_sizes) - t2 = array_ops.placeholder(dtypes.float32, shape=out_backprop_sizes) out = gen_nn_ops.conv2d_backprop_input( input_sizes=input_sizes, filter=t1, out_backprop=t2, strides=strides, + dilations=dilations, padding=padding, data_format="NHWC") value = sess.run(out, {t1: x1, t2: x2}) - self.assertArrayNear(expected, np.ravel(value), 1e-3) + self.assertAllEqual(input_sizes, value.shape) + self.assertAllClose(expected, np.ravel(value), 1e-3) def testConv2D1x1Filter(self): expected_output = [ @@ -160,7 +235,7 @@ class Conv2DBackpropInputTest(XLATestCase): input_sizes=[1, 4, 4, 3], filter_sizes=[1, 1, 3, 2], out_backprop_sizes=[1, 4, 4, 2], - stride=1, + strides=[1, 1], padding="VALID", expected=expected_output) @@ -170,7 +245,7 @@ class Conv2DBackpropInputTest(XLATestCase): input_sizes=[1, 1, 5, 1], filter_sizes=[1, 2, 1, 1], out_backprop_sizes=[1, 1, 2, 1], - stride=3, + strides=[3, 3], padding="VALID", expected=expected_output) @@ -180,7 +255,7 @@ class Conv2DBackpropInputTest(XLATestCase): input_sizes=[1, 1, 6, 1], filter_sizes=[1, 2, 1, 1], out_backprop_sizes=[1, 1, 2, 1], - stride=3, + strides=[3, 3], padding="VALID", expected=expected_output) @@ -190,7 +265,7 @@ class Conv2DBackpropInputTest(XLATestCase): input_sizes=[1, 1, 7, 1], filter_sizes=[1, 2, 1, 1], out_backprop_sizes=[1, 1, 2, 1], - stride=3, + strides=[3, 3], padding="VALID", expected=expected_output) @@ -200,7 +275,7 @@ class Conv2DBackpropInputTest(XLATestCase): input_sizes=[1, 2, 3, 1], filter_sizes=[2, 2, 1, 1], out_backprop_sizes=[1, 2, 3, 1], - stride=1, + strides=[1, 1], padding="SAME", expected=expected_output) @@ -213,7 +288,7 @@ class Conv2DBackpropInputTest(XLATestCase): input_sizes=[1, 2, 3, 3], filter_sizes=[2, 2, 3, 3], out_backprop_sizes=[1, 1, 2, 3], - stride=1, + strides=[1, 1], padding="VALID", expected=expected_output) @@ -226,7 +301,7 @@ class Conv2DBackpropInputTest(XLATestCase): input_sizes=[1, 2, 3, 3], filter_sizes=[2, 2, 3, 3], out_backprop_sizes=[1, 2, 3, 3], - stride=1, + strides=[1, 1], padding="SAME", expected=expected_output) @@ -236,7 +311,7 @@ class Conv2DBackpropInputTest(XLATestCase): input_sizes=[1, 3, 3, 1], filter_sizes=[1, 2, 1, 1], out_backprop_sizes=[1, 3, 2, 1], - stride=1, + strides=[1, 1], padding="VALID", expected=expected_output) @@ -246,7 +321,7 @@ class Conv2DBackpropInputTest(XLATestCase): input_sizes=[1, 3, 3, 1], filter_sizes=[1, 2, 1, 1], out_backprop_sizes=[1, 3, 3, 1], - stride=1, + strides=[1, 1], padding="SAME", expected=expected_output) @@ -256,7 +331,7 @@ class Conv2DBackpropInputTest(XLATestCase): input_sizes=[1, 3, 5, 1], filter_sizes=[1, 3, 1, 1], out_backprop_sizes=[1, 2, 2, 1], - stride=2, + strides=[2, 2], padding="VALID", expected=expected_output) @@ -266,15 +341,76 @@ class Conv2DBackpropInputTest(XLATestCase): input_sizes=[1, 2, 3, 1], filter_sizes=[2, 2, 1, 1], out_backprop_sizes=[1, 1, 2, 1], - stride=2, + strides=[2, 2], padding="SAME", expected=expected_output) + def testConv2D2x2Depth3ValidBackpropInputStride1x1Dilation2x1(self): + self._VerifyValues( + input_sizes=[1, 3, 6, 1], + filter_sizes=[2, 2, 1, 1], + out_backprop_sizes=[1, 1, 5, 1], + strides=[1, 1], + dilations=[2, 1], + padding="VALID", + expected=[1, 4, 7, 10, 13, 10, 0, 0, 0, 0, 0, 0, 3, 10, 17, 24, 31, 20]) + + def testConv2D2x2Depth1ValidBackpropInputDilation1x2(self): + self._VerifyValues( + input_sizes=[1, 2, 3, 1], + filter_sizes=[2, 2, 1, 1], + out_backprop_sizes=[1, 1, 1, 1], + strides=[1, 1], + dilations=[1, 2], + padding="VALID", + expected=[1, 0, 2, 3, 0, 4]) + + def testConv2DEmptyBackpropInputDilation1x2(self): + self._VerifyValues( + input_sizes=[0, 2, 3, 1], + filter_sizes=[2, 2, 1, 1], + out_backprop_sizes=[0, 1, 1, 1], + strides=[1, 1], + dilations=[1, 2], + padding="VALID", + expected=np.zeros([0])) + + def testConv2D2x2Depth3ValidBackpropInputDilation2x1(self): + # The GPU version of this test is not very stable. So adjusting the + # error threshold to 1e-4. + self._VerifyValues( + input_sizes=[1, 3, 2, 3], + filter_sizes=[2, 2, 3, 3], + out_backprop_sizes=[1, 1, 1, 3], + strides=[1, 1], + dilations=[2, 1], + padding="VALID", + expected=[ + 14, 32, 50, 68, 86, 104, 0, 0, 0, 0, 0, 0, 122, 140, 158, 176, 194, + 212 + ]) + + def testConv2DKernelSizeMatchesInputSizeBackpropInputDilation2x2(self): + self._VerifyValues( + input_sizes=[1, 3, 3, 1], + filter_sizes=[2, 2, 1, 2], + out_backprop_sizes=[1, 1, 1, 2], + strides=[1, 1], + dilations=[2, 2], + padding="VALID", + expected=[5, 0, 11, 0, 0, 0, 17, 0, 23]) + class Conv2DBackpropFilterTest(XLATestCase): - def _VerifyValues(self, input_sizes, filter_sizes, out_backprop_sizes, stride, - padding, expected): + def _VerifyValues(self, + input_sizes=None, + filter_sizes=None, + out_backprop_sizes=None, + strides=None, + dilations=None, + padding=None, + expected=None): """Tests that gen_nn_ops.conv2d_backprop_filter produces the right output. Args: @@ -283,7 +419,8 @@ class Conv2DBackpropFilterTest(XLATestCase): filter_sizes: Filter tensor dimensions in [kernel_rows, kernel_cols, input_depth, output_depth]. out_backprop_sizes: Output gradients tensor dimensions. - stride: Stride. + strides: Stride. + dilations: Dilations. padding: Padding type. expected: Expected output. """ @@ -293,22 +430,26 @@ class Conv2DBackpropFilterTest(XLATestCase): x1 = np.arange(1, total_size_1 + 1, dtype=np.float32).reshape(input_sizes) x2 = np.arange( 1, total_size_2 + 1, dtype=np.float32).reshape(out_backprop_sizes) - strides = [1, stride, stride, 1] + strides = [1] + strides + [1] + if dilations is not None: + dilations = [1] + dilations + [1] with self.test_session() as sess: + t1 = array_ops.placeholder(dtypes.float32, shape=input_sizes) + t2 = array_ops.placeholder(dtypes.float32, shape=out_backprop_sizes) with self.test_scope(): - t1 = array_ops.placeholder(dtypes.float32, shape=input_sizes) - t2 = array_ops.placeholder(dtypes.float32, shape=out_backprop_sizes) tensor = gen_nn_ops.conv2d_backprop_filter( input=t1, filter_sizes=filter_sizes, out_backprop=t2, strides=strides, + dilations=dilations, padding=padding, data_format="NHWC") value = sess.run(tensor, {t1: x1, t2: x2}) - self.assertArrayNear(expected, np.ravel(value), 1e-3) + self.assertAllEqual(filter_sizes, value.shape) + self.assertAllClose(expected, np.ravel(value), 1e-3) def testConv2D1x1Filter(self): expected_output = [8056, 8432, 8312, 8704, 8568, 8976] @@ -316,7 +457,7 @@ class Conv2DBackpropFilterTest(XLATestCase): input_sizes=[1, 4, 4, 3], filter_sizes=[1, 1, 3, 2], out_backprop_sizes=[1, 4, 4, 2], - stride=1, + strides=[1, 1], padding="VALID", expected=expected_output) @@ -326,7 +467,7 @@ class Conv2DBackpropFilterTest(XLATestCase): input_sizes=[1, 3, 3, 1], filter_sizes=[1, 2, 1, 1], out_backprop_sizes=[1, 3, 2, 1], - stride=1, + strides=[1, 1], padding="VALID", expected=expected_output) @@ -336,7 +477,7 @@ class Conv2DBackpropFilterTest(XLATestCase): input_sizes=[1, 2, 3, 1], filter_sizes=[2, 2, 1, 1], out_backprop_sizes=[1, 1, 2, 1], - stride=1, + strides=[1, 1], padding="VALID", expected=expected_output) @@ -350,7 +491,7 @@ class Conv2DBackpropFilterTest(XLATestCase): input_sizes=[1, 2, 3, 3], filter_sizes=[2, 2, 3, 3], out_backprop_sizes=[1, 1, 2, 3], - stride=1, + strides=[1, 1], padding="VALID", expected=expected_output) @@ -360,7 +501,7 @@ class Conv2DBackpropFilterTest(XLATestCase): input_sizes=[1, 1, 5, 1], filter_sizes=[1, 2, 1, 1], out_backprop_sizes=[1, 1, 2, 1], - stride=3, + strides=[3, 3], padding="VALID", expected=expected_output) @@ -370,7 +511,7 @@ class Conv2DBackpropFilterTest(XLATestCase): input_sizes=[1, 1, 6, 1], filter_sizes=[1, 2, 1, 1], out_backprop_sizes=[1, 1, 2, 1], - stride=3, + strides=[3, 3], padding="VALID", expected=expected_output) @@ -380,7 +521,7 @@ class Conv2DBackpropFilterTest(XLATestCase): input_sizes=[1, 1, 7, 1], filter_sizes=[1, 2, 1, 1], out_backprop_sizes=[1, 1, 2, 1], - stride=3, + strides=[3, 3], padding="VALID", expected=expected_output) @@ -390,7 +531,7 @@ class Conv2DBackpropFilterTest(XLATestCase): input_sizes=[1, 1, 4, 1], filter_sizes=[1, 3, 1, 1], out_backprop_sizes=[1, 1, 2, 1], - stride=1, + strides=[1, 1], padding="VALID", expected=expected_output) @@ -400,7 +541,7 @@ class Conv2DBackpropFilterTest(XLATestCase): input_sizes=[1, 1, 4, 1], filter_sizes=[1, 3, 1, 1], out_backprop_sizes=[1, 1, 4, 1], - stride=1, + strides=[1, 1], padding="SAME", expected=expected_output) @@ -410,7 +551,7 @@ class Conv2DBackpropFilterTest(XLATestCase): input_sizes=[1, 1, 4, 1], filter_sizes=[1, 3, 1, 1], out_backprop_sizes=[1, 1, 2, 1], - stride=2, + strides=[2, 2], padding="SAME", expected=expected_output) @@ -420,7 +561,7 @@ class Conv2DBackpropFilterTest(XLATestCase): input_sizes=[1, 2, 3, 1], filter_sizes=[2, 2, 1, 1], out_backprop_sizes=[1, 2, 3, 1], - stride=1, + strides=[1, 1], padding="SAME", expected=expected_output) @@ -430,7 +571,7 @@ class Conv2DBackpropFilterTest(XLATestCase): input_sizes=[1, 3, 5, 1], filter_sizes=[1, 3, 1, 1], out_backprop_sizes=[1, 2, 2, 1], - stride=2, + strides=[2, 2], padding="VALID", expected=expected_output) @@ -440,10 +581,64 @@ class Conv2DBackpropFilterTest(XLATestCase): input_sizes=[1, 2, 3, 1], filter_sizes=[2, 2, 1, 1], out_backprop_sizes=[1, 1, 2, 1], - stride=2, + strides=[2, 2], padding="SAME", expected=expected_output) + def testConv2D2x2Depth3ValidBackpropFilterStride1x1Dilation2x1(self): + self._VerifyValues( + input_sizes=[1, 3, 6, 1], + filter_sizes=[2, 2, 1, 1], + out_backprop_sizes=[1, 1, 5, 1], + strides=[1, 1], + dilations=[2, 1], + padding="VALID", + expected=[55, 70, 235, 250]) + + def testConv2D2x2Depth1ValidBackpropFilterDilation1x2(self): + self._VerifyValues( + input_sizes=[1, 2, 3, 1], + filter_sizes=[2, 2, 1, 1], + out_backprop_sizes=[1, 1, 1, 1], + strides=[1, 1], + dilations=[1, 2], + padding="VALID", + expected=[1, 3, 4, 6]) + + def testConv2DEmptyBackpropFilterDilation1x2(self): + self._VerifyValues( + input_sizes=[1, 2, 3, 1], + filter_sizes=[2, 2, 1, 0], + out_backprop_sizes=[1, 1, 1, 0], + strides=[1, 1], + dilations=[1, 2], + padding="VALID", + expected=np.zeros([0])) + + def testConv2D2x2Depth3ValidBackpropFilterDilation2x2(self): + self._VerifyValues( + input_sizes=[1, 3, 4, 3], + filter_sizes=[2, 2, 3, 3], + out_backprop_sizes=[1, 1, 2, 3], + strides=[1, 1], + dilations=[2, 2], + padding="VALID", + expected=[ + 17, 22, 27, 22, 29, 36, 27, 36, 45, 47, 64, 81, 52, 71, 90, 57, 78, + 99, 137, 190, 243, 142, 197, 252, 147, 204, 261, 167, 232, 297, 172, + 239, 306, 177, 246, 315 + ]) + + def testConv2DKernelSizeMatchesInputSizeBackpropFilterDilation2x2(self): + self._VerifyValues( + input_sizes=[1, 3, 3, 1], + filter_sizes=[2, 2, 1, 2], + out_backprop_sizes=[1, 1, 1, 2], + strides=[1, 1], + dilations=[2, 2], + padding="VALID", + expected=[1, 2, 3, 6, 7, 14, 9, 18]) + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/compiler/tests/fft_test.py b/tensorflow/compiler/tests/fft_test.py new file mode 100644 index 0000000000000000000000000000000000000000..afb5fa4bb4fefe5bc2ecded826143ffc83c2b559 --- /dev/null +++ b/tensorflow/compiler/tests/fft_test.py @@ -0,0 +1,204 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for FFT via the XLA JIT.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import itertools + +import numpy as np +import scipy.signal as sps + +from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.contrib.signal.python.ops import spectral_ops as signal +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import spectral_ops +from tensorflow.python.platform import googletest + +BATCH_DIMS = (3, 5) +RTOL = 0.02 # Eigen/cuFFT differ widely from np, especially for FFT3D +ATOL = 1e-3 + + +def pick_10(x): + x = list(x) + np.random.seed(123) + np.random.shuffle(x) + return x[:10] + + +def to_32bit(x): + if x.dtype == np.complex128: + return x.astype(np.complex64) + if x.dtype == np.float64: + return x.astype(np.float32) + return x + + +POWS_OF_2 = 2**np.arange(3, 12) +INNER_DIMS_1D = list((x,) for x in POWS_OF_2) +POWS_OF_2 = 2**np.arange(3, 8) # To avoid OOM on GPU. +INNER_DIMS_2D = pick_10(itertools.product(POWS_OF_2, POWS_OF_2)) +INNER_DIMS_3D = pick_10(itertools.product(POWS_OF_2, POWS_OF_2, POWS_OF_2)) + + +class FFTTest(XLATestCase): + + def _VerifyFftMethod(self, inner_dims, complex_to_input, input_to_expected, + tf_method): + for indims in inner_dims: + print("nfft =", indims) + shape = BATCH_DIMS + indims + data = np.arange(np.prod(shape) * 2) / np.prod(indims) + np.random.seed(123) + np.random.shuffle(data) + data = np.reshape(data.astype(np.float32).view(np.complex64), shape) + data = to_32bit(complex_to_input(data)) + expected = to_32bit(input_to_expected(data)) + with self.test_session() as sess: + with self.test_scope(): + ph = array_ops.placeholder( + dtypes.as_dtype(data.dtype), shape=data.shape) + out = tf_method(ph) + value = sess.run(out, {ph: data}) + self.assertAllClose(expected, value, rtol=RTOL, atol=ATOL) + + def testContribSignalSTFT(self): + ws = 512 + hs = 128 + dims = (ws * 20,) + shape = BATCH_DIMS + dims + data = np.arange(np.prod(shape)) / np.prod(dims) + np.random.seed(123) + np.random.shuffle(data) + data = np.reshape(data.astype(np.float32), shape) + window = sps.get_window("hann", ws) + expected = sps.stft( + data, nperseg=ws, noverlap=ws - hs, boundary=None, window=window)[2] + expected = np.swapaxes(expected, -1, -2) + expected *= window.sum() # scipy divides by window sum + with self.test_session() as sess: + with self.test_scope(): + ph = array_ops.placeholder( + dtypes.as_dtype(data.dtype), shape=data.shape) + out = signal.stft(ph, ws, hs) + + value = sess.run(out, {ph: data}) + self.assertAllClose(expected, value, rtol=RTOL, atol=ATOL) + + def testFFT(self): + self._VerifyFftMethod(INNER_DIMS_1D, lambda x: x, np.fft.fft, + spectral_ops.fft) + + def testFFT2D(self): + self._VerifyFftMethod(INNER_DIMS_2D, lambda x: x, np.fft.fft2, + spectral_ops.fft2d) + + def testFFT3D(self): + self._VerifyFftMethod(INNER_DIMS_3D, lambda x: x, + lambda x: np.fft.fftn(x, axes=(-3, -2, -1)), + spectral_ops.fft3d) + + def testIFFT(self): + self._VerifyFftMethod(INNER_DIMS_1D, lambda x: x, np.fft.ifft, + spectral_ops.ifft) + + def testIFFT2D(self): + self._VerifyFftMethod(INNER_DIMS_2D, lambda x: x, np.fft.ifft2, + spectral_ops.ifft2d) + + def testIFFT3D(self): + self._VerifyFftMethod(INNER_DIMS_3D, lambda x: x, + lambda x: np.fft.ifftn(x, axes=(-3, -2, -1)), + spectral_ops.ifft3d) + + def testRFFT(self): + self._VerifyFftMethod( + INNER_DIMS_1D, np.real, lambda x: np.fft.rfft(x, n=x.shape[-1]), + lambda x: spectral_ops.rfft(x, fft_length=[x.shape[-1].value])) + + def testRFFT2D(self): + + def _tf_fn(x): + return spectral_ops.rfft2d( + x, fft_length=[x.shape[-2].value, x.shape[-1].value]) + + self._VerifyFftMethod( + INNER_DIMS_2D, np.real, + lambda x: np.fft.rfft2(x, s=[x.shape[-2], x.shape[-1]]), _tf_fn) + + def testRFFT3D(self): + + def _to_expected(x): + return np.fft.rfftn( + x, axes=(-3, -2, -1), s=[x.shape[-3], x.shape[-2], x.shape[-1]]) + + def _tf_fn(x): + return spectral_ops.rfft3d( + x, + fft_length=[x.shape[-3].value, x.shape[-2].value, x.shape[-1].value]) + + self._VerifyFftMethod(INNER_DIMS_3D, np.real, _to_expected, _tf_fn) + + def testIRFFT(self): + + def _tf_fn(x): + return spectral_ops.irfft(x, fft_length=[2 * (x.shape[-1].value - 1)]) + + self._VerifyFftMethod( + INNER_DIMS_1D, lambda x: np.fft.rfft(np.real(x), n=x.shape[-1]), + lambda x: np.fft.irfft(x, n=2 * (x.shape[-1] - 1)), _tf_fn) + + def testIRFFT2D(self): + + def _tf_fn(x): + return spectral_ops.irfft2d( + x, fft_length=[x.shape[-2].value, 2 * (x.shape[-1].value - 1)]) + + self._VerifyFftMethod( + INNER_DIMS_2D, + lambda x: np.fft.rfft2(np.real(x), s=[x.shape[-2], x.shape[-1]]), + lambda x: np.fft.irfft2(x, s=[x.shape[-2], 2 * (x.shape[-1] - 1)]), + _tf_fn) + + def testIRFFT3D(self): + + def _to_input(x): + return np.fft.rfftn( + np.real(x), + axes=(-3, -2, -1), + s=[x.shape[-3], x.shape[-2], x.shape[-1]]) + + def _to_expected(x): + return np.fft.irfftn( + x, + axes=(-3, -2, -1), + s=[x.shape[-3], x.shape[-2], 2 * (x.shape[-1] - 1)]) + + def _tf_fn(x): + return spectral_ops.irfft3d( + x, + fft_length=[ + x.shape[-3].value, x.shape[-2].value, 2 * (x.shape[-1].value - 1) + ]) + + self._VerifyFftMethod(INNER_DIMS_3D, _to_input, _to_expected, _tf_fn) + + +if __name__ == "__main__": + googletest.main() diff --git a/tensorflow/compiler/tests/ftrl_test.py b/tensorflow/compiler/tests/ftrl_test.py index 7e3871312c86530b6d3cb0bbacc16c25d3469832..f9db4cf2017c0b4b6dc0cfeeda6dca7bb9d14f19 100644 --- a/tensorflow/compiler/tests/ftrl_test.py +++ b/tensorflow/compiler/tests/ftrl_test.py @@ -161,9 +161,9 @@ class FtrlOptimizerTest(XLATestCase): ftrl_update.run() # Validate updated params - self.assertAllClose( + self.assertAllCloseAccordingToType( np.array([-2.55607247, -3.98729396]), var0.eval(), 1e-5, 1e-5) - self.assertAllClose( + self.assertAllCloseAccordingToType( np.array([-0.28232238, -0.56096673]), var1.eval(), 1e-5, 1e-5) def testFtrlWithL1(self): @@ -189,10 +189,10 @@ class FtrlOptimizerTest(XLATestCase): ftrl_update.run() # Validate updated params - self.assertAllClose(np.array([-7.66718769, -10.91273689]), var0.eval(), - rtol=1e-4) - self.assertAllClose(np.array([-0.93460727, -1.86147261]), var1.eval(), - rtol=1e-4) + self.assertAllCloseAccordingToType( + np.array([-7.66718769, -10.91273689]), var0.eval(), rtol=1e-4) + self.assertAllCloseAccordingToType( + np.array([-0.93460727, -1.86147261]), var1.eval(), rtol=1e-4) def testFtrlWithL1_L2(self): for dtype in self.float_types: @@ -217,10 +217,10 @@ class FtrlOptimizerTest(XLATestCase): ftrl_update.run() # Validate updated params - self.assertAllClose(np.array([-0.24059935, -0.46829352]), var0.eval(), - rtol=1e-5) - self.assertAllClose(np.array([-0.02406147, -0.04830509]), var1.eval(), - rtol=1e-5) + self.assertAllCloseAccordingToType( + np.array([-0.24059935, -0.46829352]), var0.eval(), rtol=1e-5) + self.assertAllCloseAccordingToType( + np.array([-0.02406147, -0.04830509]), var1.eval(), rtol=1e-5) def testFtrlWithL1_L2_L2Shrinkage(self): """Test the new FTRL op with support for l2 shrinkage. @@ -244,18 +244,18 @@ class FtrlOptimizerTest(XLATestCase): ftrl_update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) variables.global_variables_initializer().run() # Fetch params to validate initial values - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([4.0, 3.0], var1.eval()) + self.assertAllCloseAccordingToType([1.0, 2.0], var0.eval()) + self.assertAllCloseAccordingToType([4.0, 3.0], var1.eval()) # Run 10 steps FTRL for _ in range(10): ftrl_update.run() # Validate updated params - self.assertAllClose(np.array([-0.21931979, -0.40642974]), var0.eval(), - rtol=1e-4) - self.assertAllClose(np.array([-0.0282721, -0.07188385]), var1.eval(), - rtol=1e-4) + self.assertAllCloseAccordingToType( + np.array([-0.21931979, -0.40642974]), var0.eval(), rtol=1e-4) + self.assertAllCloseAccordingToType( + np.array([-0.0282721, -0.07188385]), var1.eval(), rtol=1e-4) # When variables are initialized with Zero, FTRL-Proximal has two properties: # 1. Without L1&L2 but with fixed learning rate, FTRL-Proximal is identical @@ -272,8 +272,8 @@ class FtrlOptimizerTest(XLATestCase): with self.test_session(), self.test_scope(): val2, val3 = self.equivAdagradTest_AdagradPart(steps, dtype) - self.assertAllClose(val0, val2, rtol=1e-4) - self.assertAllClose(val1, val3, rtol=1e-4) + self.assertAllCloseAccordingToType(val0, val2, rtol=1e-4) + self.assertAllCloseAccordingToType(val1, val3, rtol=1e-4) def testEquivGradientDescentwithoutRegularization(self): steps = 5 @@ -284,8 +284,8 @@ class FtrlOptimizerTest(XLATestCase): val2, val3 = self.equivGradientDescentTest_GradientDescentPart( steps, dtype) - self.assertAllClose(val0, val2, rtol=1e-5) - self.assertAllClose(val1, val3, rtol=1e-5) + self.assertAllCloseAccordingToType(val0, val2, rtol=1e-5) + self.assertAllCloseAccordingToType(val1, val3, rtol=1e-5) if __name__ == "__main__": diff --git a/tensorflow/compiler/tests/fused_batchnorm_test.py b/tensorflow/compiler/tests/fused_batchnorm_test.py index 00a9c9a65ba03d099581a3ee0dbe32c33e111231..a80d69fa5f5099b8a8b67df0da9c92b957e9d194 100644 --- a/tensorflow/compiler/tests/fused_batchnorm_test.py +++ b/tensorflow/compiler/tests/fused_batchnorm_test.py @@ -155,7 +155,7 @@ class FusedBatchNormTest(XLATestCase): def testLearningWithGradientChecker(self): self._testLearning(True) - def testGradient(self): + def testGradientTraining(self): # TODO(b/64270657): Use gradient_checker here in addition to comparing with # this reference implementation. channel = 3 @@ -175,7 +175,7 @@ class FusedBatchNormTest(XLATestCase): var = array_ops.placeholder(np.float32, shape=scale_shape, name="var") scale = array_ops.placeholder(np.float32, shape=scale_shape, name="scale") grad_x, grad_scale, grad_offset, _, _ = gen_nn_ops.fused_batch_norm_grad( - grad, x, scale, mean, var, data_format="NHWC") + grad, x, scale, mean, var, data_format="NHWC", is_training=True) grad_x_val, grad_scale_val, grad_offset_val = sess.run( [grad_x, grad_scale, grad_offset], { @@ -193,6 +193,53 @@ class FusedBatchNormTest(XLATestCase): self.assertAllClose(grad_scale_val, grad_scale_ref, atol=1e-2) self.assertAllClose(grad_offset_val, grad_offset_ref, atol=1e-3) + def testGradientInference(self): + # TODO(b/64270657): Use gradient_checker here in addition to comparing with + # this reference implementation. + channel = 3 + x_shape = [2, 2, 6, channel] + scale_shape = [channel] + grad_val = np.random.random_sample(x_shape).astype(np.float32) + x_val = np.random.random_sample(x_shape).astype(np.float32) + scale_val = np.random.random_sample(scale_shape).astype(np.float32) + mean_val = np.random.random_sample(scale_shape).astype(np.float32) + var_val = np.random.random_sample(scale_shape).astype(np.float32) + + with self.test_session() as sess, self.test_scope(): + grad = array_ops.placeholder(np.float32, shape=x_shape, name="grad") + x = array_ops.placeholder(np.float32, shape=x_shape, name="x") + mean = array_ops.placeholder(np.float32, shape=scale_shape, name="mean") + var = array_ops.placeholder(np.float32, shape=scale_shape, name="var") + scale = array_ops.placeholder(np.float32, shape=scale_shape, name="scale") + with self.test_scope(): + out = gen_nn_ops.fused_batch_norm_grad( + grad, x, scale, mean, var, data_format="NHWC", is_training=False) + grad_x, grad_scale, grad_offset, _, _ = out + + ref_x, ref_scale, ref_offset, _, _ = gen_nn_ops.fused_batch_norm_grad( + grad, x, scale, mean, var, data_format="NHWC", is_training=False) + + grad_x_val, grad_scale_val, grad_offset_val, = sess.run( + [grad_x, grad_scale, grad_offset], { + grad: grad_val, + x: x_val, + mean: mean_val, + var: var_val, + scale: scale_val + }) + grad_x_ref, grad_scale_ref, grad_offset_ref, = sess.run( + [ref_x, ref_scale, ref_offset], { + grad: grad_val, + x: x_val, + mean: mean_val, + var: var_val, + scale: scale_val + }) + + self.assertAllClose(grad_x_val, grad_x_ref, atol=1e-2) + self.assertAllClose(grad_scale_val, grad_scale_ref, atol=1e-2) + self.assertAllClose(grad_offset_val, grad_offset_ref, atol=1e-3) + if __name__ == "__main__": test.main() diff --git a/tensorflow/compiler/tests/image_ops_test.py b/tensorflow/compiler/tests/image_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..e84b790037c3b341a01c0a4d295e36890ea1f28e --- /dev/null +++ b/tensorflow/compiler/tests/image_ops_test.py @@ -0,0 +1,547 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for image ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import colorsys +import math + +import numpy as np + +from six.moves import xrange # pylint: disable=redefined-builtin + +from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_image_ops +from tensorflow.python.ops import image_ops +from tensorflow.python.platform import test + + +class RGBToHSVTest(XLATestCase): + + def testBatch(self): + # Build an arbitrary RGB image + np.random.seed(7) + batch_size = 5 + shape = (batch_size, 2, 7, 3) + + for nptype in self.float_types: + inp = np.random.rand(*shape).astype(nptype) + + # Convert to HSV and back, as a batch and individually + with self.test_session() as sess: + batch0 = array_ops.placeholder(nptype, shape=shape) + with self.test_scope(): + batch1 = image_ops.rgb_to_hsv(batch0) + batch2 = image_ops.hsv_to_rgb(batch1) + split0 = array_ops.unstack(batch0) + with self.test_scope(): + split1 = list(map(image_ops.rgb_to_hsv, split0)) + split2 = list(map(image_ops.hsv_to_rgb, split1)) + join1 = array_ops.stack(split1) + join2 = array_ops.stack(split2) + batch1, batch2, join1, join2 = sess.run([batch1, batch2, join1, join2], + { + batch0: inp + }) + + # Verify that processing batch elements together is the same as separate + self.assertAllClose(batch1, join1) + self.assertAllClose(batch2, join2) + self.assertAllClose(batch2, inp) + + def testRGBToHSVRoundTrip(self): + data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1] + for nptype in self.float_types: + rgb_np = np.array(data, dtype=nptype).reshape([2, 2, 3]) / 255. + with self.test_session(): + placeholder = array_ops.placeholder(nptype) + with self.test_scope(): + hsv = image_ops.rgb_to_hsv(placeholder) + rgb = image_ops.hsv_to_rgb(hsv) + rgb_tf = rgb.eval(feed_dict={placeholder: rgb_np}) + self.assertAllClose(rgb_tf, rgb_np) + + def testRGBToHSVNumpy(self): + """Tests the RGB to HSV conversion matches a reference implementation.""" + for nptype in self.float_types: + rgb_flat = np.random.random(64 * 3).reshape((64, 3)).astype(nptype) + rgb_np = rgb_flat.reshape(4, 4, 4, 3) + hsv_np = np.array([colorsys.rgb_to_hsv(r, g, b) for r, g, b in rgb_flat]) + hsv_np = hsv_np.reshape(4, 4, 4, 3) + with self.test_session(): + placeholder = array_ops.placeholder(nptype) + with self.test_scope(): + hsv_op = image_ops.rgb_to_hsv(placeholder) + hsv_tf = hsv_op.eval(feed_dict={placeholder: rgb_np}) + self.assertAllClose(hsv_tf, hsv_np) + + +class AdjustContrastTest(XLATestCase): + + def _testContrast(self, x_np, y_np, contrast_factor): + with self.test_session(): + x = array_ops.placeholder(x_np.dtype, shape=x_np.shape) + flt_x = image_ops.convert_image_dtype(x, dtypes.float32) + with self.test_scope(): + y = image_ops.adjust_contrast(flt_x, contrast_factor) + y = image_ops.convert_image_dtype(y, x.dtype, saturate=True) + y_tf = y.eval({x: x_np}) + self.assertAllClose(y_tf, y_np, 1e-6) + + def testFloatContrast(self): + x_shape = [1, 2, 2, 3] + x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1] + x_np = np.array(x_data, dtype=np.float32).reshape(x_shape) / 255. + + y_data = [ + -45.25, -90.75, -92.5, 62.75, 169.25, 333.5, 28.75, -84.75, 349.5, + 134.75, 409.25, -116.5 + ] + y_np = np.array(y_data, dtype=np.float32).reshape(x_shape) / 255. + + self._testContrast(x_np, y_np, contrast_factor=2.0) + + def testBatchContrast(self): + x_shape = [2, 1, 2, 3] + x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1] + x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape) + + y_data = [0, 0, 0, 81, 200, 255, 10, 0, 255, 116, 255, 0] + y_np = np.array(y_data, dtype=np.uint8).reshape(x_shape) + + self._testContrast(x_np, y_np, contrast_factor=2.0) + + def _adjustContrastNp(self, x_np, contrast_factor): + mean = np.mean(x_np, (1, 2), keepdims=True) + y_np = mean + contrast_factor * (x_np - mean) + return y_np + + def _adjustContrastTf(self, x_np, contrast_factor): + with self.test_session(): + x = array_ops.placeholder(np.float32) + with self.test_scope(): + y = image_ops.adjust_contrast(x, contrast_factor) + y_tf = y.eval({x: x_np}) + return y_tf + + def testRandomContrast(self): + x_shapes = [ + [1, 2, 2, 3], + [2, 1, 2, 3], + [1, 2, 2, 3], + [2, 5, 5, 3], + [2, 1, 1, 3], + ] + for x_shape in x_shapes: + x_np = np.random.rand(*x_shape) * 255. + contrast_factor = np.random.rand() * 2.0 + 0.1 + y_np = self._adjustContrastNp(x_np, contrast_factor) + y_tf = self._adjustContrastTf(x_np, contrast_factor) + self.assertAllClose(y_tf, y_np, rtol=1e-5, atol=1e-5) + + +class AdjustHueTest(XLATestCase): + + def testAdjustNegativeHue(self): + x_shape = [2, 2, 3] + x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1] + x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape) + + delta = -0.25 + y_data = [0, 13, 1, 54, 226, 59, 8, 234, 150, 255, 39, 1] + y_np = np.array(y_data, dtype=np.uint8).reshape(x_shape) + + with self.test_session(): + x = array_ops.placeholder(x_np.dtype, shape=x_shape) + flt_x = image_ops.convert_image_dtype(x, dtypes.float32) + with self.test_scope(): + y = gen_image_ops.adjust_hue(flt_x, delta) + y = image_ops.convert_image_dtype(y, x.dtype, saturate=True) + y_tf = y.eval({x: x_np}) + self.assertAllEqual(y_tf, y_np) + + def testAdjustPositiveHue(self): + x_shape = [2, 2, 3] + x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1] + x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape) + + delta = 0.25 + y_data = [13, 0, 11, 226, 54, 221, 234, 8, 92, 1, 217, 255] + y_np = np.array(y_data, dtype=np.uint8).reshape(x_shape) + + with self.test_session(): + x = array_ops.placeholder(x_np.dtype, shape=x_shape) + flt_x = image_ops.convert_image_dtype(x, dtypes.float32) + with self.test_scope(): + y = gen_image_ops.adjust_hue(flt_x, delta) + y = image_ops.convert_image_dtype(y, x.dtype, saturate=True) + y_tf = y.eval({x: x_np}) + self.assertAllEqual(y_tf, y_np) + + def testBatchAdjustHue(self): + x_shape = [2, 1, 2, 3] + x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1] + x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape) + + delta = 0.25 + y_data = [13, 0, 11, 226, 54, 221, 234, 8, 92, 1, 217, 255] + y_np = np.array(y_data, dtype=np.uint8).reshape(x_shape) + + with self.test_session(): + x = array_ops.placeholder(x_np.dtype, shape=x_shape) + flt_x = image_ops.convert_image_dtype(x, dtypes.float32) + with self.test_scope(): + y = gen_image_ops.adjust_hue(flt_x, delta) + y = image_ops.convert_image_dtype(y, x.dtype, saturate=True) + y_tf = y.eval({x: x_np}) + self.assertAllEqual(y_tf, y_np) + + def _adjustHueNp(self, x_np, delta_h): + self.assertEqual(x_np.shape[-1], 3) + x_v = x_np.reshape([-1, 3]) + y_v = np.ndarray(x_v.shape, dtype=x_v.dtype) + channel_count = x_v.shape[0] + for i in xrange(channel_count): + r = x_v[i][0] + g = x_v[i][1] + b = x_v[i][2] + h, s, v = colorsys.rgb_to_hsv(r, g, b) + h += delta_h + h = math.fmod(h + 10.0, 1.0) + r, g, b = colorsys.hsv_to_rgb(h, s, v) + y_v[i][0] = r + y_v[i][1] = g + y_v[i][2] = b + return y_v.reshape(x_np.shape) + + def _adjustHueTf(self, x_np, delta_h): + with self.test_session(): + x = array_ops.placeholder(dtypes.float32) + with self.test_scope(): + y = gen_image_ops.adjust_hue(x, delta_h) + y_tf = y.eval({x: x_np}) + return y_tf + + def testAdjustRandomHue(self): + x_shapes = [ + [2, 2, 3], + [4, 2, 3], + [2, 4, 3], + [2, 5, 3], + [1000, 1, 3], + ] + test_styles = [ + "all_random", + "rg_same", + "rb_same", + "gb_same", + "rgb_same", + ] + for x_shape in x_shapes: + for test_style in test_styles: + x_np = np.random.rand(*x_shape) * 255. + delta_h = np.random.rand() * 2.0 - 1.0 + if test_style == "all_random": + pass + elif test_style == "rg_same": + x_np[..., 1] = x_np[..., 0] + elif test_style == "rb_same": + x_np[..., 2] = x_np[..., 0] + elif test_style == "gb_same": + x_np[..., 2] = x_np[..., 1] + elif test_style == "rgb_same": + x_np[..., 1] = x_np[..., 0] + x_np[..., 2] = x_np[..., 0] + else: + raise AssertionError("Invalid test style: %s" % (test_style)) + y_np = self._adjustHueNp(x_np, delta_h) + y_tf = self._adjustHueTf(x_np, delta_h) + self.assertAllClose(y_tf, y_np, rtol=2e-5, atol=1e-4) + + def testInvalidShapes(self): + fused = False + if not fused: + # The tests are known to pass with the fused adjust_hue. We will enable + # them when the fused implementation is the default. + return + x_np = np.random.rand(2, 3) * 255. + delta_h = np.random.rand() * 2.0 - 1.0 + fused = False + with self.assertRaisesRegexp(ValueError, "Shape must be at least rank 3"): + self._adjustHueTf(x_np, delta_h) + x_np = np.random.rand(4, 2, 4) * 255. + delta_h = np.random.rand() * 2.0 - 1.0 + with self.assertRaisesOpError("input must have 3 channels"): + self._adjustHueTf(x_np, delta_h) + + +class AdjustSaturationTest(XLATestCase): + + def _adjust_saturation(self, image, saturation_factor): + image = ops.convert_to_tensor(image, name="image") + orig_dtype = image.dtype + flt_image = image_ops.convert_image_dtype(image, dtypes.float32) + with self.test_scope(): + saturation_adjusted_image = gen_image_ops.adjust_saturation( + flt_image, saturation_factor) + return image_ops.convert_image_dtype(saturation_adjusted_image, orig_dtype) + + def testHalfSaturation(self): + x_shape = [2, 2, 3] + x_rgb_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1] + x_np = np.array(x_rgb_data, dtype=np.uint8).reshape(x_shape) + + saturation_factor = 0.5 + y_rgb_data = [6, 9, 13, 140, 180, 226, 135, 121, 234, 172, 255, 128] + y_np = np.array(y_rgb_data, dtype=np.uint8).reshape(x_shape) + + with self.test_session(): + x = array_ops.placeholder(x_np.dtype, shape=x_shape) + y = self._adjust_saturation(x, saturation_factor) + y_tf = y.eval({x: x_np}) + self.assertAllEqual(y_tf, y_np) + + def testTwiceSaturation(self): + x_shape = [2, 2, 3] + x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1] + x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape) + + saturation_factor = 2.0 + y_data = [0, 5, 13, 0, 106, 226, 30, 0, 234, 89, 255, 0] + y_np = np.array(y_data, dtype=np.uint8).reshape(x_shape) + + with self.test_session(): + x = array_ops.placeholder(x_np.dtype, shape=x_shape) + y = self._adjust_saturation(x, saturation_factor) + y_tf = y.eval({x: x_np}) + self.assertAllEqual(y_tf, y_np) + + def _adjustSaturationNp(self, x_np, scale): + self.assertEqual(x_np.shape[-1], 3) + x_v = x_np.reshape([-1, 3]) + y_v = np.ndarray(x_v.shape, dtype=x_v.dtype) + channel_count = x_v.shape[0] + for i in xrange(channel_count): + r = x_v[i][0] + g = x_v[i][1] + b = x_v[i][2] + h, s, v = colorsys.rgb_to_hsv(r, g, b) + s *= scale + s = min(1.0, max(0.0, s)) + r, g, b = colorsys.hsv_to_rgb(h, s, v) + y_v[i][0] = r + y_v[i][1] = g + y_v[i][2] = b + return y_v.reshape(x_np.shape) + + def testAdjustRandomSaturation(self): + x_shapes = [ + [2, 2, 3], + [4, 2, 3], + [2, 4, 3], + [2, 5, 3], + [1000, 1, 3], + ] + test_styles = [ + "all_random", + "rg_same", + "rb_same", + "gb_same", + "rgb_same", + ] + with self.test_session(): + for x_shape in x_shapes: + for test_style in test_styles: + x_np = np.random.rand(*x_shape) * 255. + scale = np.random.rand() + if test_style == "all_random": + pass + elif test_style == "rg_same": + x_np[..., 1] = x_np[..., 0] + elif test_style == "rb_same": + x_np[..., 2] = x_np[..., 0] + elif test_style == "gb_same": + x_np[..., 2] = x_np[..., 1] + elif test_style == "rgb_same": + x_np[..., 1] = x_np[..., 0] + x_np[..., 2] = x_np[..., 0] + else: + raise AssertionError("Invalid test style: %s" % (test_style)) + y_baseline = self._adjustSaturationNp(x_np, scale) + x = array_ops.placeholder(dtypes.float32, shape=x_shape) + with self.test_scope(): + y_fused = self._adjust_saturation(x, + scale).eval(feed_dict={ + x: x_np + }) + self.assertAllClose(y_fused, y_baseline, rtol=2e-5, atol=1e-5) + + +class ResizeBilinearTest(XLATestCase): + + def _assertForwardOpMatchesExpected(self, + image_np, + target_shape, + expected=None): + if expected is None: + self.fail("expected must be specified") + with self.test_session() as sess, self.test_scope(): + image = array_ops.placeholder(image_np.dtype) + resized = gen_image_ops.resize_bilinear( + image, target_shape, align_corners=True) + out = sess.run(resized, {image: image_np[np.newaxis, :, :, np.newaxis]}) + self.assertAllClose(expected[np.newaxis, :, :, np.newaxis], out) + + def _assertBackwardOpMatchesExpected(self, + grads_np, + input_shape=None, + dtype=None, + expected=None): + if input_shape is None: + self.fail("input_shape must be specified") + if expected is None: + self.fail("expected must be specified") + with self.test_session() as sess, self.test_scope(): + dtype = dtype or np.float32 + grads = array_ops.placeholder(np.float32) + resized = gen_image_ops._resize_bilinear_grad( + grads, + np.zeros([1, input_shape[0], input_shape[1], 1], dtype=dtype), + align_corners=True) + out = sess.run(resized, {grads: grads_np[np.newaxis, :, :, np.newaxis]}) + self.assertAllClose(expected[np.newaxis, :, :, np.newaxis], out) + + def testAlignCorners1x2To3x2(self): + for dtype in self.float_types: + self._assertForwardOpMatchesExpected( + np.array([[1, 2]], dtype=dtype), [3, 3], + expected=np.array( + [[1, 1.5, 2], [1, 1.5, 2], [1, 1.5, 2]], dtype=np.float32)) + + def testAlignCorners1x2To3x2Grad(self): + for dtype in self.float_types: + self._assertBackwardOpMatchesExpected( + np.array([[1, 2], [3, 4], [5, 6]], dtype=np.float32), + input_shape=[1, 2], + dtype=dtype, + expected=np.array([[9, 12]], dtype=np.float32)) + + def testAlignCorners2x2To1x1(self): + for dtype in self.float_types: + self._assertForwardOpMatchesExpected( + np.array([[1, 2], [3, 4]], dtype=dtype), [1, 1], + expected=np.array([[1]], dtype=np.float32)) + + def testAlignCorners2x2To1x1Grad(self): + for dtype in self.float_types: + self._assertBackwardOpMatchesExpected( + np.array([[7]], dtype=np.float32), + input_shape=[2, 2], + dtype=dtype, + expected=np.array([[7, 0], [0, 0]], dtype=np.float32)) + + def testAlignCorners2x2To3x3(self): + for dtype in self.float_types: + self._assertForwardOpMatchesExpected( + np.array([[1, 2], [3, 4]], dtype=dtype), [3, 3], + expected=np.array( + [[1, 1.5, 2], [2, 2.5, 3], [3, 3.5, 4]], dtype=np.float32)) + + def testAlignCorners2x2To3x3Grad(self): + self._assertBackwardOpMatchesExpected( + np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.float32), + input_shape=[2, 2], + expected=np.array([[5.25, 8.25], [14.25, 17.25]], dtype=np.float32)) + + def testAlignCorners3x3To2x2(self): + for dtype in self.float_types: + self._assertForwardOpMatchesExpected( + np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=dtype), [2, 2], + expected=np.array([[1, 3], [7, 9]], dtype=np.float32)) + + def testAlignCorners3x3To2x2Grad(self): + for dtype in self.float_types: + self._assertBackwardOpMatchesExpected( + np.array([[7, 13], [22, 4]], dtype=np.float32), + input_shape=[3, 3], + dtype=dtype, + expected=np.array( + [[7, 0, 13], [0, 0, 0], [22, 0, 4]], dtype=np.float32)) + + def testAlignCorners4x4To3x3(self): + for dtype in self.float_types: + self._assertForwardOpMatchesExpected( + np.array( + [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]], + dtype=dtype), [3, 3], + expected=np.array( + [[1, 2.5, 4], [7, 8.5, 10], [13, 14.5, 16]], dtype=np.float32)) + + def testAlignCorners4x4To3x3Grad(self): + for dtype in self.float_types: + self._assertBackwardOpMatchesExpected( + np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.float32), + input_shape=[4, 4], + dtype=dtype, + expected=np.array( + [[1, 1, 1, 3], [2, 1.25, 1.25, 3], [2, 1.25, 1.25, 3], + [7, 4, 4, 9]], + dtype=np.float32)) + + def testAlignCorners3x3To9x9(self): + for dtype in self.float_types: + self._assertForwardOpMatchesExpected( + np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=dtype), [9, 9], + expected=np.array( + [[1.0, 1.25, 1.50, 1.75, 2.00, 2.25, 2.50, 2.75, 3.00], [ + 1.75, 2.00, 2.25, 2.50, 2.75, 3.00, 3.25, 3.50, 3.75 + ], [2.50, 2.75, 3.00, 3.25, 3.50, 3.75, 4.00, 4.25, 4.50], [ + 3.25, 3.50, 3.75, 4.00, 4.25, 4.50, 4.75, 5.00, 5.25 + ], [4.00, 4.25, 4.50, 4.75, 5.00, 5.25, 5.50, 5.75, 6.00], [ + 4.75, 5.00, 5.25, 5.50, 5.75, 6.00, 6.25, 6.50, 6.75 + ], [5.50, 5.75, 6.00, 6.25, 6.50, 6.75, 7.00, 7.25, 7.50], [ + 6.25, 6.50, 6.75, 7.00, 7.25, 7.50, 7.75, 8.00, 8.25 + ], [7.00, 7.25, 7.50, 7.75, 8.00, 8.25, 8.50, 8.75, 9.00]], + dtype=np.float32)) + + def testAlignCorners3x3To9x9Grad(self): + for dtype in self.float_types: + self._assertBackwardOpMatchesExpected( + np.array( + [[1.00, 1.25, 1.50, 1.75, 2.00, 2.25, 2.50, 2.75, 3.00], [ + 1.75, 2.00, 2.25, 2.50, 2.75, 3.00, 3.25, 3.50, 3.75 + ], [2.50, 2.75, 3.00, 3.25, 3.50, 3.75, 4.00, 4.25, 4.50], [ + 3.25, 3.50, 3.75, 4.00, 4.25, 4.50, 4.75, 5.00, 5.25 + ], [4.00, 4.25, 4.50, 4.75, 5.00, 5.25, 5.50, 5.75, 6.00], [ + 4.75, 5.00, 5.25, 5.50, 5.75, 6.00, 6.25, 6.50, 6.75 + ], [5.50, 5.75, 6.00, 6.25, 6.50, 6.75, 7.00, 7.25, 7.50], [ + 6.25, 6.50, 6.75, 7.00, 7.25, 7.50, 7.75, 8.00, 8.25 + ], [7.00, 7.25, 7.50, 7.75, 8.00, 8.25, 8.50, 8.75, 9.00]], + dtype=np.float32), + input_shape=[3, 3], + dtype=dtype, + expected=np.array( + [[12.5, 27.5, 21.875], [42.5, 80.0, 57.5], [40.625, 72.5, 50]], + dtype=np.float32)) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/momentum_test.py b/tensorflow/compiler/tests/momentum_test.py index c00e3035a0982b2b2e59eb6f53499918515ae71d..af9394e7d7dc9cf7dd009420ff9c845aec8785bd 100644 --- a/tensorflow/compiler/tests/momentum_test.py +++ b/tensorflow/compiler/tests/momentum_test.py @@ -96,28 +96,27 @@ class MomentumOptimizerTest(XLATestCase): def testNesterovMomentum(self): for dtype in self.float_types: with self.test_session(), self.test_scope(): - var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) - var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype) - var0_np = np.array([1.0, 2.0], dtype=dtype) - var1_np = np.array([3.0, 4.0], dtype=dtype) + var0 = resource_variable_ops.ResourceVariable([0.1, 0.2], dtype=dtype) + var1 = resource_variable_ops.ResourceVariable([0.3, 0.4], dtype=dtype) + var0_np = np.array([0.1, 0.2], dtype=dtype) + var1_np = np.array([0.3, 0.4], dtype=dtype) accum0_np = np.array([0.0, 0.0], dtype=dtype) accum1_np = np.array([0.0, 0.0], dtype=dtype) - cost = 5 * var0 * var0 + 3 * var1 + cost = 0.4 * var0 * var0 + 0.9 * var1 global_step = resource_variable_ops.ResourceVariable( array_ops.zeros([], dtypes.int32), name="global_step") mom_op = momentum_lib.MomentumOptimizer( - learning_rate=2.0, momentum=0.9, use_nesterov=True) + learning_rate=0.1, momentum=0.9, use_nesterov=True) opt_op = mom_op.minimize(cost, global_step, [var0, var1]) variables.global_variables_initializer().run() for _ in range(1, 5): opt_op.run() var0_np, accum0_np = self._update_nesterov_momentum_numpy( - var0_np, accum0_np, var0_np * 10, 2.0, 0.9) - var1_np, accum1_np = self._update_nesterov_momentum_numpy(var1_np, - accum1_np, - 3, 2.0, 0.9) - self.assertAllClose(var0_np, var0.eval()) - self.assertAllClose(var1_np, var1.eval()) + var0_np, accum0_np, var0_np * 0.8, 0.1, 0.9) + var1_np, accum1_np = self._update_nesterov_momentum_numpy( + var1_np, accum1_np, 0.9, 0.1, 0.9) + self.assertAllCloseAccordingToType(var0_np, var0.eval()) + self.assertAllCloseAccordingToType(var1_np, var1.eval()) def testTensorLearningRateAndMomentum(self): for dtype in self.float_types: diff --git a/tensorflow/compiler/tests/randomized_tests.cc b/tensorflow/compiler/tests/randomized_tests.cc index 6a8c3bcd55a6e454a19b6249cf4eb48739c8657f..e72dd4eea9f127e1df96ab166103c4c16372adb6 100644 --- a/tensorflow/compiler/tests/randomized_tests.cc +++ b/tensorflow/compiler/tests/randomized_tests.cc @@ -93,11 +93,11 @@ class OpTestBuilder { public: explicit OpTestBuilder(const string& op_name); - // Adds an input 'tensor'. + // Adds an input 'tensor' as a Placeholder node. OpTestBuilder& Input(const Tensor& tensor); - // Adds a random input tensor with 'type'. If 'dims' is not provided, - // RandomDims() is used. + // Adds a random input tensor with 'type' as a Placeholder node. + // If 'dims' is not provided, RandomDims() is used. OpTestBuilder& RandomInput(DataType type); OpTestBuilder& RandomInput(DataType type, std::vector dims); @@ -998,6 +998,13 @@ TEST_F(OpTest, Atanh) { }); } +TEST_F(OpTest, Atan) { + Repeatedly([this]() { + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Atan").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT)); + }); +} + TEST_F(OpTest, Atan2) { Repeatedly([this]() { auto dims = BroadcastableDims(); @@ -1368,6 +1375,121 @@ TEST_F(OpTest, Conj) { }); } +TEST_F(OpTest, FFT) { + Repeatedly([this]() { + std::vector dims = RandomDims(1, kDefaultMaxRank); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("FFT").RandomInput(DT_COMPLEX64, dims)); + }); +} + +TEST_F(OpTest, FFT2D) { + Repeatedly([this]() { + std::vector dims = RandomDims(2, kDefaultMaxRank); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("FFT2D").RandomInput(DT_COMPLEX64, dims)); + }); +} + +TEST_F(OpTest, FFT3D) { + Repeatedly([this]() { + std::vector dims = RandomDims(3, kDefaultMaxRank); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("FFT3D").RandomInput(DT_COMPLEX64, dims)); + }); +} + +TEST_F(OpTest, IFFT) { + Repeatedly([this]() { + std::vector dims = RandomDims(1, kDefaultMaxRank); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("IFFT").RandomInput(DT_COMPLEX64, dims)); + }); +} + +TEST_F(OpTest, IFFT2D) { + Repeatedly([this]() { + std::vector dims = RandomDims(2, kDefaultMaxRank); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("IFFT2D").RandomInput(DT_COMPLEX64, dims)); + }); +} + +TEST_F(OpTest, IFFT3D) { + Repeatedly([this]() { + std::vector dims = RandomDims(3, kDefaultMaxRank); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("IFFT3D").RandomInput(DT_COMPLEX64, dims)); + }); +} + +TEST_F(OpTest, RFFT) { + Repeatedly([this]() { + std::vector dims = RandomDims(1, kDefaultMaxRank, 3); + Tensor fft_shape = test::AsTensor(AsInt32s({dims[dims.size() - 1]})); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("RFFT").RandomInput(DT_FLOAT, dims).Input(fft_shape)); + }); +} + +TEST_F(OpTest, RFFT2D) { + Repeatedly([this]() { + std::vector dims = RandomDims(2, kDefaultMaxRank, 3); + Tensor fft_shape = test::AsTensor( + AsInt32s({dims[dims.size() - 2], dims[dims.size() - 1]})); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("RFFT2D").RandomInput(DT_FLOAT, dims).Input(fft_shape)); + }); +} + +TEST_F(OpTest, RFFT3D) { + Repeatedly([this]() { + std::vector dims = RandomDims(3, kDefaultMaxRank, 3); + Tensor fft_shape = test::AsTensor(AsInt32s( + {dims[dims.size() - 3], dims[dims.size() - 2], dims[dims.size() - 1]})); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("RFFT3D").RandomInput(DT_FLOAT, dims).Input(fft_shape)); + }); +} + +TEST_F(OpTest, IRFFT) { + Repeatedly([this]() { + std::vector dims = RandomDims(1, kDefaultMaxRank, 3); + int64 orig_size = dims[dims.size() - 1]; + dims[dims.size() - 1] = dims[dims.size() - 1] / 2 + 1; + Tensor fft_shape = test::AsTensor(AsInt32s({orig_size})); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("IRFFT") + .RandomInput(DT_COMPLEX64, dims) + .Input(fft_shape)); + }); +} + +TEST_F(OpTest, IRFFT2D) { + Repeatedly([this]() { + std::vector dims = RandomDims(2, kDefaultMaxRank, 3); + std::vector orig_size = {dims[dims.size() - 2], + dims[dims.size() - 1]}; + dims[dims.size() - 1] = dims[dims.size() - 1] / 2 + 1; + Tensor fft_shape = test::AsTensor(AsInt32s({orig_size})); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("IRFFT2D") + .RandomInput(DT_COMPLEX64, dims) + .Input(fft_shape)); + }); +} + +TEST_F(OpTest, IRFFT3D) { + Repeatedly([this]() { + std::vector dims = RandomDims(3, kDefaultMaxRank, 3); + std::vector orig_size = { + dims[dims.size() - 3], dims[dims.size() - 2], dims[dims.size() - 1]}; + dims[dims.size() - 1] = dims[dims.size() - 1] / 2 + 1; + Tensor fft_shape = test::AsTensor(AsInt32s({orig_size})); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("IRFFT3D") + .RandomInput(DT_COMPLEX64, dims) + .Input(fft_shape)); + }); +} + TEST_F(OpTest, Conv2D) { Repeatedly([this]() { WindowedSpatialDims d = ChooseWindowedSpatialDims(2); @@ -1382,7 +1504,7 @@ TEST_F(OpTest, Conv2D) { std::vector kernel_dims = {d.kernel_dims[0], d.kernel_dims[1], features_in, features_out}; - DataType type = DT_FLOAT; // TODO(b/65408531): COMPLEX_64 support + DataType type = DT_FLOAT; return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("Conv2D") .RandomInput(type, data_dims) @@ -1407,7 +1529,7 @@ TEST_F(OpTest, Conv2DBackpropFilter) { ImageDims(FORMAT_NHWC, batch, features_out, d.output_dims); Tensor kernel_shape = test::AsTensor(AsInt32s( {d.kernel_dims[0], d.kernel_dims[1], features_in, features_out})); - DataType type = DT_FLOAT; // TODO(b/65408531): COMPLEX_64 support + DataType type = DT_FLOAT; return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("Conv2DBackpropFilter") .RandomInput(type, activations) @@ -1433,7 +1555,7 @@ TEST_F(OpTest, Conv2DBackpropInput) { ImageDims(FORMAT_NHWC, batch, features_out, d.output_dims); std::vector kernel = {d.kernel_dims[0], d.kernel_dims[1], features_in, features_out}; - DataType type = DT_FLOAT; // TODO(b/65408531): COMPLEX_64 support + DataType type = DT_FLOAT; return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("Conv2DBackpropInput") .Input(in_shape) @@ -1457,7 +1579,7 @@ TEST_F(OpTest, Conv3D) { std::vector kernel = {d.kernel_dims[0], d.kernel_dims[1], d.kernel_dims[2], features_in, features_out}; - DataType type = DT_FLOAT; // TODO(b/65408531): COMPLEX_64 support + DataType type = DT_FLOAT; return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("Conv3D") .RandomInput(type, data) @@ -1482,7 +1604,7 @@ TEST_F(OpTest, Conv3DBackpropFilter) { Tensor kernel_shape = test::AsTensor( AsInt32s({d.kernel_dims[0], d.kernel_dims[1], d.kernel_dims[2], features_in, features_out})); - DataType type = DT_FLOAT; // TODO(b/65408531): COMPLEX_64 support + DataType type = DT_FLOAT; return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("Conv3DBackpropFilterV2") .RandomInput(type, activations) @@ -2460,6 +2582,36 @@ TEST_F(OpTest, Reshape) { }); } +TEST_F(OpTest, ResizeBilinear) { + Repeatedly([this]() { + std::vector in_dims = RandomDims(4, 4); + std::vector out_dims = RandomDims(2, 2); + + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("ResizeBilinear") + .RandomInput(DT_FLOAT, in_dims) + .Input(test::AsTensor( + std::vector(out_dims.begin(), out_dims.end()))) + .Attr("T", DT_FLOAT) + .Attr("align_corners", true)); + }); +} + +TEST_F(OpTest, ResizeBilinearGrad) { + Repeatedly([this]() { + std::vector in_dims = RandomDims(4, 4); + std::vector out_dims = RandomDims(2, 2); + + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("ResizeBilinearGrad") + .RandomInput(DT_FLOAT, in_dims) + .RandomInput(DT_FLOAT, + {in_dims[0], out_dims[0], out_dims[1], in_dims[3]}) + .Attr("T", DT_FLOAT) + .Attr("align_corners", true)); + }); +} + TEST_F(OpTest, Reverse) { Repeatedly([this]() { std::vector dims = RandomDims(1); diff --git a/tensorflow/compiler/tests/scan_ops_test.py b/tensorflow/compiler/tests/scan_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..3260e63b23226d736a7ddc0f21a94a8c791e0442 --- /dev/null +++ b/tensorflow/compiler/tests/scan_ops_test.py @@ -0,0 +1,229 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Functional tests for scan ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import errors_impl +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + + +def numpy_reverse(x, axis): + length = len(x.shape) + if axis < 0: + axis = length + axis + + ix = [ + slice(None, None, -1) if i == axis else slice(None) for i in range(length) + ] + return x[ix] + + +def handle_options(func, x, axis, exclusive, reverse): + """Adds tf options to numpy scan ops.""" + length = len(x.shape) + if axis < 0: + axis = length + axis + + if reverse: + x = numpy_reverse(x, axis) + + if exclusive: + ix_head = [slice(0, 1) if i == axis else slice(None) for i in range(length)] + ix_init = [ + slice(0, -1) if i == axis else slice(None) for i in range(length) + ] + if func == np.cumsum: + init = np.zeros_like(x[ix_head]) + elif func == np.cumprod: + init = np.ones_like(x[ix_head]) + else: + raise ValueError("Unknown scan function.") + x = np.concatenate([init, func(x[ix_init], axis)], axis=axis) + else: + x = func(x, axis=axis) + + if reverse: + x = numpy_reverse(x, axis) + return x + + +class CumsumTest(XLATestCase): + + valid_dtypes = [np.float32] + + def axis_dtypes(self): + return set(self.int_types).intersection([np.int32, np.int64]) + + def _compare(self, x, axis, exclusive, reverse): + np_out = handle_options(np.cumsum, x, axis, exclusive, reverse) + with self.test_session(), self.test_scope(): + p = array_ops.placeholder(x.dtype) + tf_out = math_ops.cumsum(p, axis, exclusive, reverse).eval( + feed_dict={p: x}) + + self.assertAllClose(np_out, tf_out) + + def _compareAll(self, x, axis): + for exclusive in [True, False]: + for reverse in [True, False]: + self._compare(x, axis, exclusive, reverse) + + def testEmpty(self): + for dtype in self.valid_dtypes: + x = np.zeros([0]).astype(dtype) + for axis in (-1, 0): + self._compareAll(x, axis) + + def testAxisType(self): + for dtype in self.valid_dtypes: + x = np.arange(1, 6).reshape([5]).astype(dtype) + for axis_dtype in self.axis_dtypes(): + with self.test_session(), self.test_scope(): + p = array_ops.placeholder(x.dtype) + axis = constant_op.constant(0, axis_dtype) + math_ops.cumsum(p, axis).eval(feed_dict={p: x}) + + def test1D(self): + for dtype in self.valid_dtypes: + x = np.arange(1, 6).reshape([5]).astype(dtype) + for axis in (-1, 0): + self._compareAll(x, axis) + + def test2D(self): + for dtype in self.valid_dtypes: + x = np.arange(0, 10).reshape([2, 5]).astype(dtype) + for axis in (-2, -1, 0, 1): + self._compareAll(x, axis) + + def test3D(self): + for dtype in self.valid_dtypes: + x = np.arange(0, 20).reshape([2, 2, 5]).astype(dtype) + for axis in (-3, -2, -1, 0, 1, 2): + self._compareAll(x, axis) + + def test6D(self): + for dtype in self.valid_dtypes: + x = np.arange(1, 145).reshape([2, 2, 3, 3, 2, 2]).astype(dtype) + for axis in range(-6, 6, 3): + self._compareAll(x, axis) + + def testInvalidAxis(self): + x = np.arange(0, 10).reshape([2, 5]).astype(np.float32) + with self.test_session(), self.test_scope(): + input_tensor = ops.convert_to_tensor(x) + with self.assertRaisesWithPredicateMatch( + errors_impl.InvalidArgumentError, + lambda e: "Expected scan axis in the range [-2, 2)" in str(e)): + math_ops.cumsum(input_tensor, -3).eval() + with self.assertRaisesWithPredicateMatch( + errors_impl.InvalidArgumentError, + lambda e: "Expected scan axis in the range [-2, 2)" in str(e)): + math_ops.cumsum(input_tensor, 2).eval() + with self.assertRaisesWithPredicateMatch( + errors_impl.InvalidArgumentError, + lambda e: "axis must be a scalar" in str(e)): + math_ops.cumsum(input_tensor, [0]).eval() + + +class CumprodTest(XLATestCase): + + valid_dtypes = [np.float32] + + def axis_dtypes(self): + return set(self.int_types).intersection([np.int32, np.int64]) + + def _compare(self, x, axis, exclusive, reverse): + np_out = handle_options(np.cumprod, x, axis, exclusive, reverse) + with self.test_session(), self.test_scope(): + p = array_ops.placeholder(x.dtype) + prod = math_ops.cumprod(p, axis, exclusive, reverse) + tf_out = prod.eval(feed_dict={p: x}) + + self.assertAllClose(np_out, tf_out) + + def _compareAll(self, x, axis): + for exclusive in [True, False]: + for reverse in [True, False]: + self._compare(x, axis, exclusive, reverse) + + def testEmpty(self): + for dtype in self.valid_dtypes: + x = np.zeros([0]).astype(dtype) + for axis in (-1, 0): + self._compareAll(x, axis) + + def testAxisType(self): + for dtype in self.valid_dtypes: + x = np.arange(1, 6).reshape([5]).astype(dtype) + for axis_dtype in self.axis_dtypes(): + with self.test_session(), self.test_scope(): + p = array_ops.placeholder(x.dtype) + axis = constant_op.constant(0, axis_dtype) + math_ops.cumprod(x, axis).eval(feed_dict={p: x}) + + def test1D(self): + for dtype in self.valid_dtypes: + x = np.arange(1, 6).reshape([5]).astype(dtype) + for axis in (-1, 0): + self._compareAll(x, axis) + + def test2D(self): + for dtype in self.valid_dtypes: + x = np.arange(1, 11).reshape([2, 5]).astype(dtype) + for axis in (-2, -1, 0, 1): + self._compareAll(x, axis) + + def test3D(self): + for dtype in self.valid_dtypes: + x = np.arange(1, 21).reshape([2, 2, 5]).astype(dtype) + for axis in (-3, -2, -1, 0, 1, 2): + self._compareAll(x, axis) + + def test6D(self): + for dtype in self.valid_dtypes: + x = np.arange(1, 145).reshape([2, 2, 3, 3, 2, 2]).astype(dtype) + for axis in range(-6, 6, 3): + self._compareAll(x, axis) + + def testInvalidAxis(self): + x = np.arange(0, 10).reshape([2, 5]).astype(np.float32) + with self.test_session(), self.test_scope(): + input_tensor = ops.convert_to_tensor(x) + with self.assertRaisesWithPredicateMatch( + errors_impl.InvalidArgumentError, + lambda e: "Expected scan axis in the range [-2, 2)" in str(e)): + math_ops.cumprod(input_tensor, -3).eval() + with self.assertRaisesWithPredicateMatch( + errors_impl.InvalidArgumentError, + lambda e: "Expected scan axis in the range [-2, 2)" in str(e)): + math_ops.cumprod(input_tensor, 2).eval() + with self.assertRaisesWithPredicateMatch( + errors_impl.InvalidArgumentError, + lambda e: "axis must be a scalar" in str(e)): + math_ops.cumprod(input_tensor, [0]).eval() + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/tensor_array_ops_test.py b/tensorflow/compiler/tests/tensor_array_ops_test.py index ac039e01623b954e291760fb9b50ef8eae3da7c1..a62925a1818da00cb0a9e82e1281db20fb38b208 100644 --- a/tensorflow/compiler/tests/tensor_array_ops_test.py +++ b/tensorflow/compiler/tests/tensor_array_ops_test.py @@ -330,8 +330,7 @@ class TensorArrayTest(xla_test.XLATestCase): # Find two different floating point types, create an array of # the first type, but try to read the other type. if len(self.float_types) > 1: - dtype1 = self.float_types[0] - dtype2 = self.float_types[1] + dtype1, dtype2 = list(self.float_types)[:2] with self.test_session(), self.test_scope(): ta = tensor_array_ops.TensorArray( dtype=dtype1, tensor_array_name="foo", size=3) diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py index a9a3f4f97f649260e9863fff8ff05d046bd91947..0a6fe04d3cdd29f1d40d33be1f4319090e7ba3d1 100644 --- a/tensorflow/compiler/tests/unary_ops_test.py +++ b/tensorflow/compiler/tests/unary_ops_test.py @@ -33,6 +33,17 @@ from tensorflow.python.ops import nn_ops from tensorflow.python.platform import googletest +def nhwc_to_format(x, data_format): + """Converts a numpy array from NHWC format to `data_format`.""" + rank = len(x.shape) + if data_format == "NCHW": + return np.transpose(x, [0, rank - 1] + list(range(1, rank - 1))) + elif data_format == "NHWC": + return x + else: + raise ValueError("Unknown format {}".format(data_format)) + + class UnaryOpsTest(XLATestCase): """Test cases for unary operators.""" @@ -56,7 +67,7 @@ class UnaryOpsTest(XLATestCase): output = op(pinp) result = session.run(output, {pinp: inp}) if equality_test is None: - equality_test = self.assertAllClose + equality_test = self.assertAllCloseAccordingToType equality_test(result, expected, rtol=rtol, atol=atol) def ListsAreClose(self, result, expected, rtol, atol): @@ -76,6 +87,12 @@ class UnaryOpsTest(XLATestCase): array_ops.diag_part, np.arange(36).reshape([2, 3, 2, 3]).astype(dtype), np.array([[0, 7, 14], [21, 28, 35]], dtype=dtype)) + self._assertOpOutputMatchesExpected( + array_ops.diag, np.array([[1, 2], [3, 4]], dtype=dtype), + np.array( + [[[[1, 0], [0, 0]], [[0, 2], [0, 0]]], [[[0, 0], [3, 0]], + [[0, 0], [0, 4]]]], + dtype=dtype)) self._assertOpOutputMatchesExpected( array_ops.identity, @@ -86,6 +103,21 @@ class UnaryOpsTest(XLATestCase): array_ops.matrix_diag, np.array([[1, 2], [3, 4]], dtype=dtype), np.array([[[1, 0], [0, 2]], [[3, 0], [0, 4]]], dtype=dtype)) + self._assertOpOutputMatchesExpected( + array_ops.matrix_diag, np.array([1, 2, 3, 4], dtype=dtype), + np.array( + [[1, 0, 0, 0], [0, 2, 0, 0], [0, 0, 3, 0], [0, 0, 0, 4]], + dtype=dtype)) + self._assertOpOutputMatchesExpected( + array_ops.matrix_diag, + np.array( + [[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]], dtype=dtype), + np.array( + [[[[1, 0, 0], [0, 2, 0], [0, 0, 3]], + [[4, 0, 0], [0, 5, 0], [0, 0, 6]]], + [[[7, 0, 0], [0, 8, 0], [0, 0, 9]], + [[10, 0, 0], [0, 11, 0], [0, 0, 12]]]], + dtype=dtype)) self._assertOpOutputMatchesExpected( array_ops.matrix_diag_part, np.arange(3 * 2 * 4).reshape([3, 2, 4]).astype(dtype), @@ -331,26 +363,23 @@ class UnaryOpsTest(XLATestCase): def testComplexOps(self): for dtype in self.complex_types: - # TODO(b/65408531): Wider support for log (needs atan2). - atan2_supported = self.device == "XLA_GPU" - if atan2_supported: - self._assertOpOutputMatchesExpected( - math_ops.acosh, - np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype), - expected=np.arccosh( - np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype))) + self._assertOpOutputMatchesExpected( + math_ops.acosh, + np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype), + expected=np.arccosh( + np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype))) - self._assertOpOutputMatchesExpected( - math_ops.asinh, - np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype), - expected=np.arcsinh( - np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype))) + self._assertOpOutputMatchesExpected( + math_ops.asinh, + np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype), + expected=np.arcsinh( + np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype))) - self._assertOpOutputMatchesExpected( - math_ops.atanh, - np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype), - expected=np.arctanh( - np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype))) + self._assertOpOutputMatchesExpected( + math_ops.atanh, + np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype), + expected=np.arctanh( + np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype))) self._assertOpOutputMatchesExpected( math_ops.cosh, @@ -377,11 +406,10 @@ class UnaryOpsTest(XLATestCase): np.array([[1, 2j, 2 + 3j]], dtype=dtype), expected=1.0 / np.array([[1, 2j, 2 + 3j]], dtype=dtype)) - if atan2_supported: - self._assertOpOutputMatchesExpected( - math_ops.log, - np.array([[5j, 3 - 2j]], dtype=dtype), - expected=np.log(np.array([[5j, 3 - 2j]], dtype=dtype))) + self._assertOpOutputMatchesExpected( + math_ops.log, + np.array([[5j, 3 - 2j]], dtype=dtype), + expected=np.log(np.array([[5j, 3 - 2j]], dtype=dtype))) self._assertOpOutputMatchesExpected( math_ops.sin, @@ -395,27 +423,26 @@ class UnaryOpsTest(XLATestCase): # TODO(b/34703906): improve log1p implementation and make tolerance # tighter. - if atan2_supported: # TODO(b/34703906): log support - self._assertOpOutputMatchesExpected( - math_ops.log1p, - np.array([[1e-14, 1e-15j, 0.6 - 0.3j]], dtype=dtype), - expected=np.log1p( - np.array([[1e-14, 1e-15j, 0.6 - 0.3j]], dtype=dtype))) + self._assertOpOutputMatchesExpected( + math_ops.log1p, + np.array([[1e-14, 1e-15j, 0.6 - 0.3j]], dtype=dtype), + expected=np.log1p( + np.array([[1e-14, 1e-15j, 0.6 - 0.3j]], dtype=dtype))) - val = np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype) - self._assertOpOutputMatchesExpected( - math_ops.rsqrt, val, expected=1 / np.sqrt(val)) + val = np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype) + self._assertOpOutputMatchesExpected( + math_ops.rsqrt, val, expected=1 / np.sqrt(val)) - self._assertOpOutputMatchesExpected( - math_ops.sigmoid, val, expected=1 / (1 + np.exp(-val))) + self._assertOpOutputMatchesExpected( + math_ops.sigmoid, val, expected=1 / (1 + np.exp(-val))) - self._assertOpOutputMatchesExpected( - math_ops.sqrt, val, expected=np.sqrt(val)) + self._assertOpOutputMatchesExpected( + math_ops.sqrt, val, expected=np.sqrt(val)) - self._assertOpOutputMatchesExpected( - math_ops.tanh, - np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype), - expected=np.tanh(np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype))) + self._assertOpOutputMatchesExpected( + math_ops.tanh, + np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype), + expected=np.tanh(np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype))) self._assertOpOutputMatchesExpected( math_ops.tan, @@ -448,12 +475,10 @@ class UnaryOpsTest(XLATestCase): np.array([[-4j, 3 + 2j], [2, -1j]], dtype=dtype), expected=np.array([[1, 1], [1, 1]], dtype=dtype)) - if atan2_supported: # TODO(b/34703906): atan2 support - self._assertOpOutputMatchesExpected( - math_ops.angle, - np.array([1 + 3j, -4 + 7j, 2.7, -3j], dtype=dtype), - expected=np.angle( - np.array([1 + 3j, -4 + 7j, 2.7, -3j], dtype=dtype))) + self._assertOpOutputMatchesExpected( + math_ops.angle, + np.array([1 + 3j, -4 + 7j, 2.7, -3j], dtype=dtype), + expected=np.angle(np.array([1 + 3j, -4 + 7j, 2.7, -3j], dtype=dtype))) self._assertOpOutputMatchesExpected( math_ops.conj, @@ -541,7 +566,8 @@ class UnaryOpsTest(XLATestCase): def testCast(self): shapes = [[], [4], [2, 3], [2, 0, 4]] - types = [dtypes.bool, dtypes.int32, dtypes.float32] + self.complex_tf_types + types = (set([dtypes.bool, dtypes.int32, dtypes.float32]) | + self.complex_tf_types) for shape in shapes: for src_type in types: for dst_type in types: @@ -641,55 +667,88 @@ class UnaryOpsTest(XLATestCase): equality_test=self.ListsAreClose) def testDepthToSpace(self): + def make_op(data_format): + def op(x): + return array_ops.depth_to_space(x, block_size=2, + data_format=data_format) + return op + for dtype in self.numeric_types: - self._assertOpOutputMatchesExpected( - lambda x: array_ops.depth_to_space(x, block_size=2), - np.array([[[[1, 2, 3, 4]]]], dtype=dtype), - expected=np.array([[[[1], [2]], - [[3], [4]]]], dtype=dtype)) + for data_format in ["NCHW", "NHWC"]: + self._assertOpOutputMatchesExpected( + make_op(data_format), + nhwc_to_format(np.array([[[[1, 2, 3, 4]]]], dtype=dtype), + data_format), + expected=nhwc_to_format(np.array([[[[1], [2]], + [[3], [4]]]], dtype=dtype), + data_format)) - self._assertOpOutputMatchesExpected( - lambda x: array_ops.depth_to_space(x, block_size=2), - np.array([[[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]]]], dtype=dtype), - expected=np.array([[[[1, 2, 3], [4, 5, 6]], - [[7, 8, 9], [10, 11, 12]]]], dtype=dtype)) + self._assertOpOutputMatchesExpected( + make_op(data_format), + nhwc_to_format( + np.array([[[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]]]], + dtype=dtype), + data_format), + expected=nhwc_to_format( + np.array([[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]], + dtype=dtype), + data_format)) - self._assertOpOutputMatchesExpected( - lambda x: array_ops.depth_to_space(x, block_size=2), - np.array([[[[1, 2, 3, 4], - [5, 6, 7, 8]], - [[9, 10, 11, 12], - [13, 14, 15, 16]]]], dtype=dtype), - expected=np.array([[[[1], [2], [5], [6]], - [[3], [4], [7], [8]], - [[9], [10], [13], [14]], - [[11], [12], [15], [16]]]], dtype=dtype)) + self._assertOpOutputMatchesExpected( + make_op(data_format), + nhwc_to_format( + np.array([[[[1, 2, 3, 4], + [5, 6, 7, 8]], + [[9, 10, 11, 12], + [13, 14, 15, 16]]]], dtype=dtype), + data_format), + expected=nhwc_to_format( + np.array([[[[1], [2], [5], [6]], + [[3], [4], [7], [8]], + [[9], [10], [13], [14]], + [[11], [12], [15], [16]]]], dtype=dtype), + data_format)) def testSpaceToDepth(self): + def make_op(data_format): + def op(x): + return array_ops.space_to_depth(x, block_size=2, + data_format=data_format) + return op + for dtype in self.numeric_types: - self._assertOpOutputMatchesExpected( - lambda x: array_ops.space_to_depth(x, block_size=2), - np.array([[[[1], [2]], - [[3], [4]]]], dtype=dtype), - expected=np.array([[[[1, 2, 3, 4]]]], dtype=dtype)) + for data_format in ["NCHW", "NHWC"]: + self._assertOpOutputMatchesExpected( + make_op(data_format), + nhwc_to_format(np.array([[[[1], [2]], + [[3], [4]]]], dtype=dtype), + data_format), + expected=nhwc_to_format(np.array([[[[1, 2, 3, 4]]]], dtype=dtype), + data_format)) - self._assertOpOutputMatchesExpected( - lambda x: array_ops.space_to_depth(x, block_size=2), - np.array([[[[1, 2, 3], [4, 5, 6]], - [[7, 8, 9], [10, 11, 12]]]], dtype=dtype), - expected=np.array([[[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]]]], - dtype=dtype)) + self._assertOpOutputMatchesExpected( + make_op(data_format), + nhwc_to_format(np.array([[[[1, 2, 3], [4, 5, 6]], + [[7, 8, 9], [10, 11, 12]]]], dtype=dtype), + data_format), + expected=nhwc_to_format( + np.array([[[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]]]], + dtype=dtype), + data_format)) - self._assertOpOutputMatchesExpected( - lambda x: array_ops.space_to_depth(x, block_size=2), - np.array([[[[1], [2], [5], [6]], - [[3], [4], [7], [8]], - [[9], [10], [13], [14]], - [[11], [12], [15], [16]]]], dtype=dtype), - expected=np.array([[[[1, 2, 3, 4], - [5, 6, 7, 8]], - [[9, 10, 11, 12], - [13, 14, 15, 16]]]], dtype=dtype)) + self._assertOpOutputMatchesExpected( + make_op(data_format), + nhwc_to_format(np.array([[[[1], [2], [5], [6]], + [[3], [4], [7], [8]], + [[9], [10], [13], [14]], + [[11], [12], [15], [16]]]], dtype=dtype), + data_format), + expected=nhwc_to_format( + np.array([[[[1, 2, 3, 4], + [5, 6, 7, 8]], + [[9, 10, 11, 12], + [13, 14, 15, 16]]]], dtype=dtype), + data_format)) def _assertSoftplusMatchesExpected(self, features, dtype): features = np.array(features, dtype=dtype) diff --git a/tensorflow/compiler/tests/variable_ops_test.py b/tensorflow/compiler/tests/variable_ops_test.py index c50342dee45eba6ae54f01653ecc81ef096b547b..b08d6ab21e0746558cb3d4818d4c822c45d2e9ee 100644 --- a/tensorflow/compiler/tests/variable_ops_test.py +++ b/tensorflow/compiler/tests/variable_ops_test.py @@ -107,11 +107,26 @@ class VariableOpsTest(XLATestCase): [[[30, 31, 32], [33, 34, 35]], [[0, 1, 2], [3, 4, 5]]]], ).astype(dtype), sess.run(x)) + def testShape(self): + for dtype in self.numeric_types: + init = np.ones([2, 3]).astype(dtype) + with self.test_session() as session, self.test_scope(): + v = resource_variable_ops.ResourceVariable(init) + session.run(variables.variables_initializer([v])) + h = v.handle + s32, s64 = session.run([ + resource_variable_ops.variable_shape(h), + resource_variable_ops.variable_shape(h, out_type=dtypes.int64) + ]) + self.assertEqual(s32.dtype, np.int32) + self.assertEqual(s64.dtype, np.int64) + self.assertAllEqual(s32, [2, 3]) + self.assertAllEqual(s64, [2, 3]) + def testReadWrite(self): """Tests initialization, reading, and writing a resource variable.""" for dtype in self.numeric_types: with self.test_session() as session: - print(ops.get_default_graph()) with self.test_scope(): with variable_scope.variable_scope("ascope", use_resource=True): x = variable_scope.get_variable( diff --git a/tensorflow/compiler/tests/xla_test.py b/tensorflow/compiler/tests/xla_test.py index 0be127997e5211f810ca791187486760881fe172..7e1f5c76ed65946363cc3c113ab1a9862f87b289 100644 --- a/tensorflow/compiler/tests/xla_test.py +++ b/tensorflow/compiler/tests/xla_test.py @@ -53,41 +53,100 @@ class XLATestCase(test.TestCase): super(XLATestCase, self).__init__(method_name) self.device = FLAGS.test_device self.has_custom_call = (self.device == 'XLA_CPU') - self.all_tf_types = [ + self._all_tf_types = set([ dtypes.as_dtype(types_pb2.DataType.Value(name)) for name in FLAGS.types.split(',') - ] - self.int_tf_types = [ - dtype for dtype in self.all_tf_types if dtype.is_integer - ] - self.float_tf_types = [ - dtype for dtype in self.all_tf_types if dtype.is_floating - ] - self.complex_tf_types = [ - dtype for dtype in self.all_tf_types if dtype.is_complex - ] - self.numeric_tf_types = ( - self.int_tf_types + self.float_tf_types + self.complex_tf_types) - - self.all_types = [dtype.as_numpy_dtype for dtype in self.all_tf_types] - self.int_types = [dtype.as_numpy_dtype for dtype in self.int_tf_types] - self.float_types = [dtype.as_numpy_dtype for dtype in self.float_tf_types] - self.complex_types = [ + ]) + self.int_tf_types = set([ + dtype for dtype in self._all_tf_types if dtype.is_integer + ]) + self._float_tf_types = set([ + dtype for dtype in self._all_tf_types if dtype.is_floating + ]) + self.complex_tf_types = set([ + dtype for dtype in self._all_tf_types if dtype.is_complex + ]) + self._numeric_tf_types = set( + self.int_tf_types | self._float_tf_types | self.complex_tf_types) + + self._all_types = set( + [dtype.as_numpy_dtype for dtype in self._all_tf_types]) + self.int_types = set([dtype.as_numpy_dtype for dtype in self.int_tf_types]) + self._float_types = set( + [dtype.as_numpy_dtype for dtype in self._float_tf_types]) + self.complex_types = set([ dtype.as_numpy_dtype for dtype in self.complex_tf_types - ] - self.numeric_types = self.int_types + self.float_types + self.complex_types + ]) + self._numeric_types = set( + self.int_types | self._float_types | self.complex_types) # Parse the manifest file, if any, into a regex identifying tests to # disable self.disabled_regex = None + self._method_types_filter = dict() + # TODO(xpan): Make it text proto if it doesn't scale. + # Each line of the manifest file specifies an entry. The entry can be + # 1) TestNameRegex // E.g. CumprodTest.* Or + # 2) TestName TypeName // E.g. AdamOptimizerTest.testSharing DT_BFLOAT16 + # The 1) disables the entire test. While 2) only filter some numeric types + # so that they are not used in those tests. + if FLAGS.disabled_manifest is not None: comments_re = re.compile('#.*$') manifest_file = open(FLAGS.disabled_manifest, 'r') - lines = manifest_file.read().splitlines() - lines = [comments_re.sub('', l).strip() for l in lines] - self.disabled_regex = re.compile('|'.join(lines)) + disabled_tests = [] + disabled_method_types = [] + for l in manifest_file.read().splitlines(): + entry = comments_re.sub('', l).strip().split(' ') + if len(entry) == 1: + disabled_tests.append(entry[0]) + elif len(entry) == 2: + disabled_method_types.append( + (entry[0], entry[1].strip().split(','))) + else: + raise ValueError('Bad entry in manifest file.') + + self.disabled_regex = re.compile('|'.join(disabled_tests)) + for method, types in disabled_method_types: + self._method_types_filter[method] = set([ + dtypes.as_dtype(types_pb2.DataType.Value(name)).as_numpy_dtype + for name in types]) manifest_file.close() + @property + def all_tf_types(self): + name = '{}.{}'.format(type(self).__name__, self._testMethodName) + tf_types = set([dtypes.as_dtype(t) + for t in self._method_types_filter.get(name, set())]) + return self._all_tf_types - tf_types + + @property + def float_types(self): + name = '{}.{}'.format(type(self).__name__, self._testMethodName) + return self._float_types - self._method_types_filter.get(name, set()) + + @property + def float_tf_types(self): + name = '{}.{}'.format(type(self).__name__, self._testMethodName) + return self._float_tf_types - self._method_types_filter.get(name, set()) + + @property + def numeric_tf_types(self): + name = '{}.{}'.format(type(self).__name__, self._testMethodName) + tf_types = set([dtypes.as_dtype(t) + for t in self._method_types_filter.get(name, set())]) + return self._numeric_tf_types - tf_types + + @property + def numeric_types(self): + name = '{}.{}'.format(type(self).__name__, self._testMethodName) + return self._numeric_types - self._method_types_filter.get(name, set()) + + @property + def all_types(self): + name = '{}.{}'.format(type(self).__name__, self._testMethodName) + return self._all_types - self._method_types_filter.get(name, set()) + def setUp(self): super(XLATestCase, self).setUp() name = '{}.{}'.format(type(self).__name__, self._testMethodName) diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 5a81438b1c48e7f0ef66dae072092974db24c621..3c7dfef03dfb5d86dd63fd4aa84ad56081833035 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -1,6 +1,6 @@ licenses(["notice"]) # Apache 2.0 -load("//tensorflow:tensorflow.bzl", "tf_cc_test") +load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test") package_group( name = "internal", @@ -25,6 +25,30 @@ package( load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured") load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library") +cc_library( + name = "tf2xla_supported_ops_lib", + srcs = ["tf2xla_supported_ops.cc"], + hdrs = ["tf2xla_supported_ops.h"], + visibility = ["//visibility:public"], + deps = [ + ":xla_compiler", + "//tensorflow/compiler/tf2xla/kernels:xla_cpu_only_ops", + "//tensorflow/compiler/tf2xla/kernels:xla_ops", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:ops", + "//tensorflow/core:protos_all_cc", + ], +) + +tf_cc_binary( + name = "tf2xla_supported_ops", + srcs = ["tf2xla_supported_ops_main.cc"], + visibility = ["//visibility:public"], + deps = [":tf2xla_supported_ops_lib"], +) + xla_proto_library( name = "tf2xla_proto", srcs = ["tf2xla.proto"], @@ -67,7 +91,6 @@ cc_library( # Keep dependencies to a minimum here; this library is used in every AOT # binary produced by tfcompile. "//tensorflow/compiler/aot:runtime", - "//tensorflow/compiler/tf2xla:xla_local_runtime_context", "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/core:framework_lite", ], @@ -97,18 +120,21 @@ cc_library( cc_library( name = "xla_compiler", srcs = [ + "const_analysis.cc", + "graph_compiler.cc", "xla_compilation_device.cc", "xla_compiler.cc", "xla_context.cc", "xla_helpers.cc", "xla_op_kernel.cc", "xla_op_registry.cc", - "graph_compiler.cc", + "xla_resource.cc", "xla_cpu_backend.cc", ] + if_cuda_is_configured([ "xla_gpu_backend.cc", ]), hdrs = [ + "const_analysis.h", "graph_compiler.h", "xla_compilation_device.h", "xla_compiler.h", @@ -116,11 +142,11 @@ cc_library( "xla_helpers.h", "xla_op_kernel.h", "xla_op_registry.h", + "xla_resource.h", ], visibility = [":friends"], deps = [ ":common", - ":const_analysis", ":dump_graph", ":functionalize_control_flow", ":sharding_util", @@ -180,6 +206,7 @@ cc_library( deps = [ "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:sharding_builder", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", @@ -215,6 +242,7 @@ cc_library( "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", + "//tensorflow/core:graph", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", ], @@ -328,28 +356,16 @@ tf_cc_test( ], ) -cc_library( - name = "const_analysis", - srcs = ["const_analysis.cc"], - hdrs = ["const_analysis.h"], - deps = [ - "//tensorflow/core:core_cpu", - "//tensorflow/core:core_cpu_internal", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - ], -) - tf_cc_test( name = "const_analysis_test", size = "small", srcs = ["const_analysis_test.cc"], deps = [ - ":const_analysis", + ":xla_compiler", "//tensorflow/cc:cc_ops", "//tensorflow/cc:function_ops", "//tensorflow/cc:ops", + "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:ops", "//tensorflow/core:test", @@ -357,13 +373,6 @@ tf_cc_test( ], ) -cc_library( - name = "xla_local_runtime_context", - hdrs = ["xla_local_runtime_context.h"], - visibility = ["//visibility:public"], - deps = ["//tensorflow/core:framework_lite"], -) - cc_library( name = "dump_graph", srcs = [ @@ -400,6 +409,7 @@ cc_library( "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", + "//tensorflow/core:graph", "//tensorflow/core:lib", ], ) diff --git a/tensorflow/compiler/tf2xla/const_analysis.cc b/tensorflow/compiler/tf2xla/const_analysis.cc index d57273d84442c17565a6ace1c29170a0f3ba583b..0249500910c6ae441f038fe9ad6178794f1997ac 100644 --- a/tensorflow/compiler/tf2xla/const_analysis.cc +++ b/tensorflow/compiler/tf2xla/const_analysis.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/graph/algorithm.h" @@ -27,93 +28,18 @@ namespace tensorflow { // compile-time constants. Status BackwardsConstAnalysis(const Graph& g, std::vector* compile_time_const_args) { - // TODO(phawkins): annotate these on the kernel registrations, rather than - // using a hard-coded list. - // (operator, argument) pairs that must be compile-time constants. - const std::unordered_multimap compile_time_const_inputs = { - {"All", "reduction_indices"}, - {"Any", "reduction_indices"}, - {"ArgMin", "dimension"}, - {"ArgMax", "dimension"}, - {"AvgPoolGrad", "orig_input_shape"}, - {"AvgPool3DGrad", "orig_input_shape"}, - {"BatchToSpace", "crops"}, - {"BatchToSpaceND", "block_shape"}, - {"BatchToSpaceND", "crops"}, - {"BroadcastArgs", "s0"}, - {"BroadcastArgs", "s1"}, - {"BroadcastGradientArgs", "s0"}, - {"BroadcastGradientArgs", "s1"}, - {"Concat", "concat_dim"}, - {"ConcatV2", "axis"}, - {"ConcatOffset", "concat_dim"}, - {"ConcatOffset", "shape"}, - {"Conv2DBackpropFilter", "filter_sizes"}, - {"Conv2DBackpropInput", "input_sizes"}, - {"Conv3DBackpropFilterV2", "filter_sizes"}, - {"Conv3DBackpropInputV2", "input_sizes"}, - {"DepthwiseConv2dNativeBackpropFilter", "filter_sizes"}, - {"DepthwiseConv2dNativeBackpropInput", "input_sizes"}, - {"DynamicStitch", "indices"}, - {"ExpandDims", "dim"}, - {"Fill", "dims"}, - {"GatherV2", "axis"}, - {"InvertPermutation", "x"}, - {"LinSpace", "start"}, - {"LinSpace", "stop"}, - {"LinSpace", "num"}, - {"Max", "reduction_indices"}, - {"Mean", "reduction_indices"}, - {"Min", "reduction_indices"}, - {"OneHot", "depth"}, - {"Pad", "paddings"}, - {"PadV2", "paddings"}, - {"MirrorPad", "paddings"}, - {"Multinomial", "num_samples"}, - {"Prod", "reduction_indices"}, - {"RandomStandardNormal", "shape"}, - {"RandomUniform", "shape"}, - {"RandomUniformInt", "shape"}, - {"Range", "start"}, - {"Range", "limit"}, - {"Range", "delta"}, - {"Reshape", "shape"}, - {"ResourceStridedSliceAssign", "begin"}, - {"ResourceStridedSliceAssign", "end"}, - {"ResourceStridedSliceAssign", "strides"}, - {"Reverse", "dims"}, - {"ReverseV2", "axis"}, - {"Slice", "begin"}, - {"Slice", "size"}, - {"SpaceToBatch", "paddings"}, - {"SpaceToBatchND", "block_shape"}, - {"SpaceToBatchND", "paddings"}, - {"Split", "split_dim"}, - {"SplitV", "split_dim"}, - {"SplitV", "size_splits"}, - {"StackV2", "max_size"}, - {"StridedSlice", "begin"}, - {"StridedSlice", "end"}, - {"StridedSlice", "strides"}, - {"StridedSliceGrad", "shape"}, - {"StridedSliceGrad", "begin"}, - {"StridedSliceGrad", "end"}, - {"StridedSliceGrad", "strides"}, - {"Sum", "reduction_indices"}, - {"TensorArrayV3", "size"}, - {"TensorArraySplitV3", "lengths"}, - {"Tile", "multiples"}, - {"Transpose", "perm"}}; - // Operators that don't look at the data of their inputs, just the shapes. const std::unordered_set metadata_ops = { - "Rank", "Shape", "ShapeN", "Size", + "Rank", + "Shape", + "ShapeN", + "Size", }; Status status; std::unordered_set must_be_const; - auto visit = [&status, &metadata_ops, &compile_time_const_inputs, - &must_be_const, compile_time_const_args](Node* node) { + auto visit = [&status, &metadata_ops, &must_be_const, + compile_time_const_args](Node* node) { if (!status.ok()) return; // If this is a metadata-only op, don't propagate the const requirement. @@ -136,16 +62,17 @@ Status BackwardsConstAnalysis(const Graph& g, } // Mark any compile-time constant operator arguments as const. - auto range = compile_time_const_inputs.equal_range(node->type_string()); - if (range.first == range.second) return; + const std::unordered_set* const_inputs = + XlaOpRegistry::CompileTimeConstantInputs(node->type_string()); + if (!const_inputs) return; NameRangeMap input_name_ranges; status = NameRangesForNode(*node, node->op_def(), &input_name_ranges, nullptr); if (!status.ok()) return; - for (auto it = range.first; it != range.second; ++it) { - auto name_range = input_name_ranges.find(it->second); + for (const string& input : *const_inputs) { + auto name_range = input_name_ranges.find(input); if (name_range == input_name_ranges.end()) continue; for (Edge const* edge : node->in_edges()) { diff --git a/tensorflow/compiler/tf2xla/dump_graph.cc b/tensorflow/compiler/tf2xla/dump_graph.cc index ddd912b87315f7943915153b5bf73531107af54d..03603ee9baefd1d20d220faf63c9c1c427ebdf31 100644 --- a/tensorflow/compiler/tf2xla/dump_graph.cc +++ b/tensorflow/compiler/tf2xla/dump_graph.cc @@ -63,7 +63,12 @@ string MakeUniquePath(string name) { string DumpGraphDefToFile(const string& name, GraphDef const& graph_def) { string path = MakeUniquePath(name); - TF_CHECK_OK(WriteTextProto(Env::Default(), path, graph_def)); + Status status = WriteTextProto(Env::Default(), path, graph_def); + if (!status.ok()) { + VLOG(1) << "Failed to dump GraphDef to file: " << path << " : " << status; + path.clear(); + path = "(unavailable)"; + } return path; } @@ -79,7 +84,13 @@ string DumpGraphToFile(const string& name, Graph const& graph, string DumpFunctionDefToFile(const string& name, FunctionDef const& fdef) { string path = MakeUniquePath(name); - TF_CHECK_OK(WriteTextProto(Env::Default(), path, fdef)); + Status status = WriteTextProto(Env::Default(), path, fdef); + if (!status.ok()) { + VLOG(1) << "Failed to dump FunctionDef to file: " << path << " : " + << status; + path.clear(); + path = "(unavailable)"; + } return path; } diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc index 5726d8294a7c7fe81d7f6b803af89ca305aa2deb..1d9e0fb33ee4a4229c78d116831e95391a5ac3f8 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/control_flow.h" #include "tensorflow/core/lib/gtl/optional.h" @@ -36,6 +37,8 @@ namespace tensorflow { namespace { +using xla::StatusOr; + const char* const kArgOp = "_Arg"; const char* const kRetValOp = "_Retval"; @@ -75,6 +78,20 @@ struct Frame { std::unordered_set nodes; }; +// Comparison function used for sorting nodes consistently. +// a) resource variables are last, and +// b) sort lexicographically by name (for deterministic output). +struct NodeCmp { + bool operator()(const Node* lhs, const Node* rhs) const { + bool lhs_is_resource = + lhs->num_inputs() > 0 ? (lhs->input_type(0) == DT_RESOURCE) : false; + bool rhs_is_resource = + rhs->num_inputs() > 0 ? (rhs->input_type(0) == DT_RESOURCE) : false; + return std::tie(lhs_is_resource, lhs->name()) < + std::tie(rhs_is_resource, rhs->name()); + } +}; + // Returns a textual representation of the names of the nodes in the input. template string NodesToString(const T& nodes) { @@ -140,7 +157,7 @@ Status CopySubgraph(const Graph& graph, const Frame* frame, return Status::OK(); } -xla::StatusOr AddNode(const NodeDef& node_def, Graph* graph) { +StatusOr AddNode(const NodeDef& node_def, Graph* graph) { Status status; Node* inserted_node = graph->AddNode(node_def, &status); if (!status.ok()) { @@ -149,7 +166,7 @@ xla::StatusOr AddNode(const NodeDef& node_def, Graph* graph) { return inserted_node; } -xla::StatusOr BuildArgNode(Graph* graph, DataType type, int index) { +StatusOr BuildArgNode(Graph* graph, DataType type, int index) { NodeDef arg_def; NodeDefBuilder builder(strings::StrCat(kArgOp, index), kArgOp); builder.Attr("T", type); @@ -158,7 +175,7 @@ xla::StatusOr BuildArgNode(Graph* graph, DataType type, int index) { return AddNode(arg_def, graph); } -xla::StatusOr BuildRetvalNode(Graph* graph, DataType type, int index) { +StatusOr BuildRetvalNode(Graph* graph, DataType type, int index) { NodeDef ret_def; ret_def.set_op(kRetValOp); ret_def.set_name(strings::StrCat(kRetValOp, index)); @@ -309,16 +326,9 @@ Status FunctionalizeLoop(Graph* graph, Frame* frame, } frame->args = std::move(args); - // Order the arguments so that: - // a) resource variables are last, and - // b) sort lexicographically by name (for deterministic output). - std::sort(frame->args.begin(), frame->args.end(), - [](const Arg& a, const Arg& b) { - bool a_is_resource = (a.enter->input_type(0) == DT_RESOURCE); - bool b_is_resource = (b.enter->input_type(0) == DT_RESOURCE); - return std::tie(a_is_resource, a.enter->name()) < - std::tie(b_is_resource, b.enter->name()); - }); + std::sort( + frame->args.begin(), frame->args.end(), + [](const Arg& a, const Arg& b) { return NodeCmp()(a.enter, b.enter); }); if (frame->loop_cond == nullptr) { return errors::InvalidArgument("Loop ", frame->name, @@ -528,259 +538,127 @@ Status FunctionalizeLoop(Graph* graph, Frame* frame, class FunctionalizeCond { public: - // Identifies the connected parts of the tf.Cond. - struct ClusterHandle { - explicit ClusterHandle(int representative = -1) - : representative(representative) {} + // All nodes are assumed to be either in no branch, then branch, else branch, + // or both branches (such as merge nodes). + enum Branch { + kElseBranch = 0, + kThenBranch = 1, + kBoth = 2, + kNeither = 3, + kNumBranchTypes = 4 + }; - bool operator==(const ClusterHandle& other) const { - return representative == other.representative; - } + // Returns a textual representation of the Branch b. + static string Branch_Name(FunctionalizeCond::Branch b); - bool operator!=(const ClusterHandle& other) const { - return !(*this == other); - } + // Functionalize all the switch-merge nodes of a loop-free graph into XlaIf + // nodes. That is, attempt to transform every remaining switch and merge nodes + // in the graph into XlaIf nodes. + // Precondition: All while loops have been removed from graph. + static Status Functionalize(Graph* graph, FunctionLibraryDefinition* library); - bool operator<(const ClusterHandle& other) const { - return representative < other.representative; + private: + // CondArgNode represents a input to the conditional and its corresponding + // switch nodes. + struct CondArgNode { + explicit CondArgNode(Node* input) : input(input) {} + string ToString() const { + return strings::StrCat("input=", input->name(), + " switches=", NodesToString(switch_nodes)); } - bool operator>(const ClusterHandle& other) const { - return representative > other.representative; - } + Node* input; + std::vector switch_nodes; + }; + using CondArgNodes = std::vector; + struct ForwardFlowNode { + explicit ForwardFlowNode(Branch branch = Branch::kNeither) + : branch(branch), count(0) {} string ToString() const { - return strings::StrCat("Cluster_", representative); + return strings::StrCat("branch=", Branch_Name(branch), " count=", count); } - - // Vector of UnionFind indexable by ClusterHandle and Node*. - struct Vector { - explicit Vector(size_t size) : clusters(size) {} - - UnionFind& at(const ClusterHandle& cluster) { - return clusters.at(cluster.representative); - } - - UnionFind& at(const Node* node) { - return clusters.at(node->id()); - } - - UnionFind& operator[](const Node* node) { - return clusters.at(node->id()); - } - - size_t size() const { return clusters.size(); } - - void resize(size_t count) { return clusters.resize(count); } - - private: - std::vector> clusters; - }; - - private: - int representative; + Branch branch; + int count; }; - // Represents a node in the clustered graph consisting of switch_nodes, - // merge_nodes as well as the edges into and out of this node to other - // Clusters. Each Cluster corresponds to a ClusterHandle and has a - // corresponding representative. - struct Cluster { - std::unordered_set switch_nodes; - std::unordered_set merge_nodes; - std::unordered_set in_nodes; - std::unordered_set out_nodes; - - // A member of the ClusterHandle corresponding to this Cluster. - ClusterHandle representative; - bool visited = false; - }; + struct PredicateSwitches { + explicit PredicateSwitches(Node* predicate) : predicate(predicate) {} - // Represent the clustered graph as map from cluster representative to - // Cluster. - using ClusteredGraph = std::map; - - // The arguments and condition of a XlaIf. The arguments are ordered by node - // id in the original graph. - struct CondArgs { - struct CondCmp { - bool operator()(const Node* lhs, const Node* rhs) const { - bool lhs_is_resource = - lhs->num_inputs() > 0 ? (lhs->input_type(0) == DT_RESOURCE) : false; - bool rhs_is_resource = - rhs->num_inputs() > 0 ? (rhs->input_type(0) == DT_RESOURCE) : false; - return std::tie(lhs_is_resource, lhs->name()) < - std::tie(rhs_is_resource, rhs->name()); - } - }; - Node* conditional = nullptr; - std::set args; + Node* predicate; + std::vector switches; }; - static Status Functionalize(Graph* graph, FunctionLibraryDefinition* library); - - private: FunctionalizeCond(Graph* graph, FunctionLibraryDefinition* library) - : clusters_(graph->num_node_ids()), library_(library), graph_(graph) {} - - // Returns a vector of Switch nodes from the clustered graph where the nodes - // are sorted by the number of switch nodes minus number of merge nodes - // from a root of the clustered graph to the given Merge node, with ties - // broken by the representative of the Cluster. This corresponds to sorting by - // nesting depth, from deepest nested to outermost. - std::vector> SortedSwitchNodes(); - - // Returns whether the graph has no conditionals. - bool NoConditionals() const { return merge_nodes_.empty(); } - - // Construct the clustered graph by creating nodes for each cluster and the - // connections between the clusters. Switch and Merge nodes partition - // clusters, so iterate over those. Note: a Cluster may have neither a - // Merge or Switch but will have an in/out edge from a Cluster that has. - void CreateClusters(); - - // Creates the clustered graph by identifying all the edges between different - // clusters and collecting all switch and merge nodes that correspond to a - // cluster. - void CreateClusteredGraph(); - - // If `from` and `to` correspond to different clusters, then merge the nodes - // in the clustered graph corresponding to `from` and `to`. - // - // If `remove_from_graph` is specified then the `from` node is also removed - // from the clustered graph post contracting the edge. - void ContractEdge(Cluster* from, Cluster* to, bool remove_from_graph = false); - - // Converts a Merge node to a XlaIf. This encapsulates the process of - // extracting the bodies needed for the then and else branch, creates a XlaIf - // node, removing the nodes of the branches from the graph and replacing the - // merge node with a XlaIf. - Status ConvertCorrespondingMergeToXlaIf(Cluster* switch_cluster); - - // Removes a Switch cluster feeding directly into a Merge cluster by removing - // the Switch and Merge nodes and collapsing into a single cluster. - Status RemoveTrivialSwitch(Cluster* switch_cluster); - - // Returns the merge cluster corresponding to the switch node. This function - // only returns the merge cluster in the case where we have a switch node that - // is the single entry point for all paths to a common merge cluster, this - // merge cluster may be created by combining multiple merge clusters, that - // share the switch cluster as common ancestor, together. - // - // Switch - // / \ - // Branch Branch - // \ / - // merge_cluster - // - // Note: either of the branches may be empty. The case where both branches are - // empty is handled by RemoveTrivialSwitch. - gtl::optional CreateCorrespondingMergeCluster( - Cluster* switch_cluster); - - // Determines the arguments needed as input to the Merge cluster originating - // from the Switch cluster. - xla::StatusOr DetermineCondArgs(const Cluster& merge_cluster, - const Cluster& switch_cluster); - - // Builds a XlaIfOp to replace the Merge node with. - xla::StatusOr BuildAndAddXlaIfOp(const CondArgs& cond_args, - const Cluster& merge_cluster, - const std::vector& outputs); + : library_(library), graph_(graph) {} + + // Perform the actual cond functionalization. Iterate over groups of switch + // nodes (linked by common predicate), from innermost to outermost, and + // extract into XlaIf nodes. + Status FunctionalizeInternal(); + + // Determines the branch_map (mapping from node to branch of cond) and + // frontier (the nodes where the cond ends). + StatusOr, + std::unordered_set>> + DetermineBranchMapAndFrontier(const std::vector& switches); + + // Returns XlaIf node created from subgraph of merge and switch nodes. This + // encapsulates the process of extracting the bodies needed for the then and + // else branch, creates a XlaIf node, removing the nodes of the branches from + // the graph and replacing the merge node with a XlaIf. + StatusOr ConvertToXlaIf(const CondArgNodes& cond_arg_nodes, + const std::vector& switch_nodes, + const std::vector& merge_nodes, + Node* predicate); + + // Builds a XlaIfOp to replace the Switch-Graph-Merge cluster with. + StatusOr BuildAndAddXlaIfOp(const CondArgNodes& cond_arg_nodes, + const std::vector& switch_nodes, + const std::vector& merge_nodes, + Node* predicate); // Extracts a function body corresponding to the given input edge of the merge // node. - Status ExtractBody(const CondArgs& cond_args, const Cluster& merge_cluster, - const std::vector& outputs, int input_edge, + Status ExtractBody(const CondArgNodes& cond_arg_nodes, + const std::vector& switch_nodes, + const std::vector& merge_nodes, int input_edge, Graph* body); // Adds all the input edges to `if_node` corresponding to the arguments. - Status AddInputEdges(const CondArgs& cond_args, Node* if_node); + Status AddInputEdges(const CondArgNodes& cond_arg_nodes, Node* predicate, + Node* if_node); // Adds all output edges from the `if_node`. Status AddOutputEdges(const std::vector& outputs, Node* if_node); - // Removes all nodes from the graph that are part of cluster. - void RemoveClusterNodes(Cluster* cluster); - - // Removes all argument nodes that are unused. - template - void RemoveUnusedArgs(const T& args); - - // Removes all Merge nodes in merge_cluster. - void RemoveMergeNodes(Cluster* merge_cluster); - - // Returns the representative member of the corresponding cluster. - ClusterHandle Representative(const Node* node) { - return clusters_.at(node).Get(); - } + // Returns the switches of graph_ (along with grouping predicates) in + // postorder. Dead switch nodes are skipped and removed from the graph. + std::vector DeterminePredicateSwitchOrder(); + + // Update the state for destination based on the state of source and the node + // being updated. + Status Join(const ForwardFlowNode& src_state, const Node* dst, + ForwardFlowNode* dst_state); + + // Ensure that all nodes in the branch_map are dominated by the switch + // nodes. Returns nodes that are not dominated by the switches but are a + // control dependency of a node in the cond, and remove such control + // dependencies. + StatusOr> EnsureDominanceAndReturnNonDominatedControlNodes( + const std::unordered_map& branch_map, + const std::vector& switches); + + // Validates that the frontier of nodes for the conditional + // section are as expected. + Status ValidateFrontier( + const std::unordered_map& branch_map, + const std::unordered_set& frontier); - ClusteredGraph clustered_graph_; - ClusterHandle::Vector clusters_; - std::unordered_set merge_nodes_; - std::unordered_set switch_nodes_; FunctionLibraryDefinition* library_; Graph* graph_; }; -std::ostream& operator<<(std::ostream& os, - const FunctionalizeCond::ClusterHandle& c) { - os << c.ToString(); - return os; -} - -// Returns a dot representation of the clustered graph showing the connections -// between the nodes and the nodes in each cluster. -string DebugString(const Graph& graph, - FunctionalizeCond::ClusterHandle::Vector* clusters) { - string ret = "digraph {\ncompound=true;labeljust=\"r\";ranksep=0.24\n"; - std::map subgraphs; - auto name = [](const Node* n) { - return strings::StrCat(n->type_string(), "_", n->id()); - }; - for (Node* n : graph.nodes()) { - strings::StrAppend(&subgraphs[clusters->at(n).Get()], n->id(), " [label=\"", - name(n), "\"];\n"); - } - for (auto kv : subgraphs) { - strings::StrAppend(&ret, "subgraph cluster_", kv.first.ToString(), " {\n", - "style=filled; color=lightgrey;", "label = \"", - kv.first.ToString(), "\";\n", kv.second, "}\n"); - } - for (Node* n : graph.nodes()) { - for (Node* in : n->in_nodes()) { - strings::StrAppend(&ret, in->id(), " -> ", n->id(), ";\n"); - } - } - return strings::StrCat(ret, "} // end"); -} - -string DebugString(const FunctionalizeCond::ClusteredGraph& clustered_graph) { - string ret = "digraph {\ncompound=true;labeljust=\"r\";\n"; - auto name = [](const FunctionalizeCond::Cluster& cluster) { - return cluster.representative.ToString(); - }; - for (auto kv : clustered_graph) { - if (!kv.second.switch_nodes.empty() || !kv.second.merge_nodes.empty()) { - strings::StrAppend( - &ret, kv.first.ToString(), " [label=\"", name(kv.second), - kv.second.switch_nodes.empty() - ? "" - : strings::StrCat(" switches=", kv.second.switch_nodes.size()), - kv.second.merge_nodes.empty() - ? "" - : strings::StrCat(" merges=", kv.second.merge_nodes.size()), - "\"];\n"); - } - } - for (auto kv : clustered_graph) { - for (auto in : kv.second.in_nodes) { - strings::StrAppend(&ret, name(*in), " -> ", name(kv.second), ";\n"); - } - } - return strings::StrCat(ret, "} // end"); -} - bool IsDeadSwitch(const Node* node) { for (const Edge* e : node->out_edges()) { const Node* dst = e->dst(); @@ -796,337 +674,285 @@ bool IsDeadSwitch(const Node* node) { return true; } -void FunctionalizeCond::CreateClusters() { - ClusterHandle source_cluster = ClusterHandle(Graph::kSourceId); - auto& source = clusters_.at(source_cluster); - std::deque>> workqueue; - workqueue.push_back({source_cluster, {}}); - for (Node* node : graph_->nodes()) { - if (IsSwitch(node)) { - switch_nodes_.insert(node); - } else if (IsMerge(node)) { - merge_nodes_.insert(node); - } - ClusterHandle& cluster = clusters_.at(node).Get(); - cluster = ClusterHandle(node->id()); - // Group all source clusters together. - if (node->IsSource() || node->in_edges().empty()) { - clusters_.at(node).Merge(&source); - source.Merge(&clusters_.at(node)); - workqueue.front().second.push_back(node); +string FunctionalizeCond::Branch_Name(FunctionalizeCond::Branch b) { + const string branch_name[FunctionalizeCond::kNumBranchTypes + 1] = { + "else", "then", "both", "neither", "count"}; + return branch_name[b]; +} + +Status FunctionalizeCond::ValidateFrontier( + const std::unordered_map& + branch_map, + const std::unordered_set& frontier) { + std::unordered_set pending[kNumBranchTypes]; + for (Node* n : frontier) { + pending[branch_map.at(n).branch].insert(n); + } + TF_RET_CHECK(pending[kNeither].empty()) << NodesToString(pending[kNeither]); + for (const Node* n : pending[kBoth]) { + TF_RET_CHECK(IsMerge(n)) << n->DebugString(); + // Merge nodes may be in then or else branch too + } + int index = (pending[kThenBranch].size() <= pending[kElseBranch].size()) + ? kThenBranch + : kElseBranch; + int other = 1 - index; + for (const Node* n : pending[index]) { + if (pending[other].find(n) != pending[other].end()) { + return errors::Internal( + "Node (", n->DebugString().c_str(), + ") in both Else and Then branch should be in Both."); } } - - // If there are no Merge nodes, then terminate. - if (merge_nodes_.empty()) { - return; + if (pending[kBoth].empty() && pending[kThenBranch].empty() && + pending[kElseBranch].empty()) { + return errors::Internal("Unexpected empty frontier for switch nodes"); } + return Status::OK(); +} - // Remove all dead Switch nodes. - RemoveUnusedArgs(switch_nodes_); - - // All parent_'s are still nullptr so clusters_ may still be resized. Resize - // conservatively assuming all merge nodes become XlaIf nodes. - clusters_.resize(clusters_.size() + merge_nodes_.size()); - - std::unordered_set marked; - while (!workqueue.empty()) { - auto cluster_queue = workqueue.front(); - VLOG(4) << "Cluster: " << cluster_queue.first << " Queue: {" - << str_util::Join(cluster_queue.second, ",", - [](string* output, const Node* node) { - strings::StrAppend(output, node->id()); - }) - << "}"; - - UnionFind& repr = clusters_.at(cluster_queue.first); - workqueue.pop_front(); - std::deque switch_nodes; - std::deque merge_nodes; - std::unordered_set cluster_member; - while (!cluster_queue.second.empty()) { - // Iterate node workqueue and flow forward merging all nodes reachable - // that are neither a Switch or a Merge and whose inputs are all part of - // the same cluster. - Node* cur = cluster_queue.second.front(); - cluster_queue.second.pop_front(); - if (marked.find(cur) != marked.end()) { - continue; - } - if (IsMerge(cur)) { - merge_nodes.push_back(cur); - marked.insert(cur); - continue; - } - if (IsSwitch(cur)) { - switch_nodes.push_back(cur); - marked.insert(cur); - continue; - } - clusters_.at(cur).Merge(&repr); - cluster_member.insert(cur); - for (Node* out : cur->out_nodes()) { - bool all_ancestors_in_cluster = true; - for (Node* in : out->in_nodes()) { - if (IsMerge(out)) { - merge_nodes.push_back(out); - } - if (IsSwitch(out)) { - switch_nodes.push_back(out); - } - if (cluster_member.find(in) == cluster_member.end()) { - all_ancestors_in_cluster = false; - break; - } - } - if (all_ancestors_in_cluster && out->IsOp()) { - cluster_queue.second.push_back(out); - marked.insert(cur); - } - } +Status FunctionalizeCond::Join(const ForwardFlowNode& src_state, + const Node* dst, ForwardFlowNode* dst_state) { + TF_RET_CHECK(dst_state->branch != Branch::kBoth && + dst_state->branch != Branch::kNumBranchTypes) + << "Unexpected/Invalid branch type: Merging " + << Branch_Name(src_state.branch) << " with " + << Branch_Name(dst_state->branch); + if (dst_state->branch == Branch::kNeither) { + dst_state->branch = src_state.branch; + } else if (src_state.branch != dst_state->branch && + src_state.branch != Branch::kNeither) { + if (IsMerge(dst)) { + dst_state->branch = Branch::kBoth; + } else { + return errors::Internal("Illegal merge: ", src_state.ToString(), " with ", + dst_state->ToString(), " for ", + dst->DebugString()); } + } + ++dst_state->count; + return Status::OK(); +} - VLOG(4) << "Switches: {" - << str_util::Join(switch_nodes, ",", - [](string* output, const Node* node) { - strings::StrAppend(output, node->id()); - }) - << "}"; - - // Merge Switch nodes with common predicate. - std::unordered_map> predicate_to_switch; - for (Node* node : switch_nodes) { - Node* tmp; - TF_CHECK_OK(node->input_node(1, &tmp)); - predicate_to_switch[tmp].push_back(node); - } - for (auto kv : predicate_to_switch) { - Node* first = kv.second.front(); - for (Node* switch_node : kv.second) { - clusters_.at(first).Merge(&clusters_.at(switch_node)); +std::vector +FunctionalizeCond::DeterminePredicateSwitchOrder() { + std::vector dead_switches; + std::vector switch_order; + DFS(*graph_, nullptr, [this, &dead_switches, &switch_order](Node* n) { + if (IsSwitch(n)) { + if (IsDeadSwitch(n)) { + dead_switches.push_back(n); + } else { + switch_order.push_back(n); } } + }); - // Enqueue each edge of the switch node separately. That is, group all the - // nodes that are due to the true/false edge of the switch together and - // consider all nodes that only have a control dependency on the switch node - // separately. We want to group together all nodes that are part of the same - // branch, as these will be extracted into the `then` and `else` functions - // of the functional if. The ops due to control edges are different as they - // could be involved with either branch and merging them here could result - // in invalid graphs. - for (auto kv : predicate_to_switch) { - ClusterHandle none = ClusterHandle(-1); - ClusterHandle first[2] = {none, none}; - std::deque* queue[2]; - for (auto switch_node : kv.second) { - for (const auto e : switch_node->out_edges()) { - if (IsSwitch(e->dst()) || IsMerge(e->dst())) { - continue; - } - // Control edges are enqueued on their own. - if (e->IsControlEdge()) { - workqueue.push_back({Representative(e->dst()), {e->dst()}}); - continue; - } - // Combine all outputs of the same output port of a switch cluster - // into the same workqueue entry. - if (first[e->src_output()] == none) { - ClusterHandle repr = Representative(e->dst()); - first[e->src_output()] = repr; - workqueue.push_back({repr, {}}); - queue[e->src_output()] = &workqueue.back().second; - } - clusters_.at(first[e->src_output()]).Merge(&clusters_.at(e->dst())); - queue[e->src_output()]->push_back(e->dst()); - } - } - } + // Remove all dead switch nodes. + for (Node* n : dead_switches) { + VLOG(2) << "Removing dead switch: " << n->DebugString(); + graph_->RemoveNode(n); } -} -void FunctionalizeCond::ContractEdge(Cluster* from, Cluster* to, - bool remove_from_graph) { - VLOG(3) << "ContractEdge from = " << from->representative - << " to = " << to->representative; - if (from->representative == to->representative) { - return; + std::vector predicate_switch_order; + if (switch_order.empty()) { + return predicate_switch_order; } - to->merge_nodes.insert(from->merge_nodes.begin(), from->merge_nodes.end()); - from->merge_nodes.clear(); - to->switch_nodes.insert(from->switch_nodes.begin(), from->switch_nodes.end()); - from->switch_nodes.clear(); - - for (Cluster* from_out : from->out_nodes) { - from_out->in_nodes.erase(from); - if (from_out->representative != to->representative) { - from_out->in_nodes.insert(to); - to->out_nodes.insert(from_out); - } - } - from->out_nodes.clear(); - for (Cluster* from_in : from->in_nodes) { - from_in->out_nodes.erase(from); - if (from_in->representative != to->representative) { - from_in->out_nodes.insert(to); - to->in_nodes.insert(from_in); + // Merge Switch nodes with common predicate. + std::unordered_map predicate_index; + // The nodes in switch_order are in reverse topological order, but the + // clustered switches need not be (i.e., when considered as a cluster one + // element of a cluster may be later in the topological order than another + // node whose cluster is later in the topological order of clustered + // switches). + for (auto it = switch_order.rbegin(); it != switch_order.rend(); ++it) { + Node* pred; + TF_CHECK_OK((*it)->input_node(1, &pred)); + if (predicate_index.find(pred) == predicate_index.end()) { + predicate_index[pred] = predicate_switch_order.size(); + predicate_switch_order.emplace_back(pred); } + predicate_switch_order[predicate_index[pred]].switches.push_back(*it); } - from->in_nodes.clear(); - - to->in_nodes.erase(from); - to->out_nodes.erase(from); - clusters_.at(to->representative).Merge(&clusters_.at(from->representative)); - from->visited = true; - - if (remove_from_graph) { - clustered_graph_.erase(from->representative); - } + return predicate_switch_order; } -void FunctionalizeCond::CreateClusteredGraph() { - auto update_cluster_for_node = [this](Node* node) -> Cluster& { - ClusterHandle repr = Representative(node); - Cluster& cluster_node = clustered_graph_[repr]; - cluster_node.representative = repr; - for (const Node* in : node->in_nodes()) { - ClusterHandle other_repr = Representative(in); - // Skip source, sink and internal edges. - if (other_repr == repr) { - continue; +StatusOr> +FunctionalizeCond::EnsureDominanceAndReturnNonDominatedControlNodes( + const std::unordered_map& branch_map, + const std::vector& switches) { + std::vector old_control_nodes; + for (const auto& kv : branch_map) { + if (kv.second.count != kv.first->in_edges().size()) { + std::vector delete_edges; + for (const Edge* in : kv.first->in_edges()) { + auto it = branch_map.find(in->src()); + if (it == branch_map.end()) { + if (in->IsControlEdge()) { + old_control_nodes.push_back(in->src()); + delete_edges.push_back(in); + } else { + if (IsSwitch(in->src())) { + if (std::find(switches.begin(), switches.end(), in->src()) == + switches.end()) { + return errors::Internal( + "Unexpected switch node found during flow forward: ", + in->src()->DebugString()); + } + continue; + } + return errors::InvalidArgument( + "Value ", kv.first->name(), "'s input, ", in->src()->name(), + ", is not dominated by switch nodes ", NodesToString(switches)); + } + } } - Cluster& cluster_node_in = clustered_graph_[other_repr]; - cluster_node.in_nodes.insert(&cluster_node_in); - cluster_node_in.out_nodes.insert(&cluster_node); - cluster_node_in.representative = other_repr; - } - for (const Node* out : node->out_nodes()) { - ClusterHandle other_repr = Representative(out); - // Skip source, sink and internal edges. - if (other_repr == repr) { - continue; + // Remove control edges from nodes that are not dominated by the switch + // nodes. New control dependencies will be added between these nodes and + // the XlaIf node inserted. + for (const Edge* e : delete_edges) { + graph_->RemoveEdge(e); } - Cluster& cluster_node_out = clustered_graph_[other_repr]; - cluster_node.out_nodes.insert(&cluster_node_out); - cluster_node_out.in_nodes.insert(&cluster_node); - cluster_node_out.representative = other_repr; } - return cluster_node; - }; - update_cluster_for_node(graph_->source_node()); - for (Node* node : switch_nodes_) { - update_cluster_for_node(node).switch_nodes.insert(node); } - for (Node* node : merge_nodes_) { - update_cluster_for_node(node).merge_nodes.insert(node); - } - - VLOG(3) << "Graph with clusters: " << DebugString(*graph_, &clusters_); - VLOG(3) << "ClusteredGraph: " << DebugString(clustered_graph_); + return old_control_nodes; } -gtl::optional -FunctionalizeCond::CreateCorrespondingMergeCluster(Cluster* switch_cluster) { - VLOG(3) << "CreateCorrespondingMergeCluster for " - << switch_cluster->representative; - std::unordered_set merges; - std::unordered_set dominated; - dominated.insert(switch_cluster); - std::deque queue; - auto enqueue_or_update_merge = [this, &queue, &merges](Cluster* c) { - if (c->merge_nodes.empty()) { - queue.push_back(c); - } else { - merges.insert(c); - } - }; - // Enqueue all the outputs of the switch cluster in the workqueue. - for (auto* out : switch_cluster->out_nodes) { - enqueue_or_update_merge(out); - } - std::unordered_set visited; - while (!queue.empty()) { - Cluster* cur = queue.front(); - queue.pop_front(); - if (visited.find(cur) != visited.end()) { +StatusOr< + std::pair, + std::unordered_set>> +FunctionalizeCond::DetermineBranchMapAndFrontier( + const std::vector& switches) { + std::unordered_map branch_map; + std::unordered_set frontier; + std::vector stack = switches; + std::vector visited(graph_->num_node_ids(), false); + while (!stack.empty()) { + Node* n = stack.back(); + stack.pop_back(); + + if (visited[n->id()]) { continue; } - visited.insert(cur); - // Ensure all inputs to the current node are in the dominated set. - for (Cluster* in : cur->in_nodes) { - if (dominated.find(in) == dominated.end()) { - return gtl::nullopt; + visited[n->id()] = true; + + // Propagate branch state along each edge of a switch node. + bool sink_only = true; + for (const Edge* e : n->out_edges()) { + Node* out = e->dst(); + if (!out->IsOp()) { + continue; + } + sink_only = false; + // Propagate branch information. + ForwardFlowNode& ffn = branch_map[out]; + if (IsSwitch(n)) { + int index = e->IsControlEdge() ? Branch::kNeither : e->src_output(); + TF_RETURN_IF_ERROR(Join(ForwardFlowNode(Branch(index)), out, &ffn)); + } else { + TF_RETURN_IF_ERROR(Join(branch_map[n], out, &ffn)); + } + if (IsMerge(out)) { + if (out->in_edges().size() == ffn.count) { + frontier.insert(out); + } + } else if (!visited[out->id()]) { + stack.push_back(out); } } - for (Cluster* out : cur->out_nodes) { - // No switch nodes beyond the entry one is expected. - if (!out->switch_nodes.empty()) { - return gtl::nullopt; + if (sink_only) { + if (!IsIdentity(n)) { + VLOG(1) << "Feeding into sink: " << n->DebugString(); } - enqueue_or_update_merge(out); } } - auto it = merges.begin(); - Cluster* merge_cluster = *it; - for (++it; it != merges.end(); ++it) { - ContractEdge(*it, merge_cluster); - } - - // TODO(jpienaar): Clean up graph, merging nodes. - return merge_cluster; + if (VLOG_IS_ON(2)) { + for (const auto& kv : branch_map) { + // Append attribute to the graph if running with logging to make the + // changes clearer in the visualization. + kv.first->AddAttr("_XlaFunctionalizeBranch", + Branch_Name(kv.second.branch)); + } + } + return std::make_pair(std::move(branch_map), std::move(frontier)); } -xla::StatusOr FunctionalizeCond::DetermineCondArgs( - const Cluster& merge_cluster, const Cluster& switch_cluster) { - VLOG(2) << "DetermineCondArgs for " << merge_cluster.representative - << " with switch cluster " << switch_cluster.representative; - CondArgs ret; - auto feeds_into_branch_cluster = [&](Node* switch_cluster) { - for (Node* out : switch_cluster->out_nodes()) { - ClusterHandle repr = Representative(out); - if (repr == merge_cluster.representative) { - return true; - } - for (Cluster* in : merge_cluster.in_nodes) { - if (repr == in->representative) { - return true; - } +Status FunctionalizeCond::FunctionalizeInternal() { + std::vector predicate_switch_order = + DeterminePredicateSwitchOrder(); + + // Iterate from innermost set of clustered switches to outermost, replacing + // matching switch->merge subgraphs with single XlaIf nodes. + for (auto it = predicate_switch_order.rbegin(); + it != predicate_switch_order.rend(); ++it) { + auto& ps = *it; + VLOG(3) << "Flow down from: " << NodesToString(ps.switches) << " (" + << ps.predicate->name() << ")"; + + std::unordered_map branch_map; + std::unordered_set frontier; + TF_ASSIGN_OR_RETURN(std::tie(branch_map, frontier), + DetermineBranchMapAndFrontier(ps.switches)); + + VLOG(2) << "FunctionalizeControlFlow (before XlaIf conversion): " + << dump_graph::DumpGraphToFile("functionalize_bc", *graph_); + TF_RETURN_IF_ERROR(ValidateFrontier(branch_map, frontier)); + + // Sort the merge and switch nodes using NodeCmp. The switch-nodes are + // further grouped (post sorting) by input to the switch node as in the + // functionalized form each input will be passed in only once. This grouping + // should retain the sorted order. + CondArgNodes cond_arg_nodes; + std::unordered_map input_index; + std::sort(ps.switches.begin(), ps.switches.end(), NodeCmp()); + for (Node* switch_node : ps.switches) { + Node* in; + TF_RETURN_IF_ERROR(switch_node->input_node(0, &in)); + if (input_index.find(in) == input_index.end()) { + input_index[in] = cond_arg_nodes.size(); + cond_arg_nodes.emplace_back(in); } + cond_arg_nodes.at(input_index.at(in)).switch_nodes.push_back(switch_node); } - return false; - }; - for (Node* switch_cluster_node : switch_cluster.switch_nodes) { - if (!feeds_into_branch_cluster(switch_cluster_node)) { - continue; + std::vector merge_nodes(frontier.begin(), frontier.end()); + std::sort(merge_nodes.begin(), merge_nodes.end(), NodeCmp()); + + TF_ASSIGN_OR_RETURN(std::vector old_control_nodes, + EnsureDominanceAndReturnNonDominatedControlNodes( + branch_map, ps.switches)); + + TF_ASSIGN_OR_RETURN( + Node * if_node, + ConvertToXlaIf(cond_arg_nodes, ps.switches, merge_nodes, ps.predicate)); + for (Node* old : old_control_nodes) { + graph_->AddControlEdge(old, if_node); } - Node* tmp; - TF_RETURN_IF_ERROR(switch_cluster_node->input_node(1, &tmp)); - if (ret.conditional == nullptr) { - ret.conditional = tmp; - } else if (ret.conditional != tmp) { - return errors::Unimplemented( - "Switch statements with different conditionals cannot be " - "converted into functional conditional."); + for (auto& del_kv : branch_map) { + graph_->RemoveNode(del_kv.first); } - ret.args.insert(switch_cluster_node); + for (auto& kv : cond_arg_nodes) { + for (Node* node : kv.switch_nodes) { + graph_->RemoveNode(node); + } + } + VLOG(2) << "FunctionalizeControlFlow (after XlaIf conversion): " + << dump_graph::DumpGraphToFile("functionalize_ac", *graph_); } - return ret; + return Status::OK(); } -xla::StatusOr FunctionalizeCond::BuildAndAddXlaIfOp( - const CondArgs& cond_args, const Cluster& merge_cluster, - const std::vector& outputs) { - VLOG(2) << "Build if op for " << NodesToString(merge_cluster.merge_nodes) - << " with input " << NodesToString(cond_args.args); +StatusOr FunctionalizeCond::BuildAndAddXlaIfOp( + const CondArgNodes& cond_arg_nodes, const std::vector& switch_nodes, + const std::vector& merge_nodes, Node* predicate) { + VLOG(2) << "Build if op for " << NodesToString(merge_nodes) << " with input " + << NodesToString(switch_nodes); NodeDef if_def; // Create a new If node using the name of the merge node. - NodeDefBuilder builder( - strings::StrCat((*merge_cluster.merge_nodes.begin())->name(), "_If"), - "XlaIf"); + NodeDefBuilder builder(strings::StrCat(predicate->name(), "_If"), "XlaIf"); string branch[] = {"else_branch", "then_branch"}; for (int i = 0; i < 2; ++i) { static std::atomic sequence_num(0LL); @@ -1137,8 +963,11 @@ xla::StatusOr FunctionalizeCond::BuildAndAddXlaIfOp( strings::StrCat("_functionalize_if_", branch[i], "_", id)); auto body = xla::MakeUnique(graph_->op_registry()); TF_RETURN_IF_ERROR( - ExtractBody(cond_args, merge_cluster, outputs, i, body.get())); + ExtractBody(cond_arg_nodes, switch_nodes, merge_nodes, i, body.get())); VLOG(3) << "Body " << branch[i] << ": " << DebugString(body.get()); + VLOG(4) << "FunctionalizeControlFlow (" << branch[i] << "): " + << dump_graph::DumpGraphToFile( + strings::StrCat("functionalize_", branch[i]), *body); FunctionDef body_fdef; TF_RETURN_IF_ERROR(GraphToFunctionDef(*body, body_name.name(), &body_fdef)); TF_RETURN_IF_ERROR(library_->AddFunctionDef(body_fdef)); @@ -1148,33 +977,39 @@ xla::StatusOr FunctionalizeCond::BuildAndAddXlaIfOp( // Build input type. std::vector inputs; DataTypeVector in_arg_types; - for (const Node* arg : cond_args.args) { - const Edge* in_edge; - TF_RETURN_IF_ERROR(arg->input_edge(0, &in_edge)); - if (in_edge->IsControlEdge()) { - builder.ControlInput(in_edge->src()->name()); - } else { - DataType dtype = arg->input_type(0); - inputs.emplace_back(NodeDefBuilder::NodeOut( - in_edge->src()->name(), in_edge->src_output(), dtype)); - in_arg_types.push_back(dtype); + for (auto& kv : cond_arg_nodes) { + bool inserted = false; + for (const Node* arg : kv.switch_nodes) { + const Edge* in_edge; + TF_RETURN_IF_ERROR(arg->input_edge(0, &in_edge)); + if (in_edge->IsControlEdge()) { + builder.ControlInput(in_edge->src()->name()); + } else { + if (!inserted) { + DataType dtype = arg->input_type(0); + inputs.emplace_back(NodeDefBuilder::NodeOut( + in_edge->src()->name(), in_edge->src_output(), dtype)); + in_arg_types.push_back(dtype); + inserted = true; + } + } } } builder.Attr("Tin", in_arg_types); // Build output type. DataTypeVector out_type; - for (const Node* merge : merge_cluster.merge_nodes) { + for (const Node* merge : merge_nodes) { DataType dtype = merge->output_type(0); out_type.push_back(dtype); } builder.Attr("Tout", out_type); builder.Attr("Tcond", DT_BOOL); - builder.Device(cond_args.conditional->assigned_device_name()); + builder.Device(predicate->assigned_device_name()); // Conditional should be the first input ... - builder.Input(NodeDefBuilder::NodeOut(cond_args.conditional->name(), 0, - cond_args.conditional->output_type(0))); + builder.Input( + NodeDefBuilder::NodeOut(predicate->name(), 0, predicate->output_type(0))); // ... followed by the other inputs. builder.Input(inputs); @@ -1183,64 +1018,31 @@ xla::StatusOr FunctionalizeCond::BuildAndAddXlaIfOp( return if_node; } -void FunctionalizeCond::RemoveClusterNodes(Cluster* cluster) { - VLOG(3) << "RemoveClusterNodes for " << cluster->representative; - ClusterHandle repr = cluster->representative; - std::deque to_delete; - for (Node* node : graph_->nodes()) { - if (Representative(node) == repr) { - to_delete.push_back(node); - } - } - for (Node* n : to_delete) { - graph_->RemoveNode(n); - } -} - -template -void FunctionalizeCond::RemoveUnusedArgs(const T& args) { - VLOG(2) << "RemoveUnusedArgs among: " << NodesToString(args); - - std::deque to_delete; - for (Node* arg : args) { - if (IsDeadSwitch(arg)) { - to_delete.push_back(arg); - for (Node* n : arg->out_nodes()) { - to_delete.push_back(n); - } - } - } - for (Node* n : to_delete) { - switch_nodes_.erase(n); - auto it = clustered_graph_.find(Representative(n)); - if (it != clustered_graph_.end()) { - it->second.switch_nodes.erase(n); - } - graph_->RemoveNode(n); - } -} - -Status FunctionalizeCond::ExtractBody(const CondArgs& cond_args, - const Cluster& merge_cluster, - const std::vector& outputs, +Status FunctionalizeCond::ExtractBody(const CondArgNodes& cond_arg_nodes, + const std::vector& switch_nodes, + const std::vector& merge_nodes, int input_edge, Graph* body) { - VLOG(2) << "ExtractBody for " << merge_cluster.representative - << " along edge " << input_edge; + VLOG(2) << "ExtractBody for " << NodesToString(merge_nodes) << " along edge " + << input_edge; std::vector squash_src_outputs(graph_->num_node_ids(), false); std::vector node_map(graph_->num_node_ids(), nullptr); int arg_count = 0; - for (const auto* arg : cond_args.args) { - DataType dtype = arg->input_type(0); - TF_ASSIGN_OR_RETURN(Node * arg_node, - BuildArgNode(body, dtype, arg_count++)); - node_map.at(arg->id()) = arg_node; - squash_src_outputs.at(arg->id()) = true; + for (auto& kv : cond_arg_nodes) { + Node* arg_node = nullptr; + for (const auto* arg : kv.switch_nodes) { + DataType dtype = arg->input_type(0); + if (arg_node == nullptr) { + TF_ASSIGN_OR_RETURN(arg_node, BuildArgNode(body, dtype, arg_count++)); + } + node_map.at(arg->id()) = arg_node; + squash_src_outputs.at(arg->id()) = true; + } } std::vector stack; - stack.reserve(outputs.size()); - for (int j = 0; j < outputs.size(); ++j) { - Node* node = outputs[j]; + stack.reserve(merge_nodes.size()); + for (int j = 0; j < merge_nodes.size(); ++j) { + Node* node = merge_nodes[j]; TF_ASSIGN_OR_RETURN(node_map.at(node->id()), BuildRetvalNode(body, node->output_type(0), /*index=*/j)); @@ -1251,7 +1053,8 @@ Status FunctionalizeCond::ExtractBody(const CondArgs& cond_args, node_map.at(in->id()) = body->CopyNode(in); } - if (cond_args.args.find(in) == cond_args.args.end()) { + if (std::find(switch_nodes.begin(), switch_nodes.end(), in) == + switch_nodes.end()) { body->AddEdge(node_map.at(in->id()), in_edge->src_output(), node_map.at(node->id()), 0); } else { @@ -1266,18 +1069,25 @@ Status FunctionalizeCond::ExtractBody(const CondArgs& cond_args, body); } -Status FunctionalizeCond::AddInputEdges(const CondArgs& cond_args, - Node* if_node) { +Status FunctionalizeCond::AddInputEdges(const CondArgNodes& cond_arg_nodes, + Node* predicate, Node* if_node) { VLOG(3) << "AddInputEdges for " << if_node->name(); - int i = 0; - graph_->AddEdge(cond_args.conditional, 0, if_node, i++); - for (const Node* arg : cond_args.args) { - const Edge* in_edge; - TF_RETURN_IF_ERROR(arg->input_edge(0, &in_edge)); - if (in_edge->IsControlEdge()) { - graph_->AddControlEdge(in_edge->src(), if_node); - } else { - graph_->AddEdge(in_edge->src(), in_edge->src_output(), if_node, i++); + int index = 0; + graph_->AddEdge(predicate, 0, if_node, index++); + for (auto& kv : cond_arg_nodes) { + bool inserted = false; + for (const Node* arg : kv.switch_nodes) { + const Edge* in_edge; + TF_RETURN_IF_ERROR(arg->input_edge(0, &in_edge)); + if (in_edge->IsControlEdge()) { + graph_->AddControlEdge(in_edge->src(), if_node); + } else { + if (!inserted) { + graph_->AddEdge(in_edge->src(), in_edge->src_output(), if_node, + index++); + inserted = true; + } + } } } return Status::OK(); @@ -1308,186 +1118,27 @@ Status FunctionalizeCond::AddOutputEdges(const std::vector& outputs, return Status::OK(); } -void FunctionalizeCond::RemoveMergeNodes(Cluster* merge_cluster) { - VLOG(3) << "RemoveMergeNodes for " << merge_cluster->representative; - // Remove all merge nodes now dead post extraction of If. - for (auto it = merge_cluster->merge_nodes.begin(); - it != merge_cluster->merge_nodes.end();) { - Node* node = *it; - graph_->RemoveNode(node); - merge_cluster->merge_nodes.erase(*it++); - } -} - -Status FunctionalizeCond::RemoveTrivialSwitch(Cluster* switch_cluster) { - Cluster* merge_cluster = *switch_cluster->out_nodes.begin(); - if (merge_cluster->merge_nodes.empty()) { - return errors::FailedPrecondition( - "Not a trivial switch: no Merge node feeding into Switch node"); - } - - for (auto it = merge_cluster->merge_nodes.begin(); - it != merge_cluster->merge_nodes.end();) { - // We have the following structure: - // Op -> Switch -> Merge -> Consumer - // and we want to transform it to: - // Op -> Consumer - Node* merge_node = *it; - Node* switch_node; - const Edge* in = nullptr; - TF_RETURN_IF_ERROR(merge_node->input_node(0, &switch_node)); - TF_RETURN_IF_ERROR(switch_node->input_edge(0, &in)); - for (auto out : merge_node->out_edges()) { - int src_output = out->dst_input() == Graph::kControlSlot - ? Graph::kControlSlot - : in->src_output(); - graph_->AddEdge(in->src(), src_output, out->dst(), out->dst_input()); - } - graph_->RemoveNode(*it++); - } - RemoveUnusedArgs(switch_cluster->switch_nodes); - - return Status::OK(); -} - -Status FunctionalizeCond::ConvertCorrespondingMergeToXlaIf( - Cluster* switch_cluster) { - VLOG(1) << "ConvertMergeToXlaIf for " << switch_cluster->representative; - gtl::optional maybe_merge = - CreateCorrespondingMergeCluster(switch_cluster); - if (!maybe_merge.has_value()) { - return errors::FailedPrecondition( - "Switch cluster was not part of a simple conditional in the clustered " - "graph. Graph nodes in switch cluster ", - NodesToString(switch_cluster->switch_nodes)); - } - Cluster* merge_cluster = *maybe_merge; - if (merge_cluster->merge_nodes.empty()) { - return errors::Internal( - "Merge node in clustered graph contains no merge nodes: ", - merge_cluster->representative.ToString()); - } - TF_ASSIGN_OR_RETURN(auto cond_args, - DetermineCondArgs(*merge_cluster, *switch_cluster)); - - // Sort the outputs by ID to produce more stable output. - std::vector outputs(merge_cluster->merge_nodes.begin(), - merge_cluster->merge_nodes.end()); - std::sort(outputs.begin(), outputs.end(), CondArgs::CondCmp()); +StatusOr FunctionalizeCond::ConvertToXlaIf( + const CondArgNodes& cond_arg_nodes, const std::vector& switch_nodes, + const std::vector& merge_nodes, Node* predicate) { + VLOG(1) << "ConvertToXlaIf for " << NodesToString(switch_nodes) << " -> " + << NodesToString(merge_nodes); // Extract bodies and builds a If operator. - TF_ASSIGN_OR_RETURN(Node * if_node, - BuildAndAddXlaIfOp(cond_args, *merge_cluster, outputs)); - TF_RETURN_IF_ERROR(AddInputEdges(cond_args, if_node)); - TF_RETURN_IF_ERROR(AddOutputEdges(outputs, if_node)); - - // Remove the old nodes from the graph_ and contract the edges of the - // clustered graph. - for (auto in : merge_cluster->in_nodes) { - if (in != switch_cluster) { - RemoveClusterNodes(in); - } - } - RemoveMergeNodes(merge_cluster); - RemoveUnusedArgs(cond_args.args); - auto in_nodes = merge_cluster->in_nodes; - for (auto it = in_nodes.begin(); it != in_nodes.end();) { - ContractEdge(*it++, switch_cluster); - } - ContractEdge(merge_cluster, switch_cluster); - clusters_[if_node].Get() = ClusterHandle(switch_cluster->representative); - - return Status::OK(); -} + TF_ASSIGN_OR_RETURN( + Node * if_node, + BuildAndAddXlaIfOp(cond_arg_nodes, switch_nodes, merge_nodes, predicate)); + TF_RETURN_IF_ERROR(AddInputEdges(cond_arg_nodes, predicate, if_node)); + TF_RETURN_IF_ERROR(AddOutputEdges(merge_nodes, if_node)); -std::vector> -FunctionalizeCond::SortedSwitchNodes() { - VLOG(2) << "ProcessClusteredGraph"; - std::stack> stack; - // Initialize with the source node. - stack.push({0, &clustered_graph_[Representative(graph_->source_node())]}); - - // Perform a depth-first traversal of the clustered graph computing the - // switch-merge depth. - std::vector> queue; - std::unordered_set visited; - while (!stack.empty()) { - Cluster* n = stack.top().second; - size_t depth = stack.top().first; - stack.pop(); - - auto inserted = visited.insert(n); - if (!inserted.second) { - continue; - } - - size_t new_depth = depth; - if (!n->merge_nodes.empty()) { - --new_depth; - } - if (!n->switch_nodes.empty()) { - queue.emplace_back(depth, n); - ++new_depth; - } - for (Cluster* e : n->out_nodes) { - stack.emplace(new_depth, e); - } - } - - // Sort in reverse order of switch-merge depth with ties broken by the - // ClusterHandle. - std::sort(queue.begin(), queue.end(), - [](const std::pair& lhs, - const std::pair& rhs) { - return std::tie(lhs.first, lhs.second->representative) > - std::tie(rhs.first, rhs.second->representative); - }); - - return queue; + return if_node; } Status FunctionalizeCond::Functionalize(Graph* graph, FunctionLibraryDefinition* library) { VLOG(1) << "FunctionalizeCond::Functionalize"; FunctionalizeCond fc(graph, library); - fc.CreateClusters(); - if (fc.NoConditionals()) { - return Status::OK(); - } - fc.CreateClusteredGraph(); - - auto queue = fc.SortedSwitchNodes(); - for (auto it = queue.begin(); it != queue.end();) { - Cluster* switch_cluster = (*it).second; - ++it; - if (switch_cluster->out_nodes.size() == 1) { - TF_RETURN_IF_ERROR(fc.RemoveTrivialSwitch(switch_cluster)); - } else { - TF_RETURN_IF_ERROR(fc.ConvertCorrespondingMergeToXlaIf(switch_cluster)); - } - - // Contract newly Switch free switch_cluster with outgoing nodes without - // Switch or Merge nodes. - for (auto& nodes : {switch_cluster->out_nodes, switch_cluster->in_nodes}) { - std::vector copy_nodes(nodes.begin(), nodes.end()); - for (auto* node : copy_nodes) { - if (node->merge_nodes.empty() && node->switch_nodes.empty()) { - fc.ContractEdge(node, switch_cluster); - } - } - } - - VLOG(3) << "Graph with clusters: " - << DebugString(*fc.graph_, &fc.clusters_); - VLOG(3) << "ClusteredGraph: " << DebugString(fc.clustered_graph_); - } - - if (!fc.switch_nodes_.empty()) { - return errors::Internal( - "Failed to functionalize control flow with Switch nodes remaining: ", - NodesToString(fc.switch_nodes_)); - } - return Status::OK(); + return fc.FunctionalizeInternal(); } } // namespace diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc index 01d2b282751f387cfa9c8887cdeb48090c96bff4..71f12a13339b9b5495631b8f9350579f6a0785a3 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc @@ -109,7 +109,7 @@ TEST(FunctionalizeControlFlow, Conditional) { auto y = ops::Placeholder(scope.WithOpName("y"), DT_INT32); auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32); auto less = ops::Less(scope.WithOpName("cond/Less"), y, x); - auto if_op = ops::XlaIf(scope.WithOpName("cond/Merge_If"), less, + auto if_op = ops::XlaIf(scope.WithOpName("cond/Less_If"), less, std::initializer_list{less, y, x}, then_fn, else_fn, {DT_INT32}); GraphDef expected; diff --git a/tensorflow/compiler/tf2xla/g3doc/cpu_supported_ops.md b/tensorflow/compiler/tf2xla/g3doc/cpu_supported_ops.md new file mode 100644 index 0000000000000000000000000000000000000000..82b3b46a2f1e97001d1e0c6b993ec243170bc7d8 --- /dev/null +++ b/tensorflow/compiler/tf2xla/g3doc/cpu_supported_ops.md @@ -0,0 +1,242 @@ +**Supported operators for device: XLA_CPU_JIT** + +Operator | Type Constraint +------------------------------------- | --------------- +`Abs` | `T={double,float,int32,int64}` +`Acosh` | `T={complex64,double,float}` +`Add` | `T={complex64,double,float,int32,int64}` +`AddN` | `T={complex64,double,float,int32,int64,uint32,uint64}` +`All` | `Tidx={int32,int64}` +`Angle` | `Tout={double,float}`
`T={complex64}` +`Any` | `Tidx={int32,int64}` +`ApproximateEqual` | `T={complex64,double,float,int32,int64,uint32,uint64}` +`ArgMax` | `Tidx={int32,int64}`
`output_type={int32,int64}`
`T={float}` +`ArgMin` | `Tidx={int32,int64}`
`output_type={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` +`Asinh` | `T={complex64,double,float}` +`AssignAddVariableOp` | `dtype={complex64,double,float,int32,int64,uint32,uint64}` +`AssignSubVariableOp` | `dtype={complex64,double,float,int32,int64,uint32,uint64}` +`AssignVariableOp` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Atan2` | `T={double,float}` +`Atanh` | `T={complex64,double,float}` +`AvgPool` | `T={double,float}` +`AvgPool3D` | `T={double,float}` +`AvgPool3DGrad` | `T={double,float}` +`AvgPoolGrad` | `T={double,float}` +`BatchMatMul` | `T={complex64,double,float,int32}` +`BatchToSpace` | `Tidx={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`BatchToSpaceND` | `Tcrops={int32,int64}`
`Tblock_shape={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`BiasAdd` | `T={complex64,double,float,int32,int64,uint32,uint64}` +`BiasAddGrad` | `T={complex64,double,float,int32,int64,uint32,uint64}` +`BiasAddV1` | `T={complex64,double,float,int32,int64,uint32,uint64}` +`BitwiseAnd` | `T={int32,int64,uint32,uint64}` +`BitwiseOr` | `T={int32,int64,uint32,uint64}` +`BroadcastArgs` | `T={int32,int64}` +`BroadcastGradientArgs` | `T={int32,int64}` +`Cast` | `DstT={bool,complex64,double,float,int32,int64,uint32,uint64}`
`SrcT={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Ceil` | `T={double,float}` +`Cholesky` | `T={complex64,double,float}` +`Complex` | `Tout={complex64}`
`T={double,float}` +`ComplexAbs` | `Tout={double,float}`
`T={complex64}` +`Concat` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`ConcatOffset` | +`ConcatV2` | `Tidx={int32}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Conj` | `T={complex64}` +`Const` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}` +`ControlTrigger` | +`Conv2D` | `T={float}` +`Conv2DBackpropFilter` | `T={float}` +`Conv2DBackpropInput` | `T={float}` +`Conv3D` | `T={double,float}` +`Conv3DBackpropFilterV2` | `T={double,float}` +`Conv3DBackpropInputV2` | `T={double,float}` +`Cos` | `T={complex64,double,float}` +`Cosh` | `T={complex64,double,float}` +`Cross` | `T={double,float,int32,int64,uint32,uint64}` +`Cumprod` | `Tidx={int32,int64}`
`T={float}` +`Cumsum` | `Tidx={int32,int64}`
`T={float}` +`DepthToSpace` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`DepthwiseConv2dNative` | `T={double,float}` +`DepthwiseConv2dNativeBackpropFilter` | `T={double,float}` +`DepthwiseConv2dNativeBackpropInput` | `T={double,float}` +`Diag` | `T={complex64,double,float,int32,int64}` +`DiagPart` | `T={complex64,double,float,int32,int64}` +`Div` | `T={complex64,double,float,int32,int64}` +`DynamicStitch` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Elu` | `T={double,float}` +`EluGrad` | `T={double,float}` +`Equal` | `T={bool,complex64,double,float,int32,int64}` +`Exp` | `T={complex64,double,float}` +`ExpandDims` | `Tdim={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Expm1` | `T={complex64,double,float}` +`Fill` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Floor` | `T={double,float}` +`FloorDiv` | `T={complex64,double,float,int32,int64}` +`FloorMod` | `T={double,float,int32,int64}` +`FusedBatchNorm` | `T={float}` +`FusedBatchNormGrad` | `T={float}` +`FusedBatchNormGradV2` | `U={float}`
`T={float}` +`FusedBatchNormV2` | `U={float}`
`T={float}` +`Gather` | `Tindices={int32,int64}`
`Tparams={bool,complex64,double,float,int32,int64,uint32,uint64}` +`GatherV2` | `Taxis={int32,int64}`
`Tindices={int32,int64}`
`Tparams={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Greater` | `T={double,float,int32,int64,uint32,uint64}` +`GreaterEqual` | `T={double,float,int32,int64,uint32,uint64}` +`Identity` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`IdentityN` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Imag` | `Tout={double,float}`
`T={complex64}` +`Inv` | `T={complex64,double,float,int32,int64}` +`Invert` | `T={int32,int64,uint32,uint64}` +`InvertPermutation` | `T={int32}` +`IsFinite` | `T={double,float}` +`IsInf` | `T={double,float}` +`IsNan` | `T={double,float}` +`L2Loss` | `T={double,float}` +`LRN` | `T={float}` +`LRNGrad` | `T={float}` +`LeftShift` | `T={int32,int64,uint32,uint64}` +`Less` | `T={double,float,int32,int64,uint32,uint64}` +`LessEqual` | `T={double,float,int32,int64,uint32,uint64}` +`LinSpace` | `Tidx={int32,int64}`
`T={double,float}` +`Log` | `T={complex64,double,float}` +`Log1p` | `T={complex64,double,float}` +`LogSoftmax` | `T={double,float}` +`LogicalAnd` | +`LogicalNot` | +`LogicalOr` | +`MatMul` | `T={complex64,double,float}` +`MatrixDiag` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`MatrixDiagPart` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Max` | `Tidx={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` +`MaxPool` | `T={double,float,int32,int64}` +`MaxPool3D` | `T={float}` +`MaxPool3DGrad` | `TInput={float}`
`T={float}` +`MaxPoolGrad` | `T={double,float,int32,int64,uint32,uint64}` +`Maximum` | `T={double,float,int32,int64}` +`Mean` | `Tidx={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` +`Min` | `Tidx={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` +`Minimum` | `T={double,float,int32,int64}` +`MirrorPad` | `Tpaddings={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Mod` | `T={double,float,int32,int64}` +`Mul` | `T={complex64,double,float,int32,int64}` +`Multinomial` | `output_dtype={int32,int64}`
`T={double,float,int32,int64,uint32,uint64}` +`Neg` | `T={complex64,double,float,int32,int64}` +`NoOp` | +`NotEqual` | `T={bool,complex64,double,float,int32,int64}` +`OneHot` | `TI={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`OnesLike` | `T={bool,complex64,double,float,int32,int64}` +`Pack` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Pad` | `Tpaddings={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`PadV2` | `Tpaddings={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`ParallelDynamicStitch` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Pow` | `T={complex64,double,float,int32,int64}` +`PreventGradient` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Prod` | `Tidx={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` +`QuantizeAndDequantizeV2` | `T={double,float}` +`RandomStandardNormal` | `dtype={float}` +`RandomUniform` | `T={int32,int64}`
`dtype={double,float}` +`RandomUniformInt` | `T={int32,int64}`
`Tout={int32,int64}` +`Range` | `Tidx={double,float,int32,int64}` +`Rank` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`ReadVariableOp` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Real` | `Tout={double,float}`
`T={complex64}` +`RealDiv` | `T={complex64,double,float,int32,int64}` +`Reciprocal` | `T={complex64,double,float,int32,int64}` +`ReciprocalGrad` | `T={complex64,double,float}` +`Relu` | `T={double,float,int32,int64,uint32,uint64}` +`Relu6` | `T={double,float,int32,int64,uint32,uint64}` +`Relu6Grad` | `T={double,float,int32,int64,uint32,uint64}` +`ReluGrad` | `T={double,float,int32,int64,uint32,uint64}` +`Reshape` | `Tshape={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`ResourceApplyAdagrad` | `T={double,float}` +`ResourceApplyAdam` | `T={double,float}` +`ResourceApplyFtrl` | `T={double,float}` +`ResourceApplyFtrlV2` | `T={double,float}` +`ResourceApplyGradientDescent` | `T={double,float}` +`ResourceApplyMomentum` | `T={double,float}` +`ResourceApplyRMSProp` | `T={double,float}` +`ResourceGather` | `Tindices={int32,int64}`
`dtype={complex64,double,float,int32,int64,uint32,uint64}` +`ResourceStridedSliceAssign` | `Index={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Reverse` | `T={bool,complex64,double,float,int32,int64}` +`ReverseV2` | `T={bool,complex64,double,float,int32,int64}`
`Tidx={int32,int64}` +`RightShift` | `T={int32,int64,uint32,uint64}` +`Rint` | `T={double,float}` +`Round` | `T={complex64,double,float,int32,int64}` +`Rsqrt` | `T={complex64,double,float}` +`RsqrtGrad` | `T={complex64,double,float}` +`Select` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Selu` | `T={double,float}` +`SeluGrad` | `T={double,float}` +`Shape` | `out_type={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`ShapeN` | `out_type={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Sigmoid` | `T={complex64,double,float}` +`SigmoidGrad` | `T={complex64,double,float}` +`Sign` | `T={complex64,double,float,int32,int64}` +`Sin` | `T={complex64,double,float}` +`Sinh` | `T={complex64,double,float}` +`Size` | `out_type={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Slice` | `Index={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Softmax` | `T={double,float}` +`SoftmaxCrossEntropyWithLogits` | `T={double,float}` +`Softplus` | `T={double,float,int32,int64,uint32,uint64}` +`SoftplusGrad` | `T={double,float,int32,int64,uint32,uint64}` +`Softsign` | `T={double,float,int32,int64,uint32,uint64}` +`SoftsignGrad` | `T={double,float,int32,int64,uint32,uint64}` +`SpaceToBatch` | `Tpaddings={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`SpaceToBatchND` | `Tblock_shape={int32,int64}`
`Tpaddings={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`SpaceToDepth` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`SparseMatMul` | `Tb={float}`
`Ta={float}` +`SparseSoftmaxCrossEntropyWithLogits` | `Tlabels={int32,int64}`
`T={double,float}` +`Split` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`SplitV` | `Tlen={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Sqrt` | `T={complex64,double,float}` +`SqrtGrad` | `T={complex64,double,float}` +`Square` | `T={complex64,double,float,int32,int64}` +`SquaredDifference` | `T={complex64,double,float,int32,int64}` +`Squeeze` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`StackCloseV2` | +`StackPopV2` | `elem_type={bool,complex64,double,float,int32,int64,uint32,uint64}` +`StackPushV2` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`StackV2` | `elem_type={bool,complex64,double,float,int32,int64,uint32,uint64}` +`StatelessRandomNormal` | `Tseed={int32}`
`T={int32,int64}`
`dtype={float}` +`StatelessRandomUniform` | `Tseed={int32}`
`T={int32,int64}`
`dtype={float}` +`StopGradient` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`StridedSlice` | `Index={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`StridedSliceGrad` | `Index={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Sub` | `T={complex64,double,float,int32,int64}` +`Sum` | `Tidx={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` +`SymbolicGradient` | `Tout={bool,complex64,double,float,int32,int64,uint32,uint64}`
`Tin={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Tan` | `T={complex64,double,float,int32,int64}` +`Tanh` | `T={complex64,double,float}` +`TanhGrad` | `T={complex64,double,float}` +`TensorArrayCloseV3` | +`TensorArrayConcatV3` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}` +`TensorArrayGatherV3` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}` +`TensorArrayGradV3` | +`TensorArrayReadV3` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}` +`TensorArrayScatterV3` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`TensorArraySizeV3` | +`TensorArraySplitV3` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`TensorArrayV3` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}` +`TensorArrayWriteV3` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Tile` | `Tmultiples={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Transpose` | `Tperm={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`TruncateDiv` | `T={complex64,double,float,int32,int64}` +`TruncateMod` | `T={double,float,int32,int64}` +`TruncatedNormal` | `T={int32,int64}`
`dtype={double,float}` +`Unpack` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`UnsortedSegmentSum` | `Tnumsegments={int32,int64}`
`Tindices={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` +`VarIsInitializedOp` | +`VariableShape` | `out_type={int32,int64}` +`XlaWhile` | `T={bool,complex64,double,float,int32,int64,resource,uint32,uint64}` +`ZerosLike` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`_Arg` | `T={bool,complex64,double,float,int32,int64,resource,uint32,uint64}` +`_ArrayToList` | `out_types={bool,complex64,double,float,int32,int64,uint32,uint64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`_ListToArray` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
`Tin={bool,complex64,double,float,int32,int64,uint32,uint64}` +`_Retval` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`_XLARecv` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`_XLASend` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` + +To regenerate this table, run: + +```shell +bazel run -c opt -- tensorflow/compiler/tf2xla:tf2xla_supported_ops --device=XLA_CPU_JIT +``` diff --git a/tensorflow/compiler/tf2xla/g3doc/gpu_supported_ops.md b/tensorflow/compiler/tf2xla/g3doc/gpu_supported_ops.md new file mode 100644 index 0000000000000000000000000000000000000000..d4b7621ad2858fe17e93d292dd807e4f7c1c336b --- /dev/null +++ b/tensorflow/compiler/tf2xla/g3doc/gpu_supported_ops.md @@ -0,0 +1,238 @@ +**Supported operators for device: XLA_GPU_JIT** + +Operator | Type Constraint +------------------------------------- | --------------- +`Abs` | `T={double,float,int32,int64}` +`Acosh` | `T={complex64,double,float}` +`Add` | `T={complex64,double,float,int32,int64}` +`AddN` | `T={complex64,double,float,int32,int64,uint32,uint64}` +`All` | `Tidx={int32,int64}` +`Angle` | `Tout={double,float}`
`T={complex64}` +`Any` | `Tidx={int32,int64}` +`ApproximateEqual` | `T={complex64,double,float,int32,int64,uint32,uint64}` +`ArgMax` | `Tidx={int32,int64}`
`output_type={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` +`ArgMin` | `Tidx={int32,int64}`
`output_type={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` +`Asinh` | `T={complex64,double,float}` +`AssignAddVariableOp` | `dtype={complex64,double,float,int32,int64,uint32,uint64}` +`AssignSubVariableOp` | `dtype={complex64,double,float,int32,int64,uint32,uint64}` +`AssignVariableOp` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Atan2` | `T={double,float}` +`Atanh` | `T={complex64,double,float}` +`AvgPool` | `T={double,float}` +`AvgPool3D` | `T={double,float}` +`AvgPool3DGrad` | `T={double,float}` +`AvgPoolGrad` | `T={double,float}` +`BatchMatMul` | `T={complex64,double,float,int32}` +`BatchToSpace` | `Tidx={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`BatchToSpaceND` | `Tcrops={int32,int64}`
`Tblock_shape={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`BiasAdd` | `T={complex64,double,float,int32,int64,uint32,uint64}` +`BiasAddGrad` | `T={complex64,double,float,int32,int64,uint32,uint64}` +`BiasAddV1` | `T={complex64,double,float,int32,int64,uint32,uint64}` +`BitwiseAnd` | `T={int32,int64,uint32,uint64}` +`BitwiseOr` | `T={int32,int64,uint32,uint64}` +`BroadcastArgs` | `T={int32,int64}` +`BroadcastGradientArgs` | `T={int32,int64}` +`Cast` | `DstT={bool,complex64,double,float,int32,int64,uint32,uint64}`
`SrcT={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Ceil` | `T={double,float}` +`Cholesky` | `T={complex64,double,float}` +`Complex` | `Tout={complex64}`
`T={double,float}` +`ComplexAbs` | `Tout={double,float}`
`T={complex64}` +`Concat` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`ConcatOffset` | +`ConcatV2` | `Tidx={int32}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Conj` | `T={complex64}` +`Const` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}` +`ControlTrigger` | +`Conv2D` | `T={float}` +`Conv2DBackpropFilter` | `T={float}` +`Conv2DBackpropInput` | `T={float}` +`Conv3D` | `T={double,float}` +`Conv3DBackpropFilterV2` | `T={double,float}` +`Conv3DBackpropInputV2` | `T={double,float}` +`Cos` | `T={complex64,double,float}` +`Cosh` | `T={complex64,double,float}` +`Cross` | `T={double,float,int32,int64,uint32,uint64}` +`Cumprod` | `Tidx={int32,int64}`
`T={float}` +`Cumsum` | `Tidx={int32,int64}`
`T={float}` +`DepthToSpace` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`DepthwiseConv2dNative` | `T={double,float}` +`DepthwiseConv2dNativeBackpropFilter` | `T={double,float}` +`DepthwiseConv2dNativeBackpropInput` | `T={double,float}` +`Diag` | `T={complex64,double,float,int32,int64}` +`DiagPart` | `T={complex64,double,float,int32,int64}` +`Div` | `T={complex64,double,float,int32,int64}` +`DynamicStitch` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Elu` | `T={double,float}` +`EluGrad` | `T={double,float}` +`Equal` | `T={bool,complex64,double,float,int32,int64}` +`Exp` | `T={complex64,double,float}` +`ExpandDims` | `Tdim={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Expm1` | `T={complex64,double,float}` +`Fill` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Floor` | `T={double,float}` +`FloorDiv` | `T={complex64,double,float,int32,int64}` +`FloorMod` | `T={double,float,int32,int64}` +`FusedBatchNorm` | `T={float}` +`FusedBatchNormGrad` | `T={float}` +`FusedBatchNormGradV2` | `U={float}`
`T={float}` +`FusedBatchNormV2` | `U={float}`
`T={float}` +`Gather` | `Tindices={int32,int64}`
`Tparams={bool,complex64,double,float,int32,int64,uint32,uint64}` +`GatherV2` | `Taxis={int32,int64}`
`Tindices={int32,int64}`
`Tparams={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Greater` | `T={double,float,int32,int64,uint32,uint64}` +`GreaterEqual` | `T={double,float,int32,int64,uint32,uint64}` +`Identity` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`IdentityN` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Imag` | `Tout={double,float}`
`T={complex64}` +`Inv` | `T={complex64,double,float,int32,int64}` +`Invert` | `T={int32,int64,uint32,uint64}` +`InvertPermutation` | `T={int32}` +`IsFinite` | `T={double,float}` +`IsInf` | `T={double,float}` +`IsNan` | `T={double,float}` +`L2Loss` | `T={double,float}` +`LRN` | `T={float}` +`LRNGrad` | `T={float}` +`LeftShift` | `T={int32,int64,uint32,uint64}` +`Less` | `T={double,float,int32,int64,uint32,uint64}` +`LessEqual` | `T={double,float,int32,int64,uint32,uint64}` +`LinSpace` | `Tidx={int32,int64}`
`T={double,float}` +`Log` | `T={complex64,double,float}` +`Log1p` | `T={complex64,double,float}` +`LogSoftmax` | `T={double,float}` +`LogicalAnd` | +`LogicalNot` | +`LogicalOr` | +`MatMul` | `T={complex64,double,float}` +`MatrixDiag` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`MatrixDiagPart` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Max` | `Tidx={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` +`MaxPool` | `T={double,float,int32,int64}` +`MaxPool3D` | `T={float}` +`MaxPool3DGrad` | `TInput={float}`
`T={float}` +`MaxPoolGrad` | `T={double,float,int32,int64,uint32,uint64}` +`Maximum` | `T={double,float,int32,int64}` +`Mean` | `Tidx={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` +`Min` | `Tidx={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` +`Minimum` | `T={double,float,int32,int64}` +`MirrorPad` | `Tpaddings={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Mod` | `T={double,float,int32,int64}` +`Mul` | `T={complex64,double,float,int32,int64}` +`Multinomial` | `output_dtype={int32,int64}`
`T={double,float,int32,int64,uint32,uint64}` +`Neg` | `T={complex64,double,float,int32,int64}` +`NoOp` | +`NotEqual` | `T={bool,complex64,double,float,int32,int64}` +`OneHot` | `TI={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`OnesLike` | `T={bool,complex64,double,float,int32,int64}` +`Pack` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Pad` | `Tpaddings={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`PadV2` | `Tpaddings={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`ParallelDynamicStitch` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Pow` | `T={complex64,double,float,int32,int64}` +`PreventGradient` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Prod` | `Tidx={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` +`QuantizeAndDequantizeV2` | `T={double,float}` +`Range` | `Tidx={double,float,int32,int64}` +`Rank` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`ReadVariableOp` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Real` | `Tout={double,float}`
`T={complex64}` +`RealDiv` | `T={complex64,double,float,int32,int64}` +`Reciprocal` | `T={complex64,double,float,int32,int64}` +`ReciprocalGrad` | `T={complex64,double,float}` +`Relu` | `T={double,float,int32,int64,uint32,uint64}` +`Relu6` | `T={double,float,int32,int64,uint32,uint64}` +`Relu6Grad` | `T={double,float,int32,int64,uint32,uint64}` +`ReluGrad` | `T={double,float,int32,int64,uint32,uint64}` +`Reshape` | `Tshape={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`ResourceApplyAdagrad` | `T={double,float}` +`ResourceApplyAdam` | `T={double,float}` +`ResourceApplyFtrl` | `T={double,float}` +`ResourceApplyFtrlV2` | `T={double,float}` +`ResourceApplyGradientDescent` | `T={double,float}` +`ResourceApplyMomentum` | `T={double,float}` +`ResourceApplyRMSProp` | `T={double,float}` +`ResourceGather` | `Tindices={int32,int64}`
`dtype={complex64,double,float,int32,int64,uint32,uint64}` +`ResourceStridedSliceAssign` | `Index={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Reverse` | `T={bool,complex64,double,float,int32,int64}` +`ReverseV2` | `T={bool,complex64,double,float,int32,int64}`
`Tidx={int32,int64}` +`RightShift` | `T={int32,int64,uint32,uint64}` +`Rint` | `T={double,float}` +`Round` | `T={complex64,double,float,int32,int64}` +`Rsqrt` | `T={complex64,double,float}` +`RsqrtGrad` | `T={complex64,double,float}` +`Select` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Selu` | `T={double,float}` +`SeluGrad` | `T={double,float}` +`Shape` | `out_type={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`ShapeN` | `out_type={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Sigmoid` | `T={complex64,double,float}` +`SigmoidGrad` | `T={complex64,double,float}` +`Sign` | `T={complex64,double,float,int32,int64}` +`Sin` | `T={complex64,double,float}` +`Sinh` | `T={complex64,double,float}` +`Size` | `out_type={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Slice` | `Index={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Softmax` | `T={double,float}` +`SoftmaxCrossEntropyWithLogits` | `T={double,float}` +`Softplus` | `T={double,float,int32,int64,uint32,uint64}` +`SoftplusGrad` | `T={double,float,int32,int64,uint32,uint64}` +`Softsign` | `T={double,float,int32,int64,uint32,uint64}` +`SoftsignGrad` | `T={double,float,int32,int64,uint32,uint64}` +`SpaceToBatch` | `Tpaddings={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`SpaceToBatchND` | `Tblock_shape={int32,int64}`
`Tpaddings={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`SpaceToDepth` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`SparseMatMul` | `Tb={float}`
`Ta={float}` +`SparseSoftmaxCrossEntropyWithLogits` | `Tlabels={int32,int64}`
`T={double,float}` +`Split` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`SplitV` | `Tlen={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Sqrt` | `T={complex64,double,float}` +`SqrtGrad` | `T={complex64,double,float}` +`Square` | `T={complex64,double,float,int32,int64}` +`SquaredDifference` | `T={complex64,double,float,int32,int64}` +`Squeeze` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`StackCloseV2` | +`StackPopV2` | `elem_type={bool,complex64,double,float,int32,int64,uint32,uint64}` +`StackPushV2` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`StackV2` | `elem_type={bool,complex64,double,float,int32,int64,uint32,uint64}` +`StatelessRandomNormal` | `Tseed={int32}`
`T={int32,int64}`
`dtype={float}` +`StatelessRandomUniform` | `Tseed={int32}`
`T={int32,int64}`
`dtype={float}` +`StopGradient` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`StridedSlice` | `Index={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`StridedSliceGrad` | `Index={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Sub` | `T={complex64,double,float,int32,int64}` +`Sum` | `Tidx={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` +`SymbolicGradient` | `Tout={bool,complex64,double,float,int32,int64,uint32,uint64}`
`Tin={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Tan` | `T={complex64,double,float,int32,int64}` +`Tanh` | `T={complex64,double,float}` +`TanhGrad` | `T={complex64,double,float}` +`TensorArrayCloseV3` | +`TensorArrayConcatV3` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}` +`TensorArrayGatherV3` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}` +`TensorArrayGradV3` | +`TensorArrayReadV3` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}` +`TensorArrayScatterV3` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`TensorArraySizeV3` | +`TensorArraySplitV3` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`TensorArrayV3` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}` +`TensorArrayWriteV3` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Tile` | `Tmultiples={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Transpose` | `Tperm={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`TruncateDiv` | `T={complex64,double,float,int32,int64}` +`TruncateMod` | `T={double,float,int32,int64}` +`Unpack` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`UnsortedSegmentSum` | `Tnumsegments={int32,int64}`
`Tindices={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` +`VarIsInitializedOp` | +`VariableShape` | `out_type={int32,int64}` +`XlaWhile` | `T={bool,complex64,double,float,int32,int64,resource,uint32,uint64}` +`ZerosLike` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`_Arg` | `T={bool,complex64,double,float,int32,int64,resource,uint32,uint64}` +`_ArrayToList` | `out_types={bool,complex64,double,float,int32,int64,uint32,uint64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`_ListToArray` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
`Tin={bool,complex64,double,float,int32,int64,uint32,uint64}` +`_Retval` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`_XLARecv` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`_XLASend` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` + +To regenerate this table, run: + +```shell +bazel run -c opt -- tensorflow/compiler/tf2xla:tf2xla_supported_ops --device=XLA_GPU_JIT +``` diff --git a/tensorflow/compiler/tf2xla/graph_compiler.cc b/tensorflow/compiler/tf2xla/graph_compiler.cc index 8062f0c03ca60e88bd5c021092dceb105232219f..02215b5112d37f726604da2c2caa4f804388d6e5 100644 --- a/tensorflow/compiler/tf2xla/graph_compiler.cc +++ b/tensorflow/compiler/tf2xla/graph_compiler.cc @@ -34,6 +34,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/graph_optimizer.h" #include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/graph_constructor.h" @@ -144,7 +145,9 @@ Status GraphCompiler::Compile() { } else { device_->Compute(CHECK_NOTNULL(params.op_kernel), &op_context); Status s = op_context.status(); - TF_RETURN_IF_ERROR(s); + if (!s.ok()) { + return AttachDef(s, n->def()); + } } // Set up outputs. Also check if outputs from the previous computation is diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index 6302fece1ffb27b6c7170fcfb90f5985f5b50659..5e1b01878b74f2fbc2e84f8c2db1fa37c2c1eb0e 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -4,6 +4,7 @@ package( default_visibility = ["//tensorflow/compiler/tf2xla:internal"], ) +load("//tensorflow:tensorflow.bzl", "tf_copts") load("//tensorflow:tensorflow.bzl", "tf_kernel_library") tf_kernel_library( @@ -30,11 +31,14 @@ tf_kernel_library( "diag_op.cc", "dynamic_stitch_op.cc", "elu_op.cc", + "fft_ops.cc", "fill_op.cc", "function_ops.cc", "gather_op.cc", "gather_op_helpers.h", "identity_op.cc", + "image_ops.cc", + "image_resize_ops.cc", "index_ops.cc", "l2loss_op.cc", "lrn_ops.cc", @@ -54,11 +58,13 @@ tf_kernel_library( "reshape_op.cc", "retval_op.cc", "reverse_op.cc", + "scan_ops.cc", "segment_reduction_ops.cc", "select_op.cc", "sendrecv_ops.cc", "sequence_ops.cc", "shape_op.cc", + "shape_util.cc", "slice_op.cc", "softmax_op.cc", "spacetobatch_op.cc", @@ -78,6 +84,7 @@ tf_kernel_library( hdrs = [ "gather_op.h", "index_ops.h", + "shape_util.h", ], deps = [ ":while_op", @@ -85,7 +92,9 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/lib:batch_dot", "//tensorflow/compiler/tf2xla/lib:cholesky", + "//tensorflow/compiler/tf2xla/lib:util", "//tensorflow/compiler/tf2xla/ops:sendrecv_ops", + "//tensorflow/compiler/xla:array4d", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", @@ -94,9 +103,11 @@ tf_kernel_library( "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client/lib:arithmetic", "//tensorflow/core:framework", + "//tensorflow/core:image_ops_op_lib", "//tensorflow/core:lib", "//tensorflow/core:linalg_ops_op_lib", "//tensorflow/core:protos_all_cc", + "//tensorflow/core:spectral_ops_op_lib", "//tensorflow/core:stateless_random_ops_op_lib", "//tensorflow/core/kernels:bounds_check", "//tensorflow/core/kernels:concat_lib", @@ -157,6 +168,7 @@ tf_kernel_library( cc_library( name = "index_ops_kernel_argmax_float_1d", srcs = ["index_ops_kernel_argmax_float_1d.cc"], + copts = tf_copts(), visibility = ["//visibility:public"], deps = [ "//tensorflow/compiler/xla/service/cpu:custom_call_target_registry", @@ -169,6 +181,7 @@ cc_library( cc_library( name = "index_ops_kernel_argmax_float_2d", srcs = ["index_ops_kernel_argmax_float_2d.cc"], + copts = tf_copts(), visibility = ["//visibility:public"], deps = [ "//tensorflow/compiler/xla/service/cpu:custom_call_target_registry", diff --git a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc index 248e9d111e556dcdd75581aa6562a66fc8b57063..a249b1869f547f8e5aa725f9f5cf391b10429928 100644 --- a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc @@ -14,7 +14,7 @@ limitations under the License. ==============================================================================*/ // XLA implementation of BatchNorm operations. -#include "tensorflow/compiler/tf2xla/literal_util.h" +#include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" @@ -26,43 +26,63 @@ namespace { class FusedBatchNormOp : public XlaOpKernel { public: explicit FusedBatchNormOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { - string data_format; OP_REQUIRES_OK(ctx, ctx->GetAttr("epsilon", &epsilon_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("is_training", &is_training_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format)); - TensorFormat tensor_format; - if (ctx->GetAttr("data_format", &data_format).ok()) { - OP_REQUIRES(ctx, FormatFromString(data_format, &tensor_format), - errors::InvalidArgument("Invalid data format")); - OP_REQUIRES( - ctx, (tensor_format == FORMAT_NHWC || tensor_format == FORMAT_NCHW), - errors::InvalidArgument("Not supported format")); - feature_index_ = GetTensorFeatureDimIndex(/*num_dims=*/4, tensor_format); - } + string data_format_str; + OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str)); + OP_REQUIRES( + ctx, FormatFromString(data_format_str, &data_format_), + errors::InvalidArgument("Invalid data format: ", data_format_str)); + OP_REQUIRES(ctx, + (data_format_ == FORMAT_NHWC || data_format_ == FORMAT_NCHW), + errors::InvalidArgument( + "Unsupported data format ", ToString(data_format_), + "; supported formats are NHWC and NCHW")); } void Compile(XlaOpKernelContext* ctx) override { + xla::PrimitiveType input_type; + OP_REQUIRES_OK(ctx, + DataTypeToPrimitiveType(ctx->input_type(0), &input_type)); + xla::PrimitiveType scale_type; + OP_REQUIRES_OK(ctx, + DataTypeToPrimitiveType(ctx->input_type(1), &scale_type)); + + xla::ComputationBuilder* builder = ctx->builder(); + + xla::ComputationDataHandle input = ctx->Input(0); + TensorShape input_shape = ctx->InputShape(0); + + int feature_index = + GetTensorFeatureDimIndex(input_shape.dims(), data_format_); + + // TODO(b/69928690): support mixed precision in the XLA batch normalization + // operators. As a workaround, cast everything to the statistics type (which + // may be more precise than the input type). + input = builder->ConvertElementType(input, scale_type); + if (is_training_) { - xla::ComputationDataHandle output = ctx->builder()->BatchNormTraining( - ctx->Input(0), ctx->Input(1), ctx->Input(2), epsilon_, - feature_index_); + xla::ComputationDataHandle output = builder->BatchNormTraining( + input, ctx->Input(1), ctx->Input(2), epsilon_, feature_index); // In training mode, outputs the normalized value as well as the // calculated mean and variance. - for (int i = 0; i < 3; i++) { - ctx->SetOutput(i, ctx->builder()->GetTupleElement(output, i)); - } + ctx->SetOutput(0, builder->ConvertElementType( + builder->GetTupleElement(output, 0), input_type)); + ctx->SetOutput(1, builder->GetTupleElement(output, 1)); + ctx->SetOutput(2, builder->GetTupleElement(output, 2)); + // Output 3 and 4 for "FusedBatchNorm" are currently marked as "reserved // space 1 & 2". They are used to pass the per-batch mean and // variance to the gradient. Here we maintain the same behavior by setting // them to the mean and variance calculated by BatchNormTraining. - ctx->SetOutput(3, ctx->builder()->GetTupleElement(output, 1)); - ctx->SetOutput(4, ctx->builder()->GetTupleElement(output, 2)); + ctx->SetOutput(3, builder->GetTupleElement(output, 1)); + ctx->SetOutput(4, builder->GetTupleElement(output, 2)); } else { - xla::ComputationDataHandle output = ctx->builder()->BatchNormInference( - ctx->Input(0), ctx->Input(1), ctx->Input(2), ctx->Input(3), - ctx->Input(4), epsilon_, feature_index_); - ctx->SetOutput(0, output); + xla::ComputationDataHandle output = builder->BatchNormInference( + input, ctx->Input(1), ctx->Input(2), ctx->Input(3), ctx->Input(4), + epsilon_, feature_index); + ctx->SetOutput(0, builder->ConvertElementType(output, input_type)); // Directly send input to output as mean and variance in inference mode. ctx->SetOutput(1, ctx->Input(3)); ctx->SetOutput(2, ctx->Input(4)); @@ -73,55 +93,113 @@ class FusedBatchNormOp : public XlaOpKernel { private: float epsilon_; - int64 feature_index_; + TensorFormat data_format_; bool is_training_; }; REGISTER_XLA_OP(Name("FusedBatchNorm"), FusedBatchNormOp); +REGISTER_XLA_OP(Name("FusedBatchNormV2"), FusedBatchNormOp); class FusedBatchNormGradOp : public XlaOpKernel { public: explicit FusedBatchNormGradOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { - string data_format; OP_REQUIRES_OK(ctx, ctx->GetAttr("epsilon", &epsilon_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format)); - bool is_training; - OP_REQUIRES_OK(ctx, ctx->GetAttr("is_training", &is_training)); - CHECK(is_training) << "FusedBatchNormGradOp with is_training=False cannot " - "be used with XLA for now!"; - TensorFormat tensor_format; - if (ctx->GetAttr("data_format", &data_format).ok()) { - OP_REQUIRES(ctx, FormatFromString(data_format, &tensor_format), - errors::InvalidArgument("Invalid data format")); - OP_REQUIRES( - ctx, (tensor_format == FORMAT_NHWC || tensor_format == FORMAT_NCHW), - errors::InvalidArgument("Not supported format")); - feature_index_ = GetTensorFeatureDimIndex(4, tensor_format); - } + OP_REQUIRES_OK(ctx, ctx->GetAttr("is_training", &is_training_)); + string data_format_str; + OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str)); + OP_REQUIRES( + ctx, FormatFromString(data_format_str, &data_format_), + errors::InvalidArgument("Invalid data format: ", data_format_str)); + OP_REQUIRES(ctx, + (data_format_ == FORMAT_NHWC || data_format_ == FORMAT_NCHW), + errors::InvalidArgument( + "Unsupported data format ", ToString(data_format_), + "; supported formats are NHWC and NCHW")); } void Compile(XlaOpKernelContext* ctx) override { - auto grad_output = ctx->Input(0); - auto activation = ctx->Input(1); + xla::ComputationBuilder* b = ctx->builder(); + + auto grad_backprop = ctx->Input(0); + auto activations = ctx->Input(1); auto scale = ctx->Input(2); auto mean = ctx->Input(3); auto var = ctx->Input(4); - xla::ComputationDataHandle output = ctx->builder()->BatchNormGrad( - activation, scale, mean, var, grad_output, epsilon_, feature_index_); - for (int i = 0; i < 3; i++) { - ctx->SetOutput(i, ctx->builder()->GetTupleElement(output, i)); + TensorShape input_shape = ctx->InputShape(0); + int feature_index = + GetTensorFeatureDimIndex(input_shape.dims(), data_format_); + + DataType input_dtype = ctx->input_type(0); + DataType scale_dtype = ctx->input_type(2); + xla::PrimitiveType input_type; + OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(input_dtype, &input_type)); + xla::PrimitiveType scale_type; + OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(scale_dtype, &scale_type)); + + // TODO(b/69928690): support mixed precision in the XLA batch normalization + // operators. For now, cast everything to the statistics type (which + // may be more precise than the input type). + grad_backprop = b->ConvertElementType(grad_backprop, scale_type); + activations = b->ConvertElementType(activations, scale_type); + + xla::ComputationDataHandle x_backprop; + xla::ComputationDataHandle scale_backprop; + xla::ComputationDataHandle offset_backprop; + if (is_training_) { + xla::ComputationDataHandle output = + b->BatchNormGrad(activations, scale, mean, var, grad_backprop, + epsilon_, feature_index); + + x_backprop = b->GetTupleElement(output, 0); + scale_backprop = b->GetTupleElement(output, 1); + offset_backprop = b->GetTupleElement(output, 2); + } else { + // Reduce over all dimensions except the feature dim. + std::vector reduction_dims(input_shape.dims() - 1); + std::iota(reduction_dims.begin(), reduction_dims.begin() + feature_index, + 0); + std::iota(reduction_dims.begin() + feature_index, reduction_dims.end(), + feature_index + 1); + // offset_backprop = sum(y_backprop) + // scale_backprop = y_backprop * ((x - pop_mean) * rsqrt(pop_var + + // epsilon)) + // x_backprop = y_backprop * (scale * rsqrt(pop_var + epsilon)) + offset_backprop = + b->Reduce(grad_backprop, XlaHelpers::Zero(b, scale_dtype), + *ctx->GetOrCreateAdd(scale_dtype), reduction_dims); + + // scratch1 = rsqrt(pop_var + epsilon) + auto neg_half = XlaHelpers::FloatLiteral(b, scale_dtype, -0.5); + auto scratch1 = + b->Pow(b->Add(var, b->ConstantR0(epsilon_)), neg_half); + + // scratch2 = sum(y_backprop * (x - mean)) + auto scratch2 = b->Reduce( + b->Mul(grad_backprop, b->Sub(activations, mean, {feature_index})), + XlaHelpers::Zero(b, scale_dtype), *ctx->GetOrCreateAdd(scale_dtype), + reduction_dims); + + x_backprop = + b->Mul(grad_backprop, b->Mul(scratch1, scale), {feature_index}); + scale_backprop = b->Mul(scratch1, scratch2); } - ctx->SetOutput(3, ctx->builder()->GetTupleElement(output, 1)); - ctx->SetOutput(4, ctx->builder()->GetTupleElement(output, 2)); + + ctx->SetOutput(0, b->ConvertElementType(x_backprop, input_type)); + ctx->SetOutput(1, scale_backprop); + ctx->SetOutput(2, offset_backprop); + ctx->SetConstantOutput(3, Tensor(scale_dtype, {})); + ctx->SetConstantOutput(4, Tensor(scale_dtype, {})); } private: + TensorFormat data_format_; float epsilon_; - int64 feature_index_; + bool is_training_; }; REGISTER_XLA_OP(Name("FusedBatchNormGrad"), FusedBatchNormGradOp); +REGISTER_XLA_OP(Name("FusedBatchNormGradV2"), FusedBatchNormGradOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc index 21d3e64872e19109852297838043975cea6d7921..344a2ab2b6835c518c41de6f7a30fb2a34d130d2 100644 --- a/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc @@ -159,7 +159,8 @@ class BatchToSpaceNDOp : public XlaOpKernel { block_shape, crops); } }; -REGISTER_XLA_OP(Name("BatchToSpaceND"), BatchToSpaceNDOp); +REGISTER_XLA_OP(Name("BatchToSpaceND").CompileTimeConstInput("crops"), + BatchToSpaceNDOp); class BatchToSpaceOp : public XlaOpKernel { public: @@ -181,7 +182,10 @@ class BatchToSpaceOp : public XlaOpKernel { private: int block_size_; }; -REGISTER_XLA_OP(Name("BatchToSpace"), BatchToSpaceOp); +REGISTER_XLA_OP(Name("BatchToSpace") + .CompileTimeConstInput("crops") + .CompileTimeConstInput("block_shape"), + BatchToSpaceOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc b/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc index bb031b8c471e08ba90c554e309b850a26c3edae0..ee2c920453c3bbaef2c145df743fddf999167c39 100644 --- a/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc @@ -65,7 +65,10 @@ class BCastArgsOp : public XlaOpKernel { private: TF_DISALLOW_COPY_AND_ASSIGN(BCastArgsOp); }; -REGISTER_XLA_OP(Name("BroadcastArgs"), BCastArgsOp); +REGISTER_XLA_OP(Name("BroadcastArgs") + .CompileTimeConstInput("s0") + .CompileTimeConstInput("s1"), + BCastArgsOp); // Given shapes of two tensors, computes the reduction indices for the // gradient computation. @@ -121,7 +124,10 @@ class BCastGradArgsOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(BCastGradArgsOp); }; -REGISTER_XLA_OP(Name("BroadcastGradientArgs"), BCastGradArgsOp); +REGISTER_XLA_OP(Name("BroadcastGradientArgs") + .CompileTimeConstInput("s0") + .CompileTimeConstInput("s1"), + BCastGradArgsOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc index 1de91924326464338352b1ac9edf77141f25ad35..2436a6074a11ad66387b232dd1c5aa135875bfc3 100644 --- a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/types.h" namespace tensorflow { namespace { @@ -75,7 +76,7 @@ static xla::ComputationDataHandle FloorDivImpl(xla::ComputationBuilder* b, auto abs_y = b->Abs(y); auto t = b->Neg(b->Sub(b->Add(abs_x, abs_y), one)); auto result = b->Select(different_sign, b->Div(t, abs_y), b->Div(x, y)); - if (dtype == DT_FLOAT || dtype == DT_DOUBLE) { + if (DataTypeIsFloating(dtype)) { result = b->Floor(result); } return result; diff --git a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc index 592f3ecc3ce2abf33ddffe8b0e59c4e12e73e956..545aa364f937b2dc972dbe7b8c18b5897aa8e5c3 100644 --- a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc @@ -92,7 +92,8 @@ class CategoricalOp : public XlaOpKernel { }; // TODO(b/68769717): Rename this sampler to Categorical. -REGISTER_XLA_OP(Name("Multinomial"), CategoricalOp); +REGISTER_XLA_OP(Name("Multinomial").CompileTimeConstInput("num_samples"), + CategoricalOp); } // anonymous namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/concat_op.cc b/tensorflow/compiler/tf2xla/kernels/concat_op.cc index 73a4740e29af7fa57e71ef42a342f46b0e24231d..1a246e8df9b2cd83147b50d960744332f8582a51 100644 --- a/tensorflow/compiler/tf2xla/kernels/concat_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/concat_op.cc @@ -84,8 +84,8 @@ class ConcatBaseOp : public XlaOpKernel { in_shape.dims() == input_dims || (input_is_scalar && in_is_scalar), errors::InvalidArgument( "ConcatOp : Ranks of all input tensors should match: shape[0] = ", - input_shape.DebugString(), " vs. shape[", i, "] = ", - in_shape.DebugString())); + input_shape.DebugString(), " vs. shape[", i, + "] = ", in_shape.DebugString())); if (in_shape.dims() == 0) { // Inputs that come in as scalars must be reshaped to 1-vectors. input_data.push_back(ctx->builder()->Reshape(handle, {1})); @@ -117,8 +117,11 @@ class ConcatV2Op : public ConcatBaseOp { : ConcatBaseOp(c, /* axis_index */ c->num_inputs() - 1) {} }; -REGISTER_XLA_OP(Name("Concat"), ConcatOp); -REGISTER_XLA_OP(Name("ConcatV2").TypeConstraint("Tidx", DT_INT32), ConcatV2Op); +REGISTER_XLA_OP(Name("Concat").CompileTimeConstInput("concat_dim"), ConcatOp); +REGISTER_XLA_OP(Name("ConcatV2") + .TypeConstraint("Tidx", DT_INT32) + .CompileTimeConstInput("axis"), + ConcatV2Op); class ConcatOffsetOp : public XlaOpKernel { public: @@ -189,10 +192,10 @@ class ConcatOffsetOp : public XlaOpKernel { } else { const int32 inp0_element = inp0_literal.Get({j}); const int32 inp_element = inp_literal.Get({j}); - OP_REQUIRES( - ctx, (inp0_element == inp_element), - errors::InvalidArgument("input[", i, ",", j, "] mismatch: ", - inp0_element, " vs. ", inp_element)); + OP_REQUIRES(ctx, (inp0_element == inp_element), + errors::InvalidArgument("input[", i, ",", j, + "] mismatch: ", inp0_element, + " vs. ", inp_element)); out_vec(j) = 0; } } @@ -202,7 +205,10 @@ class ConcatOffsetOp : public XlaOpKernel { } }; -REGISTER_XLA_OP(Name("ConcatOffset"), ConcatOffsetOp); +REGISTER_XLA_OP(Name("ConcatOffset") + .CompileTimeConstInput("concat_dim") + .CompileTimeConstInput("shape"), + ConcatOffsetOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc index c5017704e2a45b0bd740f7a8fdcf3a0be1d445a4..81cea6d376d02c956a5257c5475fe5c10b83deb9 100644 --- a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc @@ -46,72 +46,130 @@ TensorShape ExpandedFilterShapeForDepthwiseConvolution( return expanded_shape; } +// Broadcast zeros to ExpandedFilterShapeForDepthwiseConvolution. +xla::ComputationDataHandle CreateExpandedZero( + const TensorShape& filter_shape, DataType dtype, + xla::ComputationBuilder* builder) { + TensorShape expanded_filter_shape = + ExpandedFilterShapeForDepthwiseConvolution(filter_shape); + return builder->Broadcast(XlaHelpers::Zero(builder, dtype), + expanded_filter_shape.dim_sizes()); +} + +// Create a mask for depthwise convolution that will make a normal convolution +// produce the same results as a depthwise convolution. For a [2, 2, 3, 2] +// depthwise filter this returns a [2, 2, 3, 6] tesnsor +// 1 1 0 0 0 0 1 1 0 0 0 0 +// 0 0 1 1 0 0 0 0 1 1 0 0 +// 0 0 0 0 1 1 0 0 0 0 1 1 +// +// 1 1 0 0 0 0 1 1 0 0 0 0 +// 0 0 1 1 0 0 0 0 1 1 0 0 +// 0 0 0 0 1 1 0 0 0 0 1 1 +// +// The first step is to create a one tensor, A, that is [3] +// 0 1 2 +// +// and another tensor, B, that is [3 * 2] +// 0 1 2 3 4 5 +// +// and divide B it by 2 to get +// 0 0 1 1 2 2 +// +// then we broadcast the B to [2, 2, 3, 3 * 2] +// 0 0 1 1 2 2 0 0 1 1 2 2 +// 0 0 1 1 2 2 0 0 1 1 2 2 +// 0 0 1 1 2 2 0 0 1 1 2 2 +// +// 0 0 1 1 2 2 0 0 1 1 2 2 +// 0 0 1 1 2 2 0 0 1 1 2 2 +// 0 0 1 1 2 2 0 0 1 1 2 2 +// +// Finally compare A and broadcasted B in dimension 2 amd return the result at +// the beginning of the comment. +xla::ComputationDataHandle CreateExpandedFilterMask( + const TensorShape& filter_shape, xla::ComputationBuilder* builder) { + TensorShape expanded_filter_shape = + ExpandedFilterShapeForDepthwiseConvolution(filter_shape); + int64 depthwise_multiplier = filter_shape.dim_size(filter_shape.dims() - 1); + int64 input_feature = filter_shape.dim_size(filter_shape.dims() - 2); + + // Create a M sized linspace and an M*N sized linspace that will be + // broadcasted into perpendicular dimensions and compared. + xla::ComputationDataHandle input_feature_iota; + // DT_INT32 Iota will always return status::OK(). + TF_CHECK_OK(XlaHelpers::Iota(builder, DataType::DT_INT32, input_feature, + &input_feature_iota)); + xla::ComputationDataHandle expanded_feature_iota; + TF_CHECK_OK(XlaHelpers::Iota(builder, DataType::DT_INT32, + input_feature * depthwise_multiplier, + &expanded_feature_iota)); + + // Divide the M*N sized linspace by the depthwise_multiplier to create + // [0 0 1 1 2 2] in the example in the function comment. + expanded_feature_iota = + builder->Div(expanded_feature_iota, + XlaHelpers::IntegerLiteral(builder, DataType::DT_INT32, + depthwise_multiplier)); + + // Broadcast the N*M linspace to [H, W, ..., M, M*N]. + auto expanded_feature_broadcast_dims = expanded_filter_shape.dim_sizes(); + expanded_feature_broadcast_dims.pop_back(); + auto broadcasted_expanded_feature_iota = builder->Broadcast( + expanded_feature_iota, expanded_feature_broadcast_dims); + + // Compare the broadcasted linspace to the input feature linspace in the + // input feature dimension to create a diagonal predicate. + return builder->Eq(broadcasted_expanded_feature_iota, input_feature_iota, + {expanded_filter_shape.dims() - 2}); +} + // Expands a filter of shape [H, W, ..., M, N] to [H, W, ..., M, M*N] by adding // zeros for the cross-depth filters. Used to build a depthwise convolution. xla::ComputationDataHandle ExpandFilterForDepthwiseConvolution( const TensorShape& filter_shape, DataType dtype, const xla::ComputationDataHandle& filter, xla::ComputationBuilder* builder) { - // Filter has shape [H, W, ..., M, N] - // Dilate to [H, W, ..., M*M, N] using M inter-element padding, and then - // reshape to [H, W, ..., M, M*N]. - int num_spatial_dims = filter_shape.dims() - 2; - const int64 in_depth = filter_shape.dim_size(num_spatial_dims); - xla::PaddingConfig padding = xla::MakeNoPaddingConfig(filter_shape.dims()); - padding.mutable_dimensions(num_spatial_dims)->set_interior_padding(in_depth); - auto dilated_filter = - builder->Pad(filter, XlaHelpers::Zero(builder, dtype), padding); - + int64 depthwise_multiplier = filter_shape.dim_size(filter_shape.dims() - 1); + int64 input_feature = filter_shape.dim_size(filter_shape.dims() - 2); TensorShape expanded_filter_shape = ExpandedFilterShapeForDepthwiseConvolution(filter_shape); - return builder->Reshape(dilated_filter, expanded_filter_shape.dim_sizes()); + + // Create a [H, W, ..., 1, N*M] reshape of the filter. + TensorShape implicit_broadcast_filter_shape = expanded_filter_shape; + implicit_broadcast_filter_shape.set_dim( + implicit_broadcast_filter_shape.dims() - 2, 1); + implicit_broadcast_filter_shape.set_dim( + implicit_broadcast_filter_shape.dims() - 1, + depthwise_multiplier * input_feature); + auto implicit_broadcast_filter = + builder->Reshape(filter, implicit_broadcast_filter_shape.dim_sizes()); + + // Broadcast the filter to [H, W, ..., M, M*N]. + auto expanded_zero = CreateExpandedZero(filter_shape, dtype, builder); + auto expanded_filter = builder->Add(implicit_broadcast_filter, expanded_zero); + + // If the filter mask is set, choose the broadcasted filter, othwerwise, + // choose zero. + return builder->Select(CreateExpandedFilterMask(filter_shape, builder), + expanded_filter, expanded_zero); } // Inverse of ExpandFilterForDepthwiseConvolution. xla::ComputationDataHandle ContractFilterForDepthwiseBackprop( - const TensorShape& filter_shape, DataType dtype, + XlaOpKernelContext* ctx, const TensorShape& filter_shape, DataType dtype, const xla::ComputationDataHandle& filter_backprop, xla::ComputationBuilder* builder) { - int num_spatial_dims = filter_shape.dims() - 2; - - // Reshape to [H, W, ..., M*M, N] - TensorShape shape = filter_shape; - int64 in_depth = filter_shape.dim_size(num_spatial_dims); - shape.set_dim(num_spatial_dims, in_depth * in_depth); - auto reshaped = builder->Reshape(filter_backprop, shape.dim_sizes()); - - std::vector zeros(filter_shape.dims()); - std::vector strides(filter_shape.dims(), 1LL); - strides[num_spatial_dims] = in_depth + 1; - return builder->Slice(reshaped, zeros, shape.dim_sizes(), strides); - - // Alternate implementation for backends without strided Slice() support. - // TODO(phawkins): Remove when all backends support strided slice. - // // Pad [..., M * (M + 1), N] - // xla::PaddingConfig config = - // xla::MakeNoPaddingConfig(filter_shape.dims()); - // config.mutable_dimensions(num_spatial_dims) - // ->set_edge_padding_high(in_depth); - // auto zero = XlaHelpers::Zero(builder, dtype); - // auto padded = builder->Pad(reshaped, zero, config); - // - // // Reshape to [..., M, M + 1, N] - // shape = filter_shape; - // shape.set_dim(num_spatial_dims, in_depth); - // shape.set_dim(num_spatial_dims + 1, in_depth + 1); - // int64 out_depth = filter_shape.dim_size(num_spatial_dims + 1); - // shape.AddDim(out_depth); - // reshaped = builder->Reshape(padded, shape.dim_sizes()); - // - // // Slice to [..., M, 1, N] - // std::vector zeros(shape.dims()); - // std::vector strides(shape.dims(), 1LL); - // shape.set_dim(num_spatial_dims + 1, 1); - // auto sliced = builder->Slice(reshaped, zeros, shape.dim_sizes(), - // strides); - // - // // Reshape to [..., M, N] - // return builder->Reshape(sliced, filter_shape.dim_sizes()); + TensorShape expanded_filter_shape = + ExpandedFilterShapeForDepthwiseConvolution(filter_shape); + auto masked_expanded_filter = builder->Select( + CreateExpandedFilterMask(filter_shape, builder), filter_backprop, + CreateExpandedZero(filter_shape, dtype, builder)); + return builder->Reshape( + builder->Reduce(masked_expanded_filter, XlaHelpers::Zero(builder, dtype), + *ctx->GetOrCreateAdd(dtype), + {expanded_filter_shape.dims() - 2}), + filter_shape.dim_sizes()); } class ConvOp : public XlaOpKernel { @@ -121,6 +179,7 @@ class ConvOp : public XlaOpKernel { : XlaOpKernel(ctx), num_spatial_dims_(num_spatial_dims), depthwise_(depthwise) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("dilations", &dilations_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &strides_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_)); @@ -144,6 +203,22 @@ class ConvOp : public XlaOpKernel { errors::Unimplemented("Current implementation does not yet support " "strides in the batch and depth dimensions.")); + OP_REQUIRES(ctx, dilations_.size() == num_dims(), + errors::InvalidArgument("Dilations field must " + "specify ", + num_dims(), " dimensions")); + OP_REQUIRES( + ctx, dilations_[batch_dim] == 1 && dilations_[feature_dim] == 1, + errors::Unimplemented("Current implementation does not support " + "dilations in the batch and depth dimensions.")); + for (int i = 0; i < num_spatial_dims_; ++i) { + int input_dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i); + OP_REQUIRES(ctx, dilations_[input_dim] >= 1, + errors::Unimplemented("Dilation values must be positive; ", i, + "th spatial dimension had dilation ", + dilations_[input_dim])); + } + const TensorShape input_shape = ctx->InputShape(0); // Input filter is of the following dimensions: // [ filter_rows, filter_cols, ..., in_depth, out_depth] @@ -172,38 +247,53 @@ class ConvOp : public XlaOpKernel { xla::ComputationBuilder* b = ctx->builder(); xla::ComputationDataHandle filter = ctx->Input(1); + TensorShape expanded_filter_shape = filter_shape; if (depthwise_) { filter = ExpandFilterForDepthwiseConvolution( filter_shape, ctx->input_type(0), filter, b); + expanded_filter_shape = + ExpandedFilterShapeForDepthwiseConvolution(filter_shape); } xla::ConvolutionDimensionNumbers dims; - std::vector window_strides; + std::vector window_strides(num_spatial_dims_); + std::vector lhs_dilation(num_spatial_dims_, 1); + std::vector rhs_dilation(num_spatial_dims_); + std::vector> padding(num_spatial_dims_); + dims.set_input_batch_dimension(batch_dim); dims.set_output_batch_dimension(batch_dim); dims.set_input_feature_dimension(feature_dim); dims.set_output_feature_dimension(feature_dim); + dims.set_kernel_input_feature_dimension(num_spatial_dims_); + dims.set_kernel_output_feature_dimension(num_spatial_dims_ + 1); + for (int i = 0; i < num_spatial_dims_; ++i) { - int64 dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i); + const int64 dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i); dims.add_input_spatial_dimensions(dim); dims.add_kernel_spatial_dimensions(i); dims.add_output_spatial_dimensions(dim); - window_strides.push_back(strides_.at(dim)); + window_strides[i] = strides_.at(dim); + rhs_dilation[i] = dilations_.at(dim); + + int64 unused_output_size; + OP_REQUIRES_OK( + ctx, GetWindowedOutputSizeVerboseV2( + input_shape.dim_size(dim), expanded_filter_shape.dim_size(i), + rhs_dilation[i], window_strides[i], padding_, + &unused_output_size, &padding[i].first, &padding[i].second)); } - dims.set_kernel_input_feature_dimension(num_spatial_dims_); - dims.set_kernel_output_feature_dimension(num_spatial_dims_ + 1); - xla::Padding xla_padding = - (padding_ == VALID) ? xla::Padding::kValid : xla::Padding::kSame; - - xla::ComputationDataHandle conv = b->ConvWithGeneralDimensions( - ctx->Input(0), filter, window_strides, xla_padding, dims); + xla::ComputationDataHandle conv = + b->ConvGeneralDilated(ctx->Input(0), filter, window_strides, padding, + lhs_dilation, rhs_dilation, dims); ctx->SetOutput(0, conv); } protected: const int num_spatial_dims_; const bool depthwise_; + std::vector dilations_; std::vector strides_; Padding padding_; TensorFormat data_format_ = FORMAT_NHWC; @@ -241,6 +331,7 @@ class ConvBackpropInputOp : public XlaOpKernel { : XlaOpKernel(ctx), num_spatial_dims_(num_spatial_dims), depthwise_(depthwise) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("dilations", &dilations_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &strides_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_)); string data_format; @@ -263,6 +354,22 @@ class ConvBackpropInputOp : public XlaOpKernel { errors::Unimplemented("Current implementation does not yet support " "strides in the batch and depth dimensions.")); + OP_REQUIRES(ctx, dilations_.size() == num_dims(), + errors::InvalidArgument("Dilations field must " + "specify ", + num_dims(), " dimensions")); + OP_REQUIRES( + ctx, dilations_[batch_dim] == 1 && dilations_[feature_dim] == 1, + errors::Unimplemented("Current implementation does not support " + "dilations in the batch and depth dimensions.")); + for (int i = 0; i < num_spatial_dims_; ++i) { + int input_dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i); + OP_REQUIRES(ctx, dilations_[input_dim] >= 1, + errors::Unimplemented("Dilation values must be positive; ", i, + "th spatial dimension had dilation ", + dilations_[input_dim])); + } + TensorShape input_shape; OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &input_shape)); @@ -274,10 +381,11 @@ class ConvBackpropInputOp : public XlaOpKernel { : filter_shape; // Reuse dimension computation logic from conv_grad_ops.cc. ConvBackpropDimensions dims; - OP_REQUIRES_OK(ctx, ConvBackpropComputeDimensions( - type_string(), num_spatial_dims_, input_shape, - expanded_filter_shape, out_backprop_shape, strides_, - padding_, data_format_, &dims)); + OP_REQUIRES_OK(ctx, + ConvBackpropComputeDimensionsV2( + type_string(), num_spatial_dims_, input_shape, + expanded_filter_shape, out_backprop_shape, dilations_, + strides_, padding_, data_format_, &dims)); xla::ComputationBuilder* b = ctx->builder(); auto filter = ctx->Input(1); @@ -301,6 +409,7 @@ class ConvBackpropInputOp : public XlaOpKernel { std::vector kernel_spatial_dims(num_spatial_dims_); std::vector> padding(num_spatial_dims_); std::vector lhs_dilation(num_spatial_dims_); + std::vector rhs_dilation(num_spatial_dims_); std::vector ones(num_spatial_dims_, 1); for (int i = 0; i < num_spatial_dims_; ++i) { int64 dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i); @@ -312,6 +421,7 @@ class ConvBackpropInputOp : public XlaOpKernel { padding[i] = {dims.spatial_dims[i].pad_before, dims.spatial_dims[i].pad_after}; lhs_dilation[i] = dims.spatial_dims[i].stride; + rhs_dilation[i] = dilations_[dim]; } // If this is a depthwise convolution, expand the filter. @@ -328,7 +438,7 @@ class ConvBackpropInputOp : public XlaOpKernel { // = gradients (with padding and dilation) mirrored_weights xla::ComputationDataHandle in_backprop = b->ConvGeneralDilated( out_backprop, mirrored_weights, /*window_strides=*/ones, padding, - lhs_dilation, /*rhs_dilation=*/ones, dnums); + lhs_dilation, rhs_dilation, dnums); ctx->SetOutput(0, in_backprop); } @@ -336,6 +446,7 @@ class ConvBackpropInputOp : public XlaOpKernel { protected: const int num_spatial_dims_; const bool depthwise_; + std::vector dilations_; std::vector strides_; Padding padding_; TensorFormat data_format_ = FORMAT_NHWC; @@ -349,21 +460,26 @@ class Conv2DBackpropInputOp : public ConvBackpropInputOp { explicit Conv2DBackpropInputOp(OpKernelConstruction* ctx) : ConvBackpropInputOp(ctx, /*num_spatial_dims=*/2, /*depthwise=*/false) {} }; -REGISTER_XLA_OP(Name("Conv2DBackpropInput"), Conv2DBackpropInputOp); +REGISTER_XLA_OP( + Name("Conv2DBackpropInput").CompileTimeConstInput("input_sizes"), + Conv2DBackpropInputOp); class Conv3DBackpropInputOp : public ConvBackpropInputOp { public: explicit Conv3DBackpropInputOp(OpKernelConstruction* ctx) : ConvBackpropInputOp(ctx, /*num_spatial_dims=*/3, /*depthwise=*/false) {} }; -REGISTER_XLA_OP(Name("Conv3DBackpropInputV2"), Conv3DBackpropInputOp); +REGISTER_XLA_OP( + Name("Conv3DBackpropInputV2").CompileTimeConstInput("input_sizes"), + Conv3DBackpropInputOp); class DepthwiseConv2DBackpropInputOp : public ConvBackpropInputOp { public: explicit DepthwiseConv2DBackpropInputOp(OpKernelConstruction* ctx) : ConvBackpropInputOp(ctx, /*num_spatial_dims=*/2, /*depthwise=*/true) {} }; -REGISTER_XLA_OP(Name("DepthwiseConv2dNativeBackpropInput"), +REGISTER_XLA_OP(Name("DepthwiseConv2dNativeBackpropInput") + .CompileTimeConstInput("input_sizes"), DepthwiseConv2DBackpropInputOp); class ConvBackpropFilterOp : public XlaOpKernel { @@ -373,6 +489,7 @@ class ConvBackpropFilterOp : public XlaOpKernel { : XlaOpKernel(ctx), num_spatial_dims_(num_spatial_dims), depthwise_(depthwise) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("dilations", &dilations_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &strides_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_)); string data_format; @@ -392,6 +509,22 @@ class ConvBackpropFilterOp : public XlaOpKernel { errors::InvalidArgument("Current implementation does not yet support " "strides in the batch and depth dimensions.")); + OP_REQUIRES(ctx, dilations_.size() == num_dims(), + errors::InvalidArgument("Dilations field must " + "specify ", + num_dims(), " dimensions")); + OP_REQUIRES( + ctx, dilations_[n_dim] == 1 && dilations_[c_dim] == 1, + errors::Unimplemented("Current implementation does not support " + "dilations in the batch and depth dimensions.")); + for (int i = 0; i < num_spatial_dims_; ++i) { + int input_dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i); + OP_REQUIRES(ctx, dilations_[input_dim] >= 1, + errors::Unimplemented("Dilation values must be positive; ", i, + "th spatial dimension had dilation ", + dilations_[input_dim])); + } + const TensorShape activations_shape = ctx->InputShape(0); TensorShape filter_shape; OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(1, &filter_shape)); @@ -403,10 +536,11 @@ class ConvBackpropFilterOp : public XlaOpKernel { // Reuse dimension computation logic from conv_grad_ops.cc. ConvBackpropDimensions dims; - OP_REQUIRES_OK(ctx, ConvBackpropComputeDimensions( - type_string(), num_spatial_dims_, activations_shape, - expanded_filter_shape, out_backprop_shape, strides_, - padding_, data_format_, &dims)); + OP_REQUIRES_OK(ctx, + ConvBackpropComputeDimensionsV2( + type_string(), num_spatial_dims_, activations_shape, + expanded_filter_shape, out_backprop_shape, dilations_, + strides_, padding_, data_format_, &dims)); xla::ComputationBuilder* b = ctx->builder(); xla::ComputationDataHandle activations = ctx->Input(0); @@ -426,9 +560,7 @@ class ConvBackpropFilterOp : public XlaOpKernel { // Swap n_dim and c_dim in the activations. dnums.set_input_batch_dimension(c_dim); - dnums.set_output_batch_dimension(c_dim); dnums.set_input_feature_dimension(n_dim); - dnums.set_output_feature_dimension(n_dim); // The gradients become the RHS of the convolution. // The gradients have shape [batch, out_rows, out_cols, ..., out_depth] @@ -438,21 +570,29 @@ class ConvBackpropFilterOp : public XlaOpKernel { std::vector> padding(num_spatial_dims_); std::vector rhs_dilation(num_spatial_dims_); + std::vector window_strides(num_spatial_dims_); std::vector ones(num_spatial_dims_, 1); + // Tensorflow filter shape is [ H, W, ..., inC, outC ]. + for (int i = 0; i < num_spatial_dims_; ++i) { + dnums.add_output_spatial_dimensions(i); + } + dnums.set_output_batch_dimension(num_spatial_dims_); + dnums.set_output_feature_dimension(num_spatial_dims_ + 1); + for (int i = 0; i < num_spatial_dims_; ++i) { int64 dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i); dnums.add_input_spatial_dimensions(dim); dnums.add_kernel_spatial_dimensions(dim); - dnums.add_output_spatial_dimensions(dim); // We will also need to pad the input with zeros such that after the // convolution, we get the right size for the filter. // The padded_in_rows should be such that when we convolve this with the // expanded_out_rows as a filter, we should get filter_rows back. // - const int padded_in_size = dims.spatial_dims[i].expanded_output_size + - dims.spatial_dims[i].filter_size - 1; + const int64 padded_in_size = + dims.spatial_dims[i].expanded_output_size + + (dims.spatial_dims[i].filter_size - 1) * dilations_[dim]; // However it can be smaller than input_rows: in this // case it means some of the inputs are not used. @@ -468,8 +608,7 @@ class ConvBackpropFilterOp : public XlaOpKernel { // and input "C" is not used at all. // // We apply negative padding in this case. - const int total_pad_in_size = - padded_in_size - dims.spatial_dims[i].input_size; + const int64 pad_total = padded_in_size - dims.spatial_dims[i].input_size; // + For the VALID padding, we don't pad anything on the top/left side // and pad the bottom/right side with the remaining space. @@ -479,13 +618,12 @@ class ConvBackpropFilterOp : public XlaOpKernel { // In addition, if the padded input size is smaller than the input size, // we need to ignore some training elements of the input. We do this by // applying negative padding on the right/bottom. - const int before_pad_in_size = - (total_pad_in_size > 0 && padding_ == Padding::SAME) - ? total_pad_in_size / 2 - : 0; + const int64 pad_before = + padding_ == Padding::SAME ? std::max(pad_total / 2, 0) : 0; - padding[i] = {before_pad_in_size, total_pad_in_size - before_pad_in_size}; + padding[i] = {pad_before, pad_total - pad_before}; rhs_dilation[i] = dims.spatial_dims[i].stride; + window_strides[i] = dilations_[dim]; } // Besides padding the input, we will also expand output_rows to @@ -497,35 +635,20 @@ class ConvBackpropFilterOp : public XlaOpKernel { // This is done by specifying the window dilation factors in the // convolution HLO below. auto filter_backprop = - b->ConvGeneralDilated(activations, gradients, - /*window_strides=*/ones, padding, + b->ConvGeneralDilated(activations, gradients, window_strides, padding, /*lhs_dilation=*/ones, rhs_dilation, dnums); - // The layout of filter_backprop will match the layout of - // padded_activations - // and so will have layout: [out_feature, h, w, ..., in_feature] - // Tensorflow filter shape is [ H, W, ..., inC, outC ], so we transpose the - // output. - std::vector transpose_dims; - transpose_dims.reserve(num_dims()); - for (int i = 0; i < num_spatial_dims_; ++i) { - transpose_dims.push_back(dnums.output_spatial_dimensions(i)); - } - transpose_dims.push_back(c_dim); - transpose_dims.push_back(n_dim); - xla::ComputationDataHandle filter_backprop_reshaped = - b->Transpose(filter_backprop, transpose_dims); - if (depthwise_) { - filter_backprop_reshaped = ContractFilterForDepthwiseBackprop( - filter_shape, ctx->input_type(0), filter_backprop_reshaped, b); + filter_backprop = ContractFilterForDepthwiseBackprop( + ctx, filter_shape, ctx->input_type(0), filter_backprop, b); } - ctx->SetOutput(0, filter_backprop_reshaped); + ctx->SetOutput(0, filter_backprop); } protected: const int num_spatial_dims_; const bool depthwise_; + std::vector dilations_; std::vector strides_; Padding padding_; TensorFormat data_format_ = FORMAT_NHWC; @@ -540,7 +663,9 @@ class Conv2DBackpropFilterOp : public ConvBackpropFilterOp { : ConvBackpropFilterOp(ctx, /*num_spatial_dims=*/2, /*depthwise=*/false) { } }; -REGISTER_XLA_OP(Name("Conv2DBackpropFilter"), Conv2DBackpropFilterOp); +REGISTER_XLA_OP( + Name("Conv2DBackpropFilter").CompileTimeConstInput("filter_sizes"), + Conv2DBackpropFilterOp); class Conv3DBackpropFilterOp : public ConvBackpropFilterOp { public: @@ -548,14 +673,17 @@ class Conv3DBackpropFilterOp : public ConvBackpropFilterOp { : ConvBackpropFilterOp(ctx, /*num_spatial_dims=*/3, /*depthwise=*/false) { } }; -REGISTER_XLA_OP(Name("Conv3DBackpropFilterV2"), Conv3DBackpropFilterOp); +REGISTER_XLA_OP( + Name("Conv3DBackpropFilterV2").CompileTimeConstInput("filter_sizes"), + Conv3DBackpropFilterOp); class DepthwiseConv2DBackpropFilterOp : public ConvBackpropFilterOp { public: explicit DepthwiseConv2DBackpropFilterOp(OpKernelConstruction* ctx) : ConvBackpropFilterOp(ctx, /*num_spatial_dims=*/2, /*depthwise=*/true) {} }; -REGISTER_XLA_OP(Name("DepthwiseConv2dNativeBackpropFilter"), +REGISTER_XLA_OP(Name("DepthwiseConv2dNativeBackpropFilter") + .CompileTimeConstInput("filter_sizes"), DepthwiseConv2DBackpropFilterOp); } // namespace diff --git a/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc b/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc index a4ea65ea89e348cb77412efb0c5c0fcb1a9f33f3..96d7809f7995634b6bc31ab801b93526d9da7e6f 100644 --- a/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/core/util/tensor_format.h" namespace tensorflow { namespace { @@ -23,6 +24,16 @@ namespace { class DepthToSpaceOp : public XlaOpKernel { public: explicit DepthToSpaceOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + string data_format_str; + OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str)); + OP_REQUIRES(ctx, FormatFromString(data_format_str, &data_format_), + errors::InvalidArgument("Invalid data format")); + + OP_REQUIRES(ctx, data_format_ == FORMAT_NCHW || data_format_ == FORMAT_NHWC, + errors::InvalidArgument("Unsupported data format ", + ToString(data_format_), + "; expected formats NHWC or NCHW")); + OP_REQUIRES_OK(ctx, ctx->GetAttr("block_size", &block_size_)); OP_REQUIRES( ctx, block_size_ > 1, @@ -31,18 +42,79 @@ class DepthToSpaceOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { const TensorShape input_tensor_shape = ctx->InputShape(0); - // The input is presumed to be [batch, height, width, depth] int input_rank = input_tensor_shape.dims(); static const int kRequiredDims = 4; OP_REQUIRES(ctx, kRequiredDims == input_rank, - errors::InvalidArgument("Input rank should be: ", kRequiredDims, - " instead of: ", input_rank)); + errors::InvalidArgument("Input rank should be ", kRequiredDims, + "; got: ", input_rank)); const gtl::InlinedVector input_shape = input_tensor_shape.dim_sizes(); xla::ComputationBuilder* b = ctx->builder(); xla::ComputationDataHandle input = ctx->Input(0); + int feature_dim = GetTensorFeatureDimIndex(input_rank, data_format_); + int num_spatial_dims = GetTensorSpatialDims(input_rank, data_format_); + + std::vector reshaped_shape; + std::vector transpose_order; + std::vector output_shape; + reshaped_shape.reserve(input_rank); + transpose_order.reserve(input_rank); + output_shape.reserve(input_rank); + if (data_format_ == FORMAT_NHWC) { + reshaped_shape.push_back(input_shape[0]); + for (int i = 0; i < num_spatial_dims; ++i) { + reshaped_shape.push_back(input_shape[1 + i]); + } + int64 block_elems = 1; + for (int i = 0; i < num_spatial_dims; ++i) { + reshaped_shape.push_back(block_size_); + block_elems *= block_size_; + } + reshaped_shape.push_back(input_shape[feature_dim] / block_elems); + + transpose_order.push_back(0); + for (int i = 0; i < num_spatial_dims; ++i) { + transpose_order.push_back(i + 1); + transpose_order.push_back(i + 1 + num_spatial_dims); + } + transpose_order.push_back(feature_dim + num_spatial_dims); + + output_shape.push_back(input_shape[0]); + for (int i = 0; i < num_spatial_dims; ++i) { + output_shape.push_back(input_shape[1 + i] * block_size_); + } + output_shape.push_back(input_shape[feature_dim] / block_elems); + } else { + // NCHW format. + reshaped_shape.push_back(input_shape[0]); + int64 block_elems = 1; + for (int i = 0; i < num_spatial_dims; ++i) { + reshaped_shape.push_back(block_size_); + block_elems *= block_size_; + } + reshaped_shape.push_back(input_shape[feature_dim] / block_elems); + for (int i = 0; i < num_spatial_dims; ++i) { + reshaped_shape.push_back(input_shape[2 + i]); + } + + transpose_order.push_back(0); + transpose_order.push_back(1 + num_spatial_dims); + for (int i = 0; i < num_spatial_dims; ++i) { + transpose_order.push_back(2 + num_spatial_dims + i); + transpose_order.push_back(1 + i); + } + + output_shape.push_back(input_shape[0]); + output_shape.push_back(input_shape[feature_dim] / block_elems); + for (int i = 0; i < num_spatial_dims; ++i) { + output_shape.push_back(input_shape[2 + i] * block_size_); + } + } + + // Note: comments are given in NHWC format; NCHW is similar with a different + // dimension order. // 1. Reshape `input` to `reshaped` of shape: // // [batch, @@ -51,14 +123,14 @@ class DepthToSpaceOp : public XlaOpKernel { // block_size_, // block_size_, // depth / (block_size_ * block_size_)] - OP_REQUIRES(ctx, input_shape[3] % (block_size_ * block_size_) == 0, + OP_REQUIRES(ctx, + input_shape[feature_dim] % (block_size_ * block_size_) == 0, errors::InvalidArgument( "Input depth dimension (", input_shape[3], ") is not divisible by square of the block size (", block_size_, ")")); - xla::ComputationDataHandle reshaped = b->Reshape( - input, {input_shape[0], input_shape[1], input_shape[2], block_size_, - block_size_, input_shape[3] / (block_size_ * block_size_)}); + + xla::ComputationDataHandle reshaped = b->Reshape(input, reshaped_shape); // 2. Permute dimensions of `reshaped` to produce // `permuted_reshaped` of shape: @@ -70,7 +142,7 @@ class DepthToSpaceOp : public XlaOpKernel { // block_size_, // depth / (block_size_ * block_size_)] xla::ComputationDataHandle permuted_reshaped = - b->Transpose(reshaped, {0, 1, 3, 2, 4, 5}); + b->Transpose(reshaped, transpose_order); // 3. Reshape `permuted_reshaped` to flatten `block_shape` into the // batch dimension, producing an output tensor of shape: @@ -80,15 +152,14 @@ class DepthToSpaceOp : public XlaOpKernel { // input_shape[2] * block_size_, // depth / (block_size_ * block_size_)] // - xla::ComputationDataHandle output = b->Reshape( - permuted_reshaped, {input_shape[0], input_shape[1] * block_size_, - input_shape[2] * block_size_, - input_shape[3] / (block_size_ * block_size_)}); + xla::ComputationDataHandle output = + b->Reshape(permuted_reshaped, output_shape); ctx->SetOutput(0, output); } private: + TensorFormat data_format_; int block_size_; }; REGISTER_XLA_OP(Name("DepthToSpace"), DepthToSpaceOp); diff --git a/tensorflow/compiler/tf2xla/kernels/diag_op.cc b/tensorflow/compiler/tf2xla/kernels/diag_op.cc index ec5017f6ab96bd3fc273a746b77fbb7e74fd9f35..765ea922a532a085a552192348ab360c4c30ff0a 100644 --- a/tensorflow/compiler/tf2xla/kernels/diag_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/diag_op.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/compiler/tf2xla/lib/util.h" +#include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" @@ -22,6 +24,62 @@ limitations under the License. namespace tensorflow { namespace { +// Create a diagonal / batch diagonal matrix with 'input' on the diagonal. +xla::StatusOr CreateDiagonal( + const xla::ComputationDataHandle& input, int64 last_dim_size, + tensorflow::gtl::ArraySlice other_dims, XlaOpKernelContext* ctx, + xla::ComputationBuilder* builder) { + // Create two matrices that have the following forms, and compare them: + // + // [[0, 0, 0, 0] [[0, 1, 2, 3] + // [1, 1, 1, 1] [0, 1, 2, 3] + // [2, 2, 2, 2] [0, 1, 2, 3] + // [3, 3, 3, 3]] [0, 1, 2, 3]] + // + // This produces a predicate matrix of the right size, with "true" on the + // diagonal. + xla::ComputationDataHandle iota; + TF_RETURN_IF_ERROR( + XlaHelpers::Iota(builder, DataType::DT_INT32, last_dim_size, &iota)); + xla::ComputationDataHandle iota_broadcast = + builder->Broadcast(iota, {last_dim_size}); + xla::ComputationDataHandle mask = builder->Eq(iota_broadcast, iota, {0}); + + // If this is a batched diagonal, broadcast the mask across the other + // dimensions. + if (!other_dims.empty()) { + mask = builder->Broadcast(mask, other_dims); + } + + // Broadcast the input, and then use the mask computed above to select the + // diagonal: + // e.g, in 2D: + // [[t, f, f] [[1, 1, 1] [[0, 0, 0] [[1, 0, 0] + // select( [f, t, f] , [4, 4, 4] , [0, 0, 0] ) = [0, 4, 0] + // [f, f, t]] [9, 9, 9]] [0, 0, 0]] [0, 0, 9]] + // + // Broadcasting the input is less-than-trivial, since we need to broadcast + // into a "middle" dimension. We can do this with a reshape + implicit + // broadcast. + // TODO(b/30112114): Replace with in-dim broadcast when those are supported. + std::vector broadcast_dims(other_dims.begin(), other_dims.end()); + broadcast_dims.push_back(1LL); + broadcast_dims.push_back(last_dim_size); + xla::ComputationDataHandle input_broadcast = + builder->Reshape(input, broadcast_dims); + + broadcast_dims[broadcast_dims.size() - 2] = last_dim_size; + xla::PrimitiveType element_type; + TF_RETURN_IF_ERROR( + DataTypeToPrimitiveType(ctx->input_type(0), &element_type)); + auto broadcast_shape = + xla::ShapeUtil::MakeShape(element_type, broadcast_dims); + xla::ComputationDataHandle zeros = Zeros(builder, broadcast_shape); + + input_broadcast = builder->Add(input_broadcast, zeros); + return builder->Select(mask, input_broadcast, zeros); +} + class DiagOp : public XlaOpKernel { public: explicit DiagOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} @@ -29,6 +87,8 @@ class DiagOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { xla::ComputationBuilder* builder = ctx->builder(); + OP_REQUIRES(ctx, ctx->num_inputs() >= 1, + errors::InvalidArgument("Diag op must have at an input")); const TensorShape input_shape = ctx->InputShape(0); auto dims = input_shape.dim_sizes(); @@ -36,7 +96,7 @@ class DiagOp : public XlaOpKernel { errors::InvalidArgument("Expected 1 <= dims, got shape ", input_shape.DebugString())); - xla::ComputationDataHandle diag = ctx->Input(0); + xla::ComputationDataHandle input = ctx->Input(0); // Picture: // tf.diag([1, 2, 3, 4]) ==> [[1, 0, 0, 0] @@ -46,13 +106,13 @@ class DiagOp : public XlaOpKernel { // Flattens the input to 1D. int64 size = input_shape.num_elements(); - diag = builder->Reshape(diag, {size}); + input = builder->Reshape(input, {size}); - // Adds inter-element padding of 'size'. - xla::PaddingConfig config; - auto* dim = config.add_dimensions(); - dim->set_interior_padding(size); - diag = builder->Pad(diag, XlaHelpers::Zero(builder, input_type(0)), config); + // Create an R2 with the R1 diagonal. + auto diag_or_status = + CreateDiagonal(input, size, /*other_dims=*/{}, ctx, builder); + OP_REQUIRES_OK(ctx, diag_or_status.status()); + xla::ComputationDataHandle diag = diag_or_status.ValueOrDie(); // Reshapes to the final shape. std::vector new_dims(dims.size() * 2); @@ -141,6 +201,8 @@ class MatrixDiagOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { xla::ComputationBuilder* builder = ctx->builder(); + OP_REQUIRES(ctx, ctx->num_inputs() >= 1, + errors::InvalidArgument("MatrixDiag op must have at an input")); const TensorShape input_shape = ctx->InputShape(0); auto dims = input_shape.dim_sizes(); @@ -152,17 +214,13 @@ class MatrixDiagOp : public XlaOpKernel { int last_dim = dims.size() - 1; int64 last_dim_size = input_shape.dim_size(last_dim); + tensorflow::gtl::ArraySlice other_dims(dims); + other_dims.pop_back(); - // Adds inter-element padding of 'last_dim_size' to the last dimension. - xla::PaddingConfig config = xla::MakeNoPaddingConfig(dims.size()); - auto* dim = config.mutable_dimensions(last_dim); - dim->set_interior_padding(last_dim_size); - diag = builder->Pad(diag, XlaHelpers::Zero(builder, input_type(0)), config); - - // Reshapes to the final shape. - dims.push_back(last_dim_size); - diag = builder->Reshape(diag, dims); - + auto diag_or_status = + CreateDiagonal(diag, last_dim_size, other_dims, ctx, builder); + OP_REQUIRES_OK(ctx, diag_or_status.status()); + diag = diag_or_status.ValueOrDie(); ctx->SetOutput(0, diag); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc index 7349dcb987cd88c423570889c0502d1a0bd12c52..f2cd21ffb9ce88747c04f3c71e66dadeb1faf0f9 100644 --- a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc @@ -72,22 +72,24 @@ class DynamicStitchOp : public XlaOpKernel { XLAShapeToTensorShape(indices_input[input_num].shape(), &indices_shape)); const TensorShape& data_shape = data_shapes[input_num]; - OP_REQUIRES(ctx, TensorShapeUtils::StartsWith(data_shape, indices_shape), - errors::InvalidArgument( - "data[", input_num, "].shape = ", - data_shape.DebugString(), " does not start with indices[", - input_num, "].shape = ", indices_shape.DebugString())); - OP_REQUIRES(ctx, - input_num == 0 || SameExtraShape(data0_shape, indices0_shape, - data_shape, indices_shape), - errors::InvalidArgument( - "Need data[0].shape[", indices0_shape.dims(), - ":] = data[", input_num, "].shape[", indices_shape.dims(), - ":], got data[0].shape = ", data0_shape.DebugString(), - ", data[", input_num, "].shape = ", - data_shape.DebugString(), ", indices[0].shape = ", - indices0_shape.DebugString(), ", indices[", input_num, - "].shape = ", indices_shape.DebugString())); + OP_REQUIRES( + ctx, TensorShapeUtils::StartsWith(data_shape, indices_shape), + errors::InvalidArgument("data[", input_num, + "].shape = ", data_shape.DebugString(), + " does not start with indices[", input_num, + "].shape = ", indices_shape.DebugString())); + OP_REQUIRES( + ctx, + input_num == 0 || SameExtraShape(data0_shape, indices0_shape, + data_shape, indices_shape), + errors::InvalidArgument( + "Need data[0].shape[", indices0_shape.dims(), ":] = data[", + input_num, "].shape[", indices_shape.dims(), + ":], got data[0].shape = ", data0_shape.DebugString(), ", data[", + input_num, "].shape = ", data_shape.DebugString(), + ", indices[0].shape = ", indices0_shape.DebugString(), + ", indices[", input_num, + "].shape = ", indices_shape.DebugString())); OP_REQUIRES_OK(ctx, XlaHelpers::ReshapeLiteral(indices_input[input_num], @@ -159,8 +161,8 @@ class DynamicStitchOp : public XlaOpKernel { indices0_shape.dims()); std::vector slice_limit(1 + data0_shape.dims() - indices0_shape.dims()); - std::vector stride(1 + data0_shape.dims() - - indices0_shape.dims(), 1); + std::vector stride(1 + data0_shape.dims() - indices0_shape.dims(), + 1); for (int d = indices0_shape.dims(); d < data0_shape.dims(); d++) { slice_limit[1 + d - indices0_shape.dims()] = data0_shape.dim_size(d); } @@ -198,8 +200,10 @@ class DynamicStitchOp : public XlaOpKernel { } }; -REGISTER_XLA_OP(Name("DynamicStitch"), DynamicStitchOp); -REGISTER_XLA_OP(Name("ParallelDynamicStitch"), DynamicStitchOp); +REGISTER_XLA_OP(Name("DynamicStitch").CompileTimeConstInput("indices"), + DynamicStitchOp); +REGISTER_XLA_OP(Name("ParallelDynamicStitch").CompileTimeConstInput("indices"), + DynamicStitchOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/fft_ops.cc b/tensorflow/compiler/tf2xla/kernels/fft_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..a4f3c1c3ad9a928e0552c388a25ed9fcb08edabb --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/fft_ops.cc @@ -0,0 +1,122 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// XLA-specific Ops for FFT. + +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/core/framework/numeric_op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_slice.h" +#include "tensorflow/core/kernels/bounds_check.h" +#include "tensorflow/core/kernels/conv_grad_ops.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/util/padding.h" +#include "tensorflow/core/util/tensor_format.h" + +namespace tensorflow { + +namespace { + +using xla::FftType; + +class GenericFftOp : public XlaOpKernel { + public: + explicit GenericFftOp(OpKernelConstruction* ctx, FftType fft_type, + int fft_rank) + : XlaOpKernel(ctx), fft_type_(fft_type), fft_rank_(fft_rank) {} + + void Compile(XlaOpKernelContext* ctx) override { + const TensorShape input_shape = ctx->InputShape(0); + OP_REQUIRES( + ctx, TensorShapeUtils::IsVectorOrHigher(input_shape), + errors::InvalidArgument("input must be at least 1 dimensional")); + + std::vector fft_length; + if (fft_type_ == FftType::RFFT || fft_type_ == FftType::IRFFT) { + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &fft_length)); + OP_REQUIRES(ctx, fft_length.size() == fft_rank_, + errors::InvalidArgument("fft_length must be length ", + fft_rank_, " vector")); + } else { + // Innermost axis provides the FFT length. + for (int i = 0; i < fft_rank_; i++) { + fft_length.push_back( + input_shape.dim_size(input_shape.dims() - fft_rank_ + i)); + } + } + + xla::ComputationBuilder* b = ctx->builder(); + xla::ComputationDataHandle fft = + b->Fft(ctx->Input(0), fft_type_, fft_length); + ctx->SetOutput(0, fft); + } + + protected: + const FftType fft_type_; + const int fft_rank_; + + private: + TF_DISALLOW_COPY_AND_ASSIGN(GenericFftOp); +}; + +template +class FFTOp : public GenericFftOp { + public: + explicit FFTOp(OpKernelConstruction* ctx) + : GenericFftOp(ctx, /*fft_type=*/FftType::FFT, /*fft_rank=*/FFTRank) {} +}; +REGISTER_XLA_OP(Name("FFT"), FFTOp<1>); +REGISTER_XLA_OP(Name("FFT2D"), FFTOp<2>); +REGISTER_XLA_OP(Name("FFT3D"), FFTOp<3>); + +template +class IFFTOp : public GenericFftOp { + public: + explicit IFFTOp(OpKernelConstruction* ctx) + : GenericFftOp(ctx, /*fft_type=*/FftType::IFFT, /*fft_rank=*/FFTRank) {} +}; +REGISTER_XLA_OP(Name("IFFT"), IFFTOp<1>); +REGISTER_XLA_OP(Name("IFFT2D"), IFFTOp<2>); +REGISTER_XLA_OP(Name("IFFT3D"), IFFTOp<3>); + +template +class RFFTOp : public GenericFftOp { + public: + explicit RFFTOp(OpKernelConstruction* ctx) + : GenericFftOp(ctx, /*fft_type=*/FftType::RFFT, /*fft_rank=*/FFTRank) {} +}; +REGISTER_XLA_OP(Name("RFFT").CompileTimeConstInput("fft_length"), RFFTOp<1>); +REGISTER_XLA_OP(Name("RFFT2D").CompileTimeConstInput("fft_length"), RFFTOp<2>); +REGISTER_XLA_OP(Name("RFFT3D").CompileTimeConstInput("fft_length"), RFFTOp<3>); + +template +class IRFFTOp : public GenericFftOp { + public: + explicit IRFFTOp(OpKernelConstruction* ctx) + : GenericFftOp(ctx, /*fft_type=*/FftType::IRFFT, /*fft_rank=*/FFTRank) {} +}; +REGISTER_XLA_OP(Name("IRFFT").CompileTimeConstInput("fft_length"), IRFFTOp<1>); +REGISTER_XLA_OP(Name("IRFFT2D").CompileTimeConstInput("fft_length"), + IRFFTOp<2>); +REGISTER_XLA_OP(Name("IRFFT3D").CompileTimeConstInput("fft_length"), + IRFFTOp<3>); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/fill_op.cc b/tensorflow/compiler/tf2xla/kernels/fill_op.cc index 9e090fe01cbfd4dab81b0de21e3a44e42c2ef18e..eaa13b8dfacce9aaca42ce5fcdfa467ce7fa7b7f 100644 --- a/tensorflow/compiler/tf2xla/kernels/fill_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/fill_op.cc @@ -69,7 +69,7 @@ class FillOp : public XlaOpKernel { } }; -REGISTER_XLA_OP(Name("Fill"), FillOp); +REGISTER_XLA_OP(Name("Fill").CompileTimeConstInput("dims"), FillOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op.cc b/tensorflow/compiler/tf2xla/kernels/gather_op.cc index e420f21ca33fe7de9b33f404ce04eae62d9c041e..ffed38249416766850ba10f1069e706570b995fe 100644 --- a/tensorflow/compiler/tf2xla/kernels/gather_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/gather_op.cc @@ -46,26 +46,15 @@ xla::ComputationDataHandle XlaComputeGatherDynamicSlice( TensorShape slice_shape(input_shape); slice_shape.set_dim(axis, 1); - // TODO(b/37575001) The tensor in which we construct the output during - // the loop must have rank >= 3 as a workaround for lowering issues. - int64 extra_dims = 0; - if (input_shape.dims() < 3) extra_dims = 3 - input_shape.dims(); - TensorShape loop_out_shape; - for (int64 k = 0; k < extra_dims; ++k) loop_out_shape.AddDim(1); loop_out_shape.AppendShape(input_shape_pre_axis); loop_out_shape.AddDim(num_indices); loop_out_shape.AppendShape(input_shape_post_axis); - - // Slices are reshaped into the rank >= 3 shape of the loop carried output. TensorShape loop_out_slice_shape; - for (int64 k = 0; k < extra_dims; ++k) loop_out_slice_shape.AddDim(1); loop_out_slice_shape.AppendShape(input_shape_pre_axis); loop_out_slice_shape.AddDim(1); loop_out_slice_shape.AppendShape(input_shape_post_axis); - // Finally, the loop-carried rank >= 3 output is reshaped to the op's - // specified result shape. TensorShape out_shape; out_shape.AppendShape(input_shape_pre_axis); out_shape.AppendShape(indices_shape); @@ -89,7 +78,7 @@ xla::ComputationDataHandle XlaComputeGatherDynamicSlice( xla::ShapeUtil::MakeShape(ptype, input_shape.dim_sizes()), // The gather indices are reshaped to rank 1. Loop invariant. xla::ShapeUtil::MakeShape(idxtype, {num_indices}), - // The output array is rank >= 3, and is updated on each loop iteration. + // The output array, which is updated on each loop iteration. xla::ShapeUtil::MakeShape(ptype, loop_out_shape.dim_sizes())}); xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(tuple_shapes); @@ -135,12 +124,11 @@ xla::ComputationDataHandle XlaComputeGatherDynamicSlice( bodyb.DynamicSlice(input, start_indices, slice_shape.dim_sizes()), loop_out_slice_shape.dim_sizes()); - // Construct the index into the R3+ output Tensor 0, ..., , 0, ... + // Construct the index into the output Tensor 0, ..., , 0, ... std::vector out_index_vals( loop_out_shape.dims(), bodyb.Reshape(XlaHelpers::Zero(&bodyb, index_type), {1})); - out_index_vals[input_shape_pre_axis.dims() + extra_dims] = - bodyb.Reshape(i, {1}); + out_index_vals[input_shape_pre_axis.dims()] = bodyb.Reshape(i, {1}); auto out_index = bodyb.ConcatInDim(out_index_vals, 0); // Update the output Tensor @@ -198,6 +186,7 @@ void GatherOpDynamicSlice::Compile(XlaOpKernelContext* context) { } REGISTER_XLA_OP(Name("Gather"), GatherOpDynamicSlice); -REGISTER_XLA_OP(Name("GatherV2"), GatherOpDynamicSlice); +REGISTER_XLA_OP(Name("GatherV2").CompileTimeConstInput("axis"), + GatherOpDynamicSlice); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/image_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..f22f384256a8ddd8c05de4a1322aba741dc4d7fd --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/image_ops.cc @@ -0,0 +1,305 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/lib/util.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" + +namespace tensorflow { +namespace { + +// Converts 'input' from RGB format to HSV format. +// 'shape' is the shape of the red/green/blue tensors. +std::array RGBToHSV( + XlaOpKernelContext* ctx, xla::ComputationBuilder* b, + const std::array& rgb, DataType dtype, + const TensorShape& shape) { + auto zero = XlaHelpers::Zero(b, dtype); + auto one = XlaHelpers::One(b, dtype); + + auto red = rgb[0]; + auto green = rgb[1]; + auto blue = rgb[2]; + auto value = b->Max(b->Max(red, green), blue); + auto minimum = b->Min(b->Min(red, green), blue); + auto range = b->Sub(value, minimum); + + auto zeros = b->Broadcast(zero, shape.dim_sizes()); + auto saturation = b->Select(b->Gt(value, zero), b->Div(range, value), zeros); + + auto norm = b->Div(XlaHelpers::FloatLiteral(b, dtype, 1.0 / 6.0), range); + + auto hue = b->Select(b->Eq(green, value), + b->Add(b->Mul(norm, b->Sub(blue, red)), + XlaHelpers::FloatLiteral(b, dtype, 2.0 / 6.0)), + b->Add(b->Mul(norm, b->Sub(red, green)), + XlaHelpers::FloatLiteral(b, dtype, 4.0 / 6.0))); + hue = b->Select(b->Eq(red, value), b->Mul(norm, b->Sub(green, blue)), hue); + hue = b->Select(b->Gt(range, zero), hue, zeros); + hue = b->Select(b->Lt(hue, zero), b->Add(hue, one), hue); + return {hue, saturation, value}; +} + +// Converts 'input' from HSV format to RGB format. +std::array HSVToRGB( + xla::ComputationBuilder* b, + const std::array& hsv, DataType dtype) { + xla::ComputationDataHandle hue = hsv[0]; + xla::ComputationDataHandle saturation = hsv[1]; + xla::ComputationDataHandle value = hsv[2]; + auto zero = XlaHelpers::Zero(b, dtype); + auto one = XlaHelpers::FloatLiteral(b, dtype, 1.0); + auto two = XlaHelpers::FloatLiteral(b, dtype, 2.0); + auto three = XlaHelpers::FloatLiteral(b, dtype, 3.0); + auto four = XlaHelpers::FloatLiteral(b, dtype, 4.0); + auto six = XlaHelpers::FloatLiteral(b, dtype, 6.0); + + auto dh = b->Mul(hue, six); + auto dr = b->Clamp(zero, b->Sub(b->Abs(b->Sub(dh, three)), one), one); + auto dg = b->Clamp(zero, b->Sub(two, b->Abs(b->Sub(dh, two))), one); + auto db = b->Clamp(zero, b->Sub(two, b->Abs(b->Sub(dh, four))), one); + auto one_minus_s = b->Sub(one, saturation); + + auto red = b->Mul(b->Add(one_minus_s, b->Mul(saturation, dr)), value); + auto green = b->Mul(b->Add(one_minus_s, b->Mul(saturation, dg)), value); + auto blue = b->Mul(b->Add(one_minus_s, b->Mul(saturation, db)), value); + return {red, green, blue}; +} + +class RGBToHSVOp : public XlaOpKernel { + public: + explicit RGBToHSVOp(OpKernelConstruction* context) : XlaOpKernel(context) {} + + void Compile(XlaOpKernelContext* context) override { + const TensorShape input_shape = context->InputShape(0); + OP_REQUIRES(context, input_shape.dims() >= 1, + errors::InvalidArgument("input must be at least 1D", + input_shape.DebugString())); + int channel_dim = input_shape.dims() - 1; + int64 channels = input_shape.dim_size(channel_dim); + OP_REQUIRES( + context, channels == 3, + errors::FailedPrecondition("input must have 3 channels but input has ", + channels, " channels.")); + + xla::ComputationBuilder* b = context->builder(); + xla::ComputationDataHandle input = context->Input(0); + + xla::ComputationDataHandle red = + b->SliceInDim(input, /*start_index=*/0, /*limit_index=*/1, /*stride=*/1, + /*dimno=*/channel_dim); + xla::ComputationDataHandle green = + b->SliceInDim(input, /*start_index=*/1, /*limit_index=*/2, /*stride=*/1, + /*dimno=*/channel_dim); + xla::ComputationDataHandle blue = + b->SliceInDim(input, /*start_index=*/2, /*limit_index=*/3, /*stride=*/1, + /*dimno=*/channel_dim); + TensorShape channel_shape = input_shape; + channel_shape.set_dim(channel_dim, 1); + auto hsv = RGBToHSV(context, b, {red, green, blue}, context->input_type(0), + channel_shape); + + context->SetOutput(0, b->ConcatInDim(hsv, channel_dim)); + } +}; +REGISTER_XLA_OP(Name("RGBToHSV"), RGBToHSVOp); + +class HSVToRGBOp : public XlaOpKernel { + public: + explicit HSVToRGBOp(OpKernelConstruction* context) : XlaOpKernel(context) {} + + void Compile(XlaOpKernelContext* context) override { + const TensorShape input_shape = context->InputShape(0); + OP_REQUIRES(context, input_shape.dims() >= 1, + errors::InvalidArgument("input must be at least 1D", + input_shape.DebugString())); + int channel_dim = input_shape.dims() - 1; + int64 channels = input_shape.dim_size(channel_dim); + OP_REQUIRES( + context, channels == 3, + errors::FailedPrecondition("input must have 3 channels but input has ", + channels, " channels.")); + + xla::ComputationBuilder* b = context->builder(); + xla::ComputationDataHandle input = context->Input(0); + xla::ComputationDataHandle hue = + b->SliceInDim(input, /*start_index=*/0, /*limit_index=*/1, /*stride=*/1, + /*dimno=*/channel_dim); + xla::ComputationDataHandle saturation = + b->SliceInDim(input, /*start_index=*/1, /*limit_index=*/2, /*stride=*/1, + /*dimno=*/channel_dim); + xla::ComputationDataHandle value = + b->SliceInDim(input, /*start_index=*/2, /*limit_index=*/3, /*stride=*/1, + /*dimno=*/channel_dim); + + auto rgb = HSVToRGB(context->builder(), {hue, saturation, value}, + context->input_type(0)); + + context->SetOutput(0, b->ConcatInDim(rgb, channel_dim)); + } +}; +REGISTER_XLA_OP(Name("HSVToRGB"), HSVToRGBOp); + +class AdjustContrastOpV2 : public XlaOpKernel { + public: + explicit AdjustContrastOpV2(OpKernelConstruction* context) + : XlaOpKernel(context) {} + + void Compile(XlaOpKernelContext* context) override { + const TensorShape& input_shape = context->InputShape(0); + const TensorShape& factor_shape = context->InputShape(1); + OP_REQUIRES(context, input_shape.dims() >= 3, + errors::InvalidArgument("input must be at least 3-D, got shape", + input_shape.DebugString())); + int height_dim = input_shape.dims() - 3; + int width_dim = input_shape.dims() - 2; + int channel_dim = input_shape.dims() - 1; + const int64 height = input_shape.dim_size(height_dim); + const int64 width = input_shape.dim_size(width_dim); + + OP_REQUIRES(context, TensorShapeUtils::IsScalar(factor_shape), + errors::InvalidArgument("contrast_factor must be scalar: ", + factor_shape.DebugString())); + + xla::ComputationBuilder* b = context->builder(); + xla::ComputationDataHandle input = context->Input(0); + xla::ComputationDataHandle factor = context->Input(1); + + DataType type = context->input_type(0); + + auto output = b->Reduce(input, /*init_value=*/XlaHelpers::Zero(b, type), + /*computation=*/*context->GetOrCreateAdd(type), + {height_dim, width_dim}); + output = b->Div(output, XlaHelpers::FloatLiteral(b, type, height * width)); + + std::vector broadcast_dims(input_shape.dims() - 2); + std::iota(broadcast_dims.begin(), broadcast_dims.end(), 0); + broadcast_dims.back() = channel_dim; + output = b->Add(b->Mul(input, factor), + b->Mul(output, b->Sub(XlaHelpers::One(b, type), factor)), + broadcast_dims); + context->SetOutput(0, output); + } +}; +REGISTER_XLA_OP(Name("AdjustContrastv2"), AdjustContrastOpV2); + +class AdjustSaturationOp : public XlaOpKernel { + public: + explicit AdjustSaturationOp(OpKernelConstruction* context) + : XlaOpKernel(context) {} + + void Compile(XlaOpKernelContext* context) override { + const TensorShape& input_shape = context->InputShape(0); + const TensorShape& scale_shape = context->InputShape(1); + OP_REQUIRES(context, input_shape.dims() >= 3, + errors::InvalidArgument("input must be at least 3-D, got shape", + input_shape.DebugString())); + OP_REQUIRES(context, TensorShapeUtils::IsScalar(scale_shape), + errors::InvalidArgument("scale must be scalar: ", + scale_shape.DebugString())); + const int channel_dim = input_shape.dims() - 1; + const int64 channels = input_shape.dim_size(channel_dim); + OP_REQUIRES( + context, channels == 3, + errors::InvalidArgument("input must have 3 channels but instead has ", + channels, " channels.")); + + xla::ComputationBuilder* b = context->builder(); + xla::ComputationDataHandle input = context->Input(0); + xla::ComputationDataHandle scale = context->Input(1); + + DataType type = context->input_type(0); + + xla::ComputationDataHandle red = + b->SliceInDim(input, /*start_index=*/0, /*limit_index=*/1, /*stride=*/1, + /*dimno=*/channel_dim); + xla::ComputationDataHandle green = + b->SliceInDim(input, /*start_index=*/1, /*limit_index=*/2, /*stride=*/1, + /*dimno=*/channel_dim); + xla::ComputationDataHandle blue = + b->SliceInDim(input, /*start_index=*/2, /*limit_index=*/3, /*stride=*/1, + /*dimno=*/channel_dim); + TensorShape channel_shape = input_shape; + channel_shape.set_dim(channel_dim, 1); + auto hsv = RGBToHSV(context, b, {red, green, blue}, context->input_type(0), + channel_shape); + + hsv[1] = b->Clamp(XlaHelpers::Zero(b, type), b->Mul(hsv[1], scale), + XlaHelpers::One(b, type)); + + auto rgb = HSVToRGB(context->builder(), hsv, context->input_type(0)); + + context->SetOutput(0, b->ConcatInDim(rgb, channel_dim)); + } +}; +REGISTER_XLA_OP(Name("AdjustSaturation"), AdjustSaturationOp); + +class AdjustHueOp : public XlaOpKernel { + public: + explicit AdjustHueOp(OpKernelConstruction* context) : XlaOpKernel(context) {} + + void Compile(XlaOpKernelContext* context) override { + const TensorShape& input_shape = context->InputShape(0); + const TensorShape& delta_shape = context->InputShape(1); + OP_REQUIRES(context, input_shape.dims() >= 3, + errors::InvalidArgument("input must be at least 3-D, got shape", + input_shape.DebugString())); + OP_REQUIRES(context, TensorShapeUtils::IsScalar(delta_shape), + errors::InvalidArgument("delta must be scalar: ", + delta_shape.DebugString())); + const int channel_dim = input_shape.dims() - 1; + const int64 channels = input_shape.dim_size(channel_dim); + OP_REQUIRES( + context, channels == 3, + errors::InvalidArgument("input must have 3 channels but instead has ", + channels, " channels.")); + + xla::ComputationBuilder* b = context->builder(); + xla::ComputationDataHandle input = context->Input(0); + xla::ComputationDataHandle delta = context->Input(1); + + DataType type = context->input_type(0); + + xla::ComputationDataHandle red = + b->SliceInDim(input, /*start_index=*/0, /*limit_index=*/1, /*stride=*/1, + /*dimno=*/channel_dim); + xla::ComputationDataHandle green = + b->SliceInDim(input, /*start_index=*/1, /*limit_index=*/2, /*stride=*/1, + /*dimno=*/channel_dim); + xla::ComputationDataHandle blue = + b->SliceInDim(input, /*start_index=*/2, /*limit_index=*/3, /*stride=*/1, + /*dimno=*/channel_dim); + TensorShape channel_shape = input_shape; + channel_shape.set_dim(channel_dim, 1); + auto hsv = RGBToHSV(context, b, {red, green, blue}, context->input_type(0), + channel_shape); + + auto zero = XlaHelpers::Zero(b, type); + auto one = XlaHelpers::One(b, type); + + auto& hue = hsv[0]; + hue = b->Rem(b->Add(hsv[0], delta), one); + hue = b->Select(b->Lt(hue, zero), b->Rem(b->Add(one, hue), one), hue); + + auto rgb = HSVToRGB(context->builder(), hsv, context->input_type(0)); + + context->SetOutput(0, b->ConcatInDim(rgb, channel_dim)); + } +}; +REGISTER_XLA_OP(Name("AdjustHue"), AdjustHueOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..f36b3f594826c27b7866d956c855aa3638db9cb4 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc @@ -0,0 +1,449 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/array4d.h" +#include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/lib/math/math_util.h" + +namespace tensorflow { +namespace { + +// We implement bilinear interpolation by upsampling followed by convolution. +// The basic idea is as follows. To scale from NxN to RxR: +// +// 1. S := (N - 1) / gcd(N-1, R-1) +// 2. k := (R - 1) / gcd(N-1, R-1) +// 3. Convolution(kxk, stride=S, lhs_dilation=k, padding=k-1) +// +// For example, to Scale from 7x7 -> 15x15: +// +// 1. S := (7-1) / gcd(7-1, 15-1) = 6 / gcd(6, 14) = 6 / 2 = 3 +// 2. k := (15 - 1) / gcd(7-1, 15-1) = 14 / gcd(6, 14) = 14 / 2 = 7 +// 3. Convolution(7x7, stride=3, lhs_dilation=3, padding=2) +// +// +// The 7x7 -> 15x15 case is much too large to write out in full as an +// example. The smallest interesting example is 3x3 -> 4x4. +// +// S := 2 +// k := 3 +// +// 00 03 06 00 00 00 00 00 00 00 00 00 00 00 00 02 04 06 +// 09 12 15 -> 00 00 00 00 00 00 00 00 00 00 00 -> 06 08 10 12 +// 18 21 24 00 00 00 00 00 03 00 00 06 00 00 12 14 16 18 +// 00 00 00 00 00 00 00 00 00 00 00 18 20 22 24 +// 00 00 00 00 00 00 00 00 00 00 00 +// 00 00 09 00 00 12 00 00 15 00 00 +// 00 00 00 00 00 00 00 00 00 00 00 +// 00 00 00 00 00 00 00 00 00 00 00 +// 00 00 18 00 00 21 00 00 24 00 00 +// 00 00 00 00 00 00 00 00 00 00 00 +// 00 00 00 00 00 00 00 00 00 00 00 +// +// with the following convolutional kernel, with stride [2, 2]: +// 1 2 3 2 1 +// 2 4 6 4 2 +// 1/9 * 3 6 9 6 3 +// 2 4 6 4 2 +// 1 2 3 2 1 + +// Computes the size of the convolutional kernel and stride to use when resizing +// from in_size to out_size. +struct ResizeConvolutionDims { + // Size of the kernel to use. + std::vector kernel_size; + + // Stride of the convolution to use. + std::vector stride; +}; +ResizeConvolutionDims ComputeResizeConvolutionParameters( + gtl::ArraySlice in_size, gtl::ArraySlice out_size) { + CHECK_EQ(in_size.size(), out_size.size()); + int num_spatial_dims = in_size.size(); + ResizeConvolutionDims dims; + dims.kernel_size.resize(num_spatial_dims); + dims.stride.resize(num_spatial_dims); + for (int i = 0; i < num_spatial_dims; ++i) { + if (in_size[i] == 1) { + // We must handle input size 1 specially because XLA convolution does + // not allow stride 0. + dims.stride[i] = dims.kernel_size[i] = 1; + } else if (out_size[i] == 1) { + // If in_size[i] > 1 but out_size[i] == 1, then we slice out the first + // entry before resizing. + dims.stride[i] = dims.kernel_size[i] = 1; + } else { + int64 gcd = MathUtil::GCD(static_cast(in_size[i] - 1), + static_cast(out_size[i] - 1)); + dims.stride[i] = (in_size[i] - 1) / gcd; + dims.kernel_size[i] = (out_size[i] - 1) / gcd; + } + } + return dims; +} + +xla::ComputationDataHandle MakeBilinearResizeKernel( + xla::ComputationBuilder* builder, gtl::ArraySlice kernel_size, + int64 channels) { + // Form a 2D convolution kernel like: + // 1 2 3 2 1 + // 2 4 6 4 2 + // 1/9 * 3 6 9 6 3 + // 2 4 6 4 2 + // 1 2 3 2 1 + // by multiplying two 1D kernels of the form: + // 1/3 * [1 2 3 2 1] + auto make_1d_kernel = [](int64 n) { + std::vector kernel(n * 2 - 1); + for (int64 i = 0; i < n; ++i) { + float v = (i + 1.0f) / n; + kernel[i] = v; + kernel[n * 2 - 2 - i] = v; + } + return kernel; + }; + + xla::ComputationDataHandle channels_iota; + // DT_INT32 Iota will always return status::OK(). + TF_CHECK_OK( + XlaHelpers::Iota(builder, DataType::DT_INT32, channels, &channels_iota)); + + auto diag = builder->ConvertElementType( + builder->Eq( + builder->Broadcast(channels_iota, {2 * kernel_size[0] - 1, + 2 * kernel_size[1] - 1, channels}), + channels_iota, /*broadcast_dimensions=*/{2}), + xla::PrimitiveType::F32); + return builder->Mul( + builder->Mul(diag, + builder->ConstantR1(make_1d_kernel(kernel_size[1])), + /*broadcast_dimensions=*/{1}), + builder->ConstantR1(make_1d_kernel(kernel_size[0])), + /*broadcast_dimensions=*/{0}); +} + +xla::ComputationDataHandle ResizeUsingDilationAndConvolution( + xla::ComputationBuilder* builder, const xla::ComputationDataHandle& input, + const int num_spatial_dims, std::vector in_size, + std::vector out_size, const int64 channels) { + // Picture for a 1x3 to 1x4 resize: + // stride = 2, kernel size = 3 + // Input: + // 3 6 9 + // Input with dilation and padding: + // 0 0 3 0 0 6 0 0 9 0 0 + // Convolution kernel: + // 1/3 * [1 2 3 2 1] + // Output: + // 3 5 7 9 + xla::ConvolutionDimensionNumbers dimension_numbers; + dimension_numbers.set_input_batch_dimension(0); + dimension_numbers.set_output_batch_dimension(0); + dimension_numbers.set_input_feature_dimension(3); + dimension_numbers.set_output_feature_dimension(3); + for (int i = 0; i < num_spatial_dims; ++i) { + dimension_numbers.add_input_spatial_dimensions(1 + i); + dimension_numbers.add_output_spatial_dimensions(1 + i); + dimension_numbers.add_kernel_spatial_dimensions(i); + } + dimension_numbers.set_kernel_input_feature_dimension(num_spatial_dims); + dimension_numbers.set_kernel_output_feature_dimension(num_spatial_dims + 1); + + ResizeConvolutionDims dims = + ComputeResizeConvolutionParameters(in_size, out_size); + xla::ComputationDataHandle kernel = + MakeBilinearResizeKernel(builder, dims.kernel_size, channels); + xla::ComputationDataHandle output = builder->ConvGeneralDilated( + input, kernel, dims.stride, + /*padding=*/ + {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1}, + {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}}, + /*lhs_dilation=*/dims.kernel_size, + /*rhs_dilation=*/{1, 1}, dimension_numbers); + + // Add broadcasts to handle expanding from a size == 1 dimension to a + // size > 1 dimension. + for (int i = 0; i < num_spatial_dims; ++i) { + if (in_size[i] == 1 && out_size[i] > 1) { + output = builder->Add(output, builder->ConstantR1(out_size[i], 0), + /*broadcast_dimensions=*/{1 + i}); + } + } + return output; +} + +xla::ComputationDataHandle ResizeUsingDilationAndConvolutionGradOp( + xla::ComputationBuilder* builder, const xla::ComputationDataHandle& grad, + const int num_spatial_dims, std::vector in_size, + std::vector grad_size, const int64 channels) { + ResizeConvolutionDims dims = + ComputeResizeConvolutionParameters(in_size, grad_size); + + // To form the backward convolution, we keep the kernel unchanged (it is + // already symmetric) and swap the roles of strides and LHS dilation. + xla::ConvolutionDimensionNumbers dimension_numbers; + dimension_numbers.set_input_batch_dimension(0); + dimension_numbers.set_output_batch_dimension(0); + dimension_numbers.set_input_feature_dimension(3); + dimension_numbers.set_output_feature_dimension(3); + for (int i = 0; i < num_spatial_dims; ++i) { + dimension_numbers.add_input_spatial_dimensions(1 + i); + dimension_numbers.add_output_spatial_dimensions(1 + i); + dimension_numbers.add_kernel_spatial_dimensions(i); + } + dimension_numbers.set_kernel_input_feature_dimension(num_spatial_dims); + dimension_numbers.set_kernel_output_feature_dimension(num_spatial_dims + 1); + xla::ComputationDataHandle kernel = + MakeBilinearResizeKernel(builder, dims.kernel_size, channels); + + // Broadcast the input kernel where the forward op expanded from a size == 1 + // dimension to a size > 1 dimension. This has the effect of summing the + // gradient contributions in that dimension. + for (int i = 0; i < num_spatial_dims; ++i) { + if (in_size[i] == 1 && grad_size[i] > 1) { + kernel = builder->Add(kernel, builder->ConstantR1(grad_size[i], 0), + /*broadcast_dimensions=*/{i}); + } + } + + xla::ComputationDataHandle output = builder->ConvGeneralDilated( + grad, kernel, /*window_strides=*/dims.kernel_size, + /*padding=*/ + {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1}, + {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}}, + /*lhs_dilation=*/dims.stride, + /*rhs_dilation=*/{1, 1}, dimension_numbers); + + // If in_size[i] > 1 and grad_size[i] == 1, pad the output in dimension i. + // Opposite of the slice performed by the forward op. + xla::PaddingConfig padding = xla::MakeNoPaddingConfig(4); + bool pad_output = false; + for (int i = 0; i < num_spatial_dims; ++i) { + if (in_size[i] > 1 && grad_size[i] == 1) { + pad_output = true; + padding.mutable_dimensions(1 + i)->set_edge_padding_high(in_size[i] - 1); + } + } + if (pad_output) { + output = builder->Pad(output, builder->ConstantR0(0.0f), padding); + } + return output; +} + +class ResizeBilinearOp : public XlaOpKernel { + public: + explicit ResizeBilinearOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("align_corners", &align_corners_)); + OP_REQUIRES( + ctx, align_corners_ == true, + errors::Unimplemented( + "ResizeBilinear with align_corners=False is not yet implemented")); + } + + void Compile(XlaOpKernelContext* ctx) override { + xla::ComputationBuilder* b = ctx->builder(); + + TensorShape input_shape = ctx->InputShape(0); + OP_REQUIRES(ctx, input_shape.dims() == 4, + errors::InvalidArgument("input must be 4-dimensional", + input_shape.DebugString())); + const int64 batch = input_shape.dim_size(0); + std::vector in_size = {input_shape.dim_size(1), + input_shape.dim_size(2)}; + const int64 channels = input_shape.dim_size(3); + OP_REQUIRES(ctx, in_size[0] > 0 && in_size[1] > 0, + errors::InvalidArgument("input size must be positive, got [", + in_size[0], ",", in_size[1], "]")); + + std::vector out_size; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &out_size)); + OP_REQUIRES(ctx, out_size.size() == 2, + errors::InvalidArgument("output size must be length 2, got ", + out_size.size())); + OP_REQUIRES(ctx, out_size[0] > 0 && out_size[1] > 0, + errors::InvalidArgument("output size must be positive, got [", + out_size[0], ",", out_size[1], "]")); + + const int num_spatial_dims = 2; + + xla::ComputationDataHandle input = ctx->Input(0); + + // If in_size[i] > 1 and out_size[i] == 1, slice out the first input in + // dimension i. + std::vector slice_size = in_size; + bool slice_input = false; + for (int i = 0; i < num_spatial_dims; ++i) { + if (in_size[i] > 1 && out_size[i] == 1) { + // If in_size[i] > 1 but out_size[i] == 1, then we slice out the first + // entry before resizing. + slice_input = true; + slice_size[i] = 1; + } + } + if (slice_input) { + input = b->Slice(input, {0, 0, 0, 0}, + {batch, slice_size[0], slice_size[1], channels}, + {1, 1, 1, 1}); + } + + // Output is always type float. + input = b->ConvertElementType(input, xla::F32); + + // Special Case: + // Instead of doing a ResizeUsingDilationAndConvolution directly, + // while (out_size[0]-1) = c * 2^x * (in_size[0]-1) for x>1 c>1, resize the + // image to 2*(in_size[0]-1)+1 x-times and then resize by scale c(int here). + // Instead of resizing directly we resize it iteratively. + // + // Since bilinear resize can be broken down as 2 sequential linear + // operations along different dimensions. + // Given sufficient numerical stability and a cxd is same as resizing axb -> exf -> cxd. + // + // This makes the convolutions kernels smaller and the operation faster. + xla::ComputationDataHandle output = input; + while (in_size != out_size) { + if (in_size[0] != 1 && in_size[1] != 1) { + std::vector k = { + (static_cast(out_size[0]) - 1) / ((in_size[0] - 1) * 2), + (static_cast(out_size[1]) - 1) / ((in_size[1] - 1) * 2)}; + if ((k[0] == std::floor(k[0])) && (k[1] == std::floor(k[1])) && + k[0] > 1 && k[1] > 1) { + std::vector next_out_size = {(in_size[0] - 1) * 2 + 1, + (in_size[1] - 1) * 2 + 1}; + output = ResizeUsingDilationAndConvolution( + b, input, num_spatial_dims, in_size, next_out_size, channels); + input = output; + in_size = next_out_size; + } else { + output = ResizeUsingDilationAndConvolution( + b, input, num_spatial_dims, in_size, out_size, channels); + in_size = out_size; + } + } else { + output = ResizeUsingDilationAndConvolution(b, input, num_spatial_dims, + in_size, out_size, channels); + in_size = out_size; + } + } + + ctx->SetOutput(0, output); + } + + private: + bool align_corners_; +}; + +REGISTER_XLA_OP(Name("ResizeBilinear").CompileTimeConstInput("size"), + ResizeBilinearOp); + +class ResizeBilinearGradOp : public XlaOpKernel { + public: + explicit ResizeBilinearGradOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("align_corners", &align_corners_)); + OP_REQUIRES( + ctx, align_corners_ == true, + errors::Unimplemented("ResizeBilinearGrad with align_corners=False is " + "not yet implemented")); + + DataType output_dtype; + OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &output_dtype)); + OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(output_dtype, &output_type_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + xla::ComputationBuilder* b = ctx->builder(); + + TensorShape input_shape = ctx->InputShape(1); + OP_REQUIRES(ctx, input_shape.dims() == 4, + errors::InvalidArgument("input must be 4-dimensional", + input_shape.DebugString())); + const int64 batch = input_shape.dim_size(0); + std::vector in_size = {input_shape.dim_size(1), + input_shape.dim_size(2)}; + const int64 channels = input_shape.dim_size(3); + OP_REQUIRES(ctx, in_size[0] > 0 && in_size[1] > 0, + errors::InvalidArgument("input size must be positive, got [", + in_size[0], ",", in_size[1], "]")); + + TensorShape grad_shape = ctx->InputShape(0); + OP_REQUIRES(ctx, grad_shape.dims() == 4, + errors::InvalidArgument("gradient must be 4-dimensional", + grad_shape.DebugString())); + const int64 grad_batch = grad_shape.dim_size(0); + const std::vector grad_size = {grad_shape.dim_size(1), + grad_shape.dim_size(2)}; + const int64 grad_channels = grad_shape.dim_size(3); + OP_REQUIRES(ctx, batch == grad_batch, + errors::InvalidArgument( + "activations and gradients must have the same batch size (", + batch, " vs. ", grad_batch, ")")); + OP_REQUIRES(ctx, grad_size[0] > 0 && grad_size[1] > 0, + errors::InvalidArgument("gradient size must be positive, got [", + grad_size[0], ",", grad_size[1], "]")); + OP_REQUIRES( + ctx, channels == grad_channels, + errors::InvalidArgument( + "activations and gradients must have the same number of channels (", + channels, " vs. ", grad_channels, ")")); + + const int num_spatial_dims = 2; + + xla::ComputationDataHandle grad = ctx->Input(0); + + xla::ComputationDataHandle output = grad; + while (in_size != grad_size) { + if (in_size[0] != 1 && in_size[1] != 1) { + std::vector k = { + (static_cast(grad_size[0]) - 1) / ((in_size[0] - 1) * 2), + (static_cast(grad_size[1]) - 1) / ((in_size[1] - 1) * 2)}; + if ((k[0] == std::floor(k[0])) && (k[1] == std::floor(k[1])) && + k[0] > 1 && k[1] > 1) { + std::vector next_grad_size = {(in_size[0] - 1) * 2 + 1, + (in_size[1] - 1) * 2 + 1}; + output = ResizeUsingDilationAndConvolutionGradOp( + b, grad, num_spatial_dims, in_size, next_grad_size, channels); + grad = output; + in_size = next_grad_size; + } else { + output = ResizeUsingDilationAndConvolutionGradOp( + b, grad, num_spatial_dims, in_size, grad_size, channels); + in_size = grad_size; + } + } else { + output = ResizeUsingDilationAndConvolutionGradOp( + b, grad, num_spatial_dims, in_size, grad_size, channels); + in_size = grad_size; + } + } + + output = b->ConvertElementType(output, output_type_); + ctx->SetOutput(0, output); + } + + private: + bool align_corners_; + xla::PrimitiveType output_type_; +}; + +REGISTER_XLA_OP(Name("ResizeBilinearGrad"), ResizeBilinearGradOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops.cc b/tensorflow/compiler/tf2xla/kernels/index_ops.cc index e0dc1870f2a4934c35163f0cc10196e8fcbed9be..7bf4b435f526afa93d8a218b191928acb932cd6b 100644 --- a/tensorflow/compiler/tf2xla/kernels/index_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/index_ops.cc @@ -80,7 +80,10 @@ void XlaArgMinMaxOp::Compile(XlaOpKernelContext* ctx) { XlaArgMaxOp::XlaArgMaxOp(OpKernelConstruction* ctx) : XlaArgMinMaxOp(ctx, /*is_min=*/false) {} -REGISTER_XLA_OP(Name("ArgMax").Device(DEVICE_GPU_XLA_JIT), XlaArgMaxOp); +REGISTER_XLA_OP(Name("ArgMax") + .Device(DEVICE_GPU_XLA_JIT) + .CompileTimeConstInput("dimension"), + XlaArgMaxOp); namespace { @@ -90,7 +93,7 @@ class XlaArgMinOp : public XlaArgMinMaxOp { }; XlaArgMinOp::XlaArgMinOp(OpKernelConstruction* ctx) : XlaArgMinMaxOp(ctx, /*is_min=*/true) {} -REGISTER_XLA_OP(Name("ArgMin"), XlaArgMinOp); +REGISTER_XLA_OP(Name("ArgMin").CompileTimeConstInput("dimension"), XlaArgMinOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc index 20946e247a9459d7c8a0d8a666fef24bd32838f2..b1f3c3c298ce0cadf38b9bda715761fe7e2896d7 100644 --- a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc +++ b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc @@ -56,10 +56,10 @@ class ArgMaxCustomCallOp : public XlaOpKernel { errors::InvalidArgument("dim must be < input rank (", input_shape.dims(), "), but got: ", dim)); const int64 dim_size = input_shape.dim_size(dim); - OP_REQUIRES( - ctx, dim_size > 0, - errors::InvalidArgument("Reduction axis ", dim, " is empty in shape: ", - input_shape.DebugString())); + OP_REQUIRES(ctx, dim_size > 0, + errors::InvalidArgument( + "Reduction axis ", dim, + " is empty in shape: ", input_shape.DebugString())); // The output shape is the input shape contracted along dim. TensorShape output_shape; @@ -113,9 +113,11 @@ class ArgMaxCustomCallOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(ArgMaxCustomCallOp); }; -REGISTER_XLA_OP( - Name("ArgMax").TypeConstraint("T", DT_FLOAT).Device(DEVICE_CPU_XLA_JIT), - ArgMaxCustomCallOp); +REGISTER_XLA_OP(Name("ArgMax") + .TypeConstraint("T", DT_FLOAT) + .Device(DEVICE_CPU_XLA_JIT) + .CompileTimeConstInput("dimension"), + ArgMaxCustomCallOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/matmul_op.cc b/tensorflow/compiler/tf2xla/kernels/matmul_op.cc index fcef497e5845d9080bc83b54e92dcf2fdecf5f12..886baf8115243a22b7255a3961c914d4cf6c2ed5 100644 --- a/tensorflow/compiler/tf2xla/kernels/matmul_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/matmul_op.cc @@ -23,16 +23,18 @@ limitations under the License. namespace tensorflow { namespace { -constexpr std::array kMatmulTypes = { - {DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64}}; +constexpr std::array kMatmulTypes = { + {DT_HALF, DT_BFLOAT16, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64}}; class MatMulOp : public XlaOpKernel { public: explicit MatMulOp(OpKernelConstruction* ctx, bool is_sparse = false) - : XlaOpKernel(ctx) { + : XlaOpKernel(ctx), is_sparse_(is_sparse) { OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_a", &transpose_a_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_b", &transpose_b_)); if (is_sparse) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("Ta", &a_type_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("Tb", &b_type_)); // SparseMatMul is actually dense matmul with a hint that one or // both of the inputs may contain a lot of zeroes. On CPU these // inputs are dynamically converted to sparse representation @@ -66,14 +68,25 @@ class MatMulOp : public XlaOpKernel { xla::ComputationDataHandle a = ctx->Input(0); xla::ComputationDataHandle b = ctx->Input(1); + if (is_sparse_) { + if (a_type_ == DT_BFLOAT16) { + a = ctx->builder()->ConvertElementType(a, xla::F32); + } + if (b_type_ == DT_BFLOAT16) { + b = ctx->builder()->ConvertElementType(b, xla::F32); + } + } auto lhs = (transpose_a_) ? ctx->builder()->Transpose(a, {1, 0}) : a; auto rhs = (transpose_b_) ? ctx->builder()->Transpose(b, {1, 0}) : b; ctx->SetOutput(0, ctx->builder()->Dot(lhs, rhs)); } private: + bool is_sparse_; bool transpose_a_; bool transpose_b_; + DataType a_type_; + DataType b_type_; }; REGISTER_XLA_OP(Name("MatMul").TypeConstraint("T", kMatmulTypes), MatMulOp); @@ -85,10 +98,7 @@ class SparseMatMulOp : public MatMulOp { ~SparseMatMulOp() override = default; }; -REGISTER_XLA_OP(Name("SparseMatMul") - .TypeConstraint("Ta", kFloatTypes) - .TypeConstraint("Tb", kFloatTypes), - SparseMatMulOp); +REGISTER_XLA_OP(Name("SparseMatMul"), SparseMatMulOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc b/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc index bea1d1600b5b5fc0c44f0208d394f25061ecbb68..05a36a031ad73be289604da1b7e56203ff12fbf5 100644 --- a/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc @@ -92,7 +92,8 @@ class MirrorPadOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(MirrorPadOp); }; -REGISTER_XLA_OP(Name("MirrorPad"), MirrorPadOp); +REGISTER_XLA_OP(Name("MirrorPad").CompileTimeConstInput("paddings"), + MirrorPadOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/one_hot_op.cc b/tensorflow/compiler/tf2xla/kernels/one_hot_op.cc index 2a9cfcb2eb86399bd446db8d591012a7a2f3d667..9f7c9913802d311895479b914b66553e135aa426 100644 --- a/tensorflow/compiler/tf2xla/kernels/one_hot_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/one_hot_op.cc @@ -76,7 +76,7 @@ class OneHotOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(OneHotOp); }; -REGISTER_XLA_OP(Name("OneHot"), OneHotOp); +REGISTER_XLA_OP(Name("OneHot").CompileTimeConstInput("depth"), OneHotOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/pad_op.cc b/tensorflow/compiler/tf2xla/kernels/pad_op.cc index d841bd37b33c31dbc156fa824ff62a58169a99cb..791351637aee61c5fdd911dd8a48959990514395 100644 --- a/tensorflow/compiler/tf2xla/kernels/pad_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/pad_op.cc @@ -83,8 +83,8 @@ class PadOp : public XlaOpKernel { } }; -REGISTER_XLA_OP(Name("Pad"), PadOp); -REGISTER_XLA_OP(Name("PadV2"), PadOp); +REGISTER_XLA_OP(Name("Pad").CompileTimeConstInput("paddings"), PadOp); +REGISTER_XLA_OP(Name("PadV2").CompileTimeConstInput("paddings"), PadOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc index 2b6053d19dd64a0c893b3613133c8f4691f9cd27..0b5a38967aeb5b4cd66de5220e2c764371440c2d 100644 --- a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc @@ -455,14 +455,16 @@ class AvgPool2DGradOp : public AvgPoolGradOp { errors::InvalidArgument("Invalid data format")); } }; -REGISTER_XLA_OP(Name("AvgPoolGrad"), AvgPool2DGradOp); +REGISTER_XLA_OP(Name("AvgPoolGrad").CompileTimeConstInput("orig_input_shape"), + AvgPool2DGradOp); class AvgPool3DGradOp : public AvgPoolGradOp { public: explicit AvgPool3DGradOp(OpKernelConstruction* ctx) : AvgPoolGradOp(ctx, /*num_spatial_dims=*/3) {} }; -REGISTER_XLA_OP(Name("AvgPool3DGrad"), AvgPool3DGradOp); +REGISTER_XLA_OP(Name("AvgPool3DGrad").CompileTimeConstInput("orig_input_shape"), + AvgPool3DGradOp); } // anonymous namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/random_ops.cc b/tensorflow/compiler/tf2xla/kernels/random_ops.cc index 2421825ead17a3acee9f145f00904d382fb656f4..c0994c434bca5174eaee7b9e63e10432d9c2ed8d 100644 --- a/tensorflow/compiler/tf2xla/kernels/random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/random_ops.cc @@ -52,7 +52,8 @@ class RandomUniformOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(RandomUniformOp); }; -REGISTER_XLA_OP(Name("RandomUniform"), RandomUniformOp); +REGISTER_XLA_OP(Name("RandomUniform").CompileTimeConstInput("shape"), + RandomUniformOp); class RandomUniformIntOp : public XlaOpKernel { public: @@ -83,7 +84,8 @@ class RandomUniformIntOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(RandomUniformIntOp); }; -REGISTER_XLA_OP(Name("RandomUniformInt"), RandomUniformIntOp); +REGISTER_XLA_OP(Name("RandomUniformInt").CompileTimeConstInput("shape"), + RandomUniformIntOp); class RandomStandardNormalOp : public XlaOpKernel { public: @@ -111,7 +113,8 @@ class RandomStandardNormalOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(RandomStandardNormalOp); }; -REGISTER_XLA_OP(Name("RandomStandardNormal"), RandomStandardNormalOp); +REGISTER_XLA_OP(Name("RandomStandardNormal").CompileTimeConstInput("shape"), + RandomStandardNormalOp); class TruncatedNormalOp : public XlaOpKernel { public: @@ -183,7 +186,8 @@ class TruncatedNormalOp : public XlaOpKernel { } }; -REGISTER_XLA_OP(Name("TruncatedNormal"), TruncatedNormalOp); +REGISTER_XLA_OP(Name("TruncatedNormal").CompileTimeConstInput("shape"), + TruncatedNormalOp); } // anonymous namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc index 647b6274083cf8886af6c451b746416445a4a2b2..03b13b2924f4b81c1017804c91d5ffb81c44ea0b 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc @@ -35,7 +35,7 @@ class SumOp : public XlaReductionOp { } }; -REGISTER_XLA_OP(Name("Sum"), SumOp); +REGISTER_XLA_OP(Name("Sum").CompileTimeConstInput("reduction_indices"), SumOp); class ProdOp : public XlaReductionOp { public: @@ -53,7 +53,8 @@ class ProdOp : public XlaReductionOp { } }; -REGISTER_XLA_OP(Name("Prod"), ProdOp); +REGISTER_XLA_OP(Name("Prod").CompileTimeConstInput("reduction_indices"), + ProdOp); class MinOp : public XlaReductionOp { public: @@ -73,7 +74,7 @@ class MinOp : public XlaReductionOp { } }; -REGISTER_XLA_OP(Name("Min"), MinOp); +REGISTER_XLA_OP(Name("Min").CompileTimeConstInput("reduction_indices"), MinOp); class MaxOp : public XlaReductionOp { public: @@ -93,7 +94,7 @@ class MaxOp : public XlaReductionOp { } }; -REGISTER_XLA_OP(Name("Max"), MaxOp); +REGISTER_XLA_OP(Name("Max").CompileTimeConstInput("reduction_indices"), MaxOp); class MeanOp : public XlaReductionOp { public: @@ -115,7 +116,8 @@ class MeanOp : public XlaReductionOp { } }; -REGISTER_XLA_OP(Name("Mean"), MeanOp); +REGISTER_XLA_OP(Name("Mean").CompileTimeConstInput("reduction_indices"), + MeanOp); class AllOp : public XlaReductionOp { public: @@ -133,7 +135,7 @@ class AllOp : public XlaReductionOp { } }; -REGISTER_XLA_OP(Name("All"), AllOp); +REGISTER_XLA_OP(Name("All").CompileTimeConstInput("reduction_indices"), AllOp); class AnyOp : public XlaReductionOp { public: @@ -151,7 +153,7 @@ class AnyOp : public XlaReductionOp { } }; -REGISTER_XLA_OP(Name("Any"), AnyOp); +REGISTER_XLA_OP(Name("Any").CompileTimeConstInput("reduction_indices"), AnyOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc index 5952e752724d1e6953dd4dbb6a8099b847c64d08..af4d64b159c09ed7e01017f25a2b23e58542dc3c 100644 --- a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc @@ -95,7 +95,7 @@ class ReshapeOp : public XlaOpKernel { } }; -REGISTER_XLA_OP(Name("Reshape"), ReshapeOp); +REGISTER_XLA_OP(Name("Reshape").CompileTimeConstInput("shape"), ReshapeOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/reverse_op.cc b/tensorflow/compiler/tf2xla/kernels/reverse_op.cc index 7489321f72f50c8f55f8da9dabb9f4b5c7797195..e51d386926763ecbb5a943dfb6f872e78901dc69 100644 --- a/tensorflow/compiler/tf2xla/kernels/reverse_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reverse_op.cc @@ -16,7 +16,6 @@ limitations under the License. // XLA-specific reverse Op. #include "tensorflow/compiler/tf2xla/type_util.h" -#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" @@ -53,7 +52,8 @@ class ReverseOp : public XlaOpKernel { xla::Literal lax; OP_REQUIRES_OK(ctx, ctx->ConstantInputReshaped(1, {x_shape.dims()}, &lax)); std::vector revdims(x_shape.dims()); - std::copy(lax.preds().begin(), lax.preds().end(), revdims.begin()); + std::copy(lax.data().begin(), lax.data().end(), + revdims.begin()); std::vector dimensions; for (int d = 0; d < x_shape.dims(); ++d) { @@ -66,7 +66,7 @@ class ReverseOp : public XlaOpKernel { } }; -REGISTER_XLA_OP(Name("Reverse"), ReverseOp); +REGISTER_XLA_OP(Name("Reverse").CompileTimeConstInput("dims"), ReverseOp); class ReverseV2Op : public XlaOpKernel { public: @@ -104,7 +104,7 @@ class ReverseV2Op : public XlaOpKernel { } }; -REGISTER_XLA_OP(Name("ReverseV2"), ReverseV2Op); +REGISTER_XLA_OP(Name("ReverseV2").CompileTimeConstInput("axis"), ReverseV2Op); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..ee4a94164c4a43828eb4feedbfa9d1a9e231ef8f --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc @@ -0,0 +1,147 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/partial_tensor_shape.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/bounds_check.h" +#include "tensorflow/core/kernels/concat_lib.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace { + +// TODO(phawkins): implement double-sized windowed reductions in XLA and remove +// the type constraint. +constexpr std::array kScanOpTypes = { + {DT_HALF, DT_BFLOAT16, DT_FLOAT}}; + +class ScanOp : public XlaOpKernel { + public: + ScanOp(OpKernelConstruction* ctx, bool sum) : XlaOpKernel(ctx), sum_(sum) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("reverse", &reverse_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("exclusive", &exclusive_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + const TensorShape input_shape = ctx->InputShape(0); + const TensorShape tensor_axis_shape = ctx->InputShape(1); + + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(tensor_axis_shape), + errors::InvalidArgument("ScanOp: axis must be a scalar, not ", + tensor_axis_shape.DebugString())); + + int64 axis; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &axis)); + if (axis < 0) { + axis += input_shape.dims(); + } + OP_REQUIRES( + ctx, FastBoundsCheck(axis, input_shape.dims()), + errors::InvalidArgument("ScanOp: Expected scan axis in the range [", + -input_shape.dims(), ", ", input_shape.dims(), + "), but got ", axis)); + + DataType dtype = ctx->input_type(0); + + if (input_shape.num_elements() == 0) { + // Exit early if there is nothing to compute. + ctx->SetOutput(0, ctx->Input(0)); + return; + } + + xla::ComputationBuilder* builder = ctx->builder(); + + std::vector window_strides(input_shape.dims(), 1); + std::vector window_dims(input_shape.dims(), 1); + window_dims[axis] = input_shape.dim_size(axis); + + std::vector> padding(input_shape.dims(), {0, 0}); + padding[axis].first = input_shape.dim_size(axis) - 1; + // In exclusive mode, add an extra padding element so there is a complete + // window of padding before the data starts. + if (exclusive_) { + ++padding[axis].first; + } + if (reverse_) { + std::swap(padding[axis].first, padding[axis].second); + } + + xla::ComputationDataHandle input = ctx->Input(0); + xla::ComputationDataHandle init; + const xla::Computation* reducer; + if (sum_) { + init = XlaHelpers::Zero(builder, dtype); + reducer = ctx->GetOrCreateAdd(dtype); + } else { + init = XlaHelpers::One(builder, dtype); + reducer = ctx->GetOrCreateMul(dtype); + } + auto output = builder->ReduceWindowWithGeneralPadding( + ctx->Input(0), init, *reducer, window_dims, window_strides, padding); + + // In exclusive mode, we have computed an extra element containing the sum + // of all the input elements. Slice off this extra "last" element. + if (exclusive_) { + if (reverse_) { + output = builder->SliceInDim(output, 1, input_shape.dim_size(axis) + 1, + 1, axis); + + } else { + output = + builder->SliceInDim(output, 0, input_shape.dim_size(axis), 1, axis); + } + } + ctx->SetOutput(0, output); + } + + private: + const bool sum_; // True=cumulative sum. False=cumulative product. + bool reverse_; + bool exclusive_; +}; + +class CumsumOp : public ScanOp { + public: + explicit CumsumOp(OpKernelConstruction* ctx) : ScanOp(ctx, /*sum=*/true) {} +}; +REGISTER_XLA_OP(Name("Cumsum") + .TypeConstraint("T", kScanOpTypes) + .CompileTimeConstInput("axis"), + CumsumOp); + +class CumprodOp : public ScanOp { + public: + explicit CumprodOp(OpKernelConstruction* ctx) : ScanOp(ctx, /*sum=*/false) {} +}; +REGISTER_XLA_OP(Name("Cumprod") + .TypeConstraint("T", kScanOpTypes) + .CompileTimeConstInput("axis"), + CumprodOp); + +} // anonymous namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc index 8a67c0b67fcd95f4841c5e011a4e51638eea5b0f..c220edd588071ef262621784015d34cd475b2918 100644 --- a/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc @@ -43,24 +43,11 @@ xla::ComputationDataHandle XlaComputeScatterAddDynamicSlice( TensorShape out_shape(flat_shape); out_shape.set_dim(0, num_segments); - // TODO(b/37575001) The tensor in which we construct the output during - // the loop must have rank >= 3 as a workaround for lowering issues. - int64 extra_dims = 0; - if (out_shape.dims() < 3) { - extra_dims = 3 - out_shape.dims(); - } - TensorShape loop_out_shape; - for (int64 k = 0; k < extra_dims; ++k) { - loop_out_shape.AddDim(1); - } - loop_out_shape.AppendShape(out_shape); - // Slices from the input data are same shape as the input data, except dim 0. TensorShape slice_shape(flat_shape); slice_shape.set_dim(0, 1); - // slices are reshaped into the rank >= 3 shape of the loop-carried output - TensorShape loop_out_slice_shape(loop_out_shape); - loop_out_slice_shape.set_dim(extra_dims, 1); + TensorShape loop_out_slice_shape(out_shape); + loop_out_slice_shape.set_dim(0, 1); // Construct the initial values of the loop-carried variables // Flatten the indices into 1-D for ease of iteration. @@ -70,7 +57,7 @@ xla::ComputationDataHandle XlaComputeScatterAddDynamicSlice( auto init_i = builder->ConstantR0(0); auto init_out = builder->Broadcast(XlaHelpers::Zero(builder, dtype), - loop_out_shape.dim_sizes()); + out_shape.dim_sizes()); xla::PrimitiveType ptype; TF_CHECK_OK(DataTypeToPrimitiveType(dtype, &ptype)); @@ -83,7 +70,7 @@ xla::ComputationDataHandle XlaComputeScatterAddDynamicSlice( // The scatter indices tensor is loop invariant. xla::ShapeUtil::MakeShape(xla::S32, {indices_shape.num_elements()}), // The output data array is updated each loop iteration. - xla::ShapeUtil::MakeShape(ptype, loop_out_shape.dim_sizes())}); + xla::ShapeUtil::MakeShape(ptype, out_shape.dim_sizes())}); xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(tuple_shapes); auto init = builder->Tuple({init_i, data_flat, indices_1d, init_out}); @@ -95,7 +82,6 @@ xla::ComputationDataHandle XlaComputeScatterAddDynamicSlice( condb.Parameter(0, tuple_shape, "ScatterAddWhileTuple"), 0), condb.ConstantR0(indices_shape.num_elements())); auto cond_status = condb.Build(); - // TF_CHECK_OK(cond_status); auto cond = cond_status.ConsumeValueOrDie(); // Construct the while loop body's function. The implementation of scatter is: @@ -123,11 +109,9 @@ xla::ComputationDataHandle XlaComputeScatterAddDynamicSlice( loop_out_slice_shape.dim_sizes()); // Index into the output array. - // Construct the index into the R3+ output array 0, ..., , 0, ... - std::vector out_index_vals( - loop_out_shape.dims(), zero); - out_index_vals[extra_dims] = - bodyb.DynamicSlice(idcs, bodyb.Reshape(i, {1}), {1}); + std::vector out_index_vals(out_shape.dims(), + zero); + out_index_vals[0] = bodyb.DynamicSlice(idcs, bodyb.Reshape(i, {1}), {1}); auto out_index = bodyb.ConcatInDim(out_index_vals, 0); // Slice the output array, update value, and update the output slice. @@ -142,12 +126,10 @@ xla::ComputationDataHandle XlaComputeScatterAddDynamicSlice( bodyb.Tuple({ip1, data, idcs, updated_output}); } auto body_status = bodyb.Build(); - // TF_CHECK_OK(body_status); auto body = body_status.ConsumeValueOrDie(); auto gather_while = builder->While(cond, body, init); - auto updated_output = builder->GetTupleElement(gather_while, 3); - return builder->Reshape(updated_output, out_shape.dim_sizes()); + return builder->GetTupleElement(gather_while, 3); } namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc b/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc index c2b0e1bb4c1a141d0ab3f5b3ff5397d9da620bd8..2c31f8d90891924f6f86a54ccf548de4df87f3bd 100644 --- a/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc @@ -138,7 +138,11 @@ class RangeOp : public XlaOpKernel { } }; -REGISTER_XLA_OP(Name("Range"), RangeOp); +REGISTER_XLA_OP(Name("Range") + .CompileTimeConstInput("start") + .CompileTimeConstInput("limit") + .CompileTimeConstInput("delta"), + RangeOp); class LinSpaceOp : public XlaOpKernel { public: @@ -207,7 +211,11 @@ class LinSpaceOp : public XlaOpKernel { } }; -REGISTER_XLA_OP(Name("LinSpace"), LinSpaceOp); +REGISTER_XLA_OP(Name("LinSpace") + .CompileTimeConstInput("start") + .CompileTimeConstInput("stop") + .CompileTimeConstInput("num"), + LinSpaceOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/shape_op.cc b/tensorflow/compiler/tf2xla/kernels/shape_op.cc index 24a99f253d6dc8bb699fff587c363b12c227e821..05354bca5bb089703fdcceb6f44648bbb98d004b 100644 --- a/tensorflow/compiler/tf2xla/kernels/shape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/shape_op.cc @@ -15,6 +15,7 @@ limitations under the License. // XLA-specific Shape Ops. +#include "tensorflow/compiler/tf2xla/kernels/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" @@ -27,56 +28,42 @@ namespace { class ShapeOp : public XlaOpKernel { public: - explicit ShapeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + explicit ShapeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("out_type", &out_dtype_)); + } void Compile(XlaOpKernelContext* ctx) override { const TensorShape input_shape = ctx->InputShape(0); - const int rank = input_shape.dims(); - Tensor shape_constant(DT_INT32, TensorShape({rank})); - auto vec = shape_constant.vec(); - // TODO(dga): support int64. b/28119922. - for (int i = 0; i < rank; ++i) { - int64 dim_size = input_shape.dim_size(i); - OP_REQUIRES( - ctx, FastBoundsCheck(dim_size, std::numeric_limits::max()), - errors::InvalidArgument("Shape does not support tensors > int32max", - " but dim ", i, " is ", dim_size)); - vec(i) = static_cast(dim_size); - } - + Tensor shape_constant(out_dtype_, TensorShape({input_shape.dims()})); + OP_REQUIRES_OK(ctx, TensorShapeToConstant(input_shape, &shape_constant)); ctx->SetConstantOutput(0, shape_constant); } + + private: + DataType out_dtype_; }; REGISTER_XLA_OP(Name("Shape"), ShapeOp); class ShapeNOp : public XlaOpKernel { public: - explicit ShapeNOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + explicit ShapeNOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("out_type", &out_dtype_)); + } void Compile(XlaOpKernelContext* ctx) override { for (int i = 0; i < ctx->num_inputs(); ++i) { - const TensorShape shape = ctx->InputShape(i); - const int dims = shape.dims(); - Tensor shape_constant(DT_INT32, TensorShape({dims})); - auto vec = shape_constant.vec(); - - // TODO(dga): support int64. b/28119922. - for (int j = 0; j < dims; ++j) { - int64 dim_size = shape.dim_size(j); - OP_REQUIRES( - ctx, FastBoundsCheck(dim_size, std::numeric_limits::max()), - errors::InvalidArgument("Shape does not support tensors > int32max", - " but shape ", i, " dim ", j, " is ", - dim_size)); - vec(j) = static_cast(dim_size); - } - + const TensorShape input_shape = ctx->InputShape(i); + Tensor shape_constant(out_dtype_, TensorShape({input_shape.dims()})); + OP_REQUIRES_OK(ctx, TensorShapeToConstant(input_shape, &shape_constant)); ctx->SetConstantOutput(i, shape_constant); } } bool IsExpensive() override { return false; } + + private: + DataType out_dtype_; }; REGISTER_XLA_OP(Name("ShapeN"), ShapeNOp); @@ -134,7 +121,7 @@ class ExpandDimsOp : public XlaOpKernel { xla::Literal literal; OP_REQUIRES_OK(ctx, ctx->ConstantInputReshaped(1, {1}, &literal)); - int dim = literal.s32s(0); + int dim = literal.data()[0]; OP_REQUIRES(ctx, (dim >= -1 - input_shape.dims() && dim <= input_shape.dims()), @@ -163,7 +150,7 @@ class ExpandDimsOp : public XlaOpKernel { ctx->SetOutput(0, ctx->builder()->Reshape(ctx->Input(0), new_shape)); } }; -REGISTER_XLA_OP(Name("ExpandDims"), ExpandDimsOp); +REGISTER_XLA_OP(Name("ExpandDims").CompileTimeConstInput("dim"), ExpandDimsOp); class SqueezeOp : public XlaOpKernel { public: diff --git a/tensorflow/compiler/tf2xla/kernels/shape_util.cc b/tensorflow/compiler/tf2xla/kernels/shape_util.cc new file mode 100644 index 0000000000000000000000000000000000000000..76ea5f525598f511f295eb5a30f3cf603fbf57aa --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/shape_util.cc @@ -0,0 +1,48 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/kernels/shape_util.h" + +#include + +#include "tensorflow/core/kernels/bounds_check.h" + +namespace tensorflow { + +Status TensorShapeToConstant(const TensorShape& input_shape, + Tensor* shape_constant) { + const int dims = input_shape.dims(); + if (shape_constant->dtype() == DT_INT32) { + auto vec = shape_constant->vec(); + for (int i = 0; i < dims; ++i) { + int64 dim_size = input_shape.dim_size(i); + if (!FastBoundsCheck(dim_size, std::numeric_limits::max())) { + return errors::InvalidArgument( + "Shape with out_type=int32 does not support tensors > int32max", + " but dim ", i, " is ", dim_size); + } + vec(i) = static_cast(dim_size); + } + } else { + auto vec = shape_constant->vec(); + for (int i = 0; i < dims; ++i) { + int64 dim_size = input_shape.dim_size(i); + vec(i) = dim_size; + } + } + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/shape_util.h b/tensorflow/compiler/tf2xla/kernels/shape_util.h new file mode 100644 index 0000000000000000000000000000000000000000..575086e118080f6799a54d3ae6409b2b641c4341 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/shape_util.h @@ -0,0 +1,34 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_SHAPE_UTIL_H_ +#define TENSORFLOW_COMPILER_TF2XLA_KERNELS_SHAPE_UTIL_H_ + +#include + +#include "tensorflow/core/framework/tensor.h" + +namespace tensorflow { + +// Converts a TensorShape to a constant Tensor. +// +// The input TensorShape input_shape is used to populate the elements of +// shape_constant, which is modified in place. +Status TensorShapeToConstant(const TensorShape& input_shape, + Tensor* shape_constant); + +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_COMPILER_TF2XLA_KERNELS_SHAPE_UTIL_H_ diff --git a/tensorflow/compiler/tf2xla/kernels/slice_op.cc b/tensorflow/compiler/tf2xla/kernels/slice_op.cc index fbe8c78d8fb5f800967942555531a50937cad0ca..be1e97bf26fa4cde1b741c8d0b843a85ce33a59c 100644 --- a/tensorflow/compiler/tf2xla/kernels/slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/slice_op.cc @@ -112,7 +112,9 @@ class SliceOp : public XlaOpKernel { } }; -REGISTER_XLA_OP(Name("Slice"), SliceOp); +REGISTER_XLA_OP( + Name("Slice").CompileTimeConstInput("begin").CompileTimeConstInput("size"), + SliceOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc b/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc index 83a87f19a718ce86a105e3c33ab9eaf0faff3a76..01b46e160d1f1f10a43faf7ca35afb42dfde6e33 100644 --- a/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc @@ -162,7 +162,10 @@ class SpaceToBatchNDOp : public XlaOpKernel { block_shape, paddings); } }; -REGISTER_XLA_OP(Name("SpaceToBatchND"), SpaceToBatchNDOp); +REGISTER_XLA_OP(Name("SpaceToBatchND") + .CompileTimeConstInput("paddings") + .CompileTimeConstInput("block_shape"), + SpaceToBatchNDOp); class SpaceToBatchOp : public XlaOpKernel { public: @@ -184,7 +187,8 @@ class SpaceToBatchOp : public XlaOpKernel { private: int block_size_; }; -REGISTER_XLA_OP(Name("SpaceToBatch"), SpaceToBatchOp); +REGISTER_XLA_OP(Name("SpaceToBatch").CompileTimeConstInput("paddings"), + SpaceToBatchOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc b/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc index 89befda346ec06fec23ab1d1c9d910ded8cd806d..806fda632cde64c1b37ae3b9199028d6b6b0a215 100644 --- a/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/core/util/tensor_format.h" namespace tensorflow { namespace { @@ -23,6 +24,16 @@ namespace { class SpaceToDepthOp : public XlaOpKernel { public: explicit SpaceToDepthOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + string data_format_str; + OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str)); + OP_REQUIRES(ctx, FormatFromString(data_format_str, &data_format_), + errors::InvalidArgument("Invalid data format")); + + OP_REQUIRES(ctx, data_format_ == FORMAT_NCHW || data_format_ == FORMAT_NHWC, + errors::InvalidArgument("Unsupported data format ", + ToString(data_format_), + "; expected formats NHWC or NCHW")); + OP_REQUIRES_OK(ctx, ctx->GetAttr("block_size", &block_size_)); OP_REQUIRES( ctx, block_size_ > 1, @@ -31,34 +42,100 @@ class SpaceToDepthOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { const TensorShape input_tensor_shape = ctx->InputShape(0); - // The input is presumed to be [batch, height, width, depth] int input_rank = input_tensor_shape.dims(); static const int kRequiredDims = 4; OP_REQUIRES(ctx, kRequiredDims == input_rank, - errors::InvalidArgument("Input rank should be: ", kRequiredDims, - " instead of: ", input_rank)); + errors::InvalidArgument("Input rank should be ", kRequiredDims, + "; got ", input_rank)); const gtl::InlinedVector input_shape = input_tensor_shape.dim_sizes(); xla::ComputationBuilder* b = ctx->builder(); xla::ComputationDataHandle input = ctx->Input(0); + int feature_dim = GetTensorFeatureDimIndex(input_rank, data_format_); + int num_spatial_dims = GetTensorSpatialDims(input_rank, data_format_); + + std::vector reshaped_shape; + std::vector transpose_order; + std::vector output_shape; + reshaped_shape.reserve(input_rank); + transpose_order.reserve(input_rank); + output_shape.reserve(input_rank); + if (data_format_ == FORMAT_NHWC) { + int64 block_elems = 1; + for (int i = 0; i < num_spatial_dims; ++i) { + OP_REQUIRES(ctx, input_shape[1 + i] % block_size_ == 0, + errors::InvalidArgument( + "input shape[", 1 + i, "]=", input_shape[1 + i], + " is not divisible by block_size=", block_size_)); + block_elems *= block_size_; + } + + reshaped_shape.push_back(input_shape[0]); + for (int i = 0; i < num_spatial_dims; ++i) { + reshaped_shape.push_back(input_shape[1 + i] / block_size_); + reshaped_shape.push_back(block_size_); + } + reshaped_shape.push_back(input_shape[feature_dim]); + + transpose_order.push_back(0); + for (int i = 0; i < num_spatial_dims; ++i) { + transpose_order.push_back(i * 2 + 1); + } + for (int i = 0; i < num_spatial_dims; ++i) { + transpose_order.push_back(i * 2 + 2); + } + transpose_order.push_back(feature_dim + num_spatial_dims); + + output_shape.push_back(input_shape[0]); + for (int i = 0; i < num_spatial_dims; ++i) { + output_shape.push_back(input_shape[1 + i] / block_size_); + } + output_shape.push_back(input_shape[feature_dim] * block_elems); + } else { + // FORMAT_NCHW + int64 block_elems = 1; + for (int i = 0; i < num_spatial_dims; ++i) { + OP_REQUIRES(ctx, input_shape[2 + i] % block_size_ == 0, + errors::InvalidArgument( + "input shape[", 2 + i, "]=", input_shape[2 + i], + " is not divisible by block_size=", block_size_)); + block_elems *= block_size_; + } + + reshaped_shape.push_back(input_shape[0]); + reshaped_shape.push_back(input_shape[feature_dim]); + for (int i = 0; i < num_spatial_dims; ++i) { + reshaped_shape.push_back(input_shape[2 + i] / block_size_); + reshaped_shape.push_back(block_size_); + } + + transpose_order.push_back(0); + for (int i = 0; i < num_spatial_dims; ++i) { + transpose_order.push_back(i * 2 + 3); + } + transpose_order.push_back(feature_dim); + for (int i = 0; i < num_spatial_dims; ++i) { + transpose_order.push_back(i * 2 + 2); + } + + output_shape.push_back(input_shape[0]); + output_shape.push_back(input_shape[feature_dim] * block_elems); + for (int i = 0; i < num_spatial_dims; ++i) { + output_shape.push_back(input_shape[2 + i] / block_size_); + } + } + + // Note: comments are given in NHWC format; NCHW is similar with a different + // dimension order. // 1. Reshape `input` to `reshaped` of shape: // // [batch, // input_shape[1] / block_size_, block_size_, // input_shape[2] / block_size_, block_size_, // depth] - const int block_rank = 2; - for (int i = 0; i < block_rank; ++i) { - OP_REQUIRES(ctx, input_shape[1 + i] % block_size_ == 0, - errors::InvalidArgument( - "input shape[", 1 + i, "]=", input_shape[1 + i], - " is not divisible by block_size=", block_size_)); - } - xla::ComputationDataHandle reshaped = b->Reshape( - input, {input_shape[0], input_shape[1] / block_size_, block_size_, - input_shape[2] / block_size_, block_size_, input_shape[3]}); + xla::ComputationDataHandle reshaped = b->Reshape(input, reshaped_shape); // 2. Permute dimensions of `reshaped` to produce // `permuted_reshaped` of shape: @@ -69,7 +146,7 @@ class SpaceToDepthOp : public XlaOpKernel { // block_size_, block_size_, // depth] xla::ComputationDataHandle permuted_reshaped = - b->Transpose(reshaped, {0, 1, 3, 2, 4, 5}); + b->Transpose(reshaped, transpose_order); // 3. Reshape `permuted_reshaped` to flatten `block_shape` into the // batch dimension, producing an output tensor of shape: @@ -79,15 +156,14 @@ class SpaceToDepthOp : public XlaOpKernel { // input_shape[2] / block_size_, // block_size_ * block_size_ * depth] // - xla::ComputationDataHandle output = b->Reshape( - permuted_reshaped, {input_shape[0], input_shape[1] / block_size_, - input_shape[2] / block_size_, - block_size_ * block_size_ * input_shape[3]}); + xla::ComputationDataHandle output = + b->Reshape(permuted_reshaped, output_shape); ctx->SetOutput(0, output); } private: + TensorFormat data_format_; int block_size_; }; REGISTER_XLA_OP(Name("SpaceToDepth"), SpaceToDepthOp); diff --git a/tensorflow/compiler/tf2xla/kernels/split_op.cc b/tensorflow/compiler/tf2xla/kernels/split_op.cc index 795eb1794f577e0f7fd2a2068878e540ff0c1a1d..79c435c90a1f57250be90c2c2523bf3d7d231461 100644 --- a/tensorflow/compiler/tf2xla/kernels/split_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/split_op.cc @@ -103,7 +103,7 @@ class SplitOp : public XlaOpKernel { } }; -REGISTER_XLA_OP(Name("Split"), SplitOp); +REGISTER_XLA_OP(Name("Split").CompileTimeConstInput("split_dim"), SplitOp); class SplitVOp : public XlaOpKernel { public: @@ -142,8 +142,9 @@ class SplitVOp : public XlaOpKernel { int neg_one_dim = -1; std::vector split_sizes_vec(num_split, -1); const TensorShape split_size_shape = ctx->InputShape(1); - OP_REQUIRES(ctx, split_size_shape.dims() == 1 && - split_size_shape.num_elements() == num_split, + OP_REQUIRES(ctx, + split_size_shape.dims() == 1 && + split_size_shape.num_elements() == num_split, errors::InvalidArgument( "shape of tensor describing " " the output must have dimension 1 and the same " @@ -171,10 +172,11 @@ class SplitVOp : public XlaOpKernel { } OP_REQUIRES( - ctx, (neg_one_dim == -1 && - total_split_size == input_shape.dim_size(split_dim)) || - (neg_one_dim >= 0 && - total_split_size <= input_shape.dim_size(split_dim)), + ctx, + (neg_one_dim == -1 && + total_split_size == input_shape.dim_size(split_dim)) || + (neg_one_dim >= 0 && + total_split_size <= input_shape.dim_size(split_dim)), errors::InvalidArgument("Determined shape must either match " "input shape along split_dim exactly if " "fully specified, or be less than the size of " @@ -206,7 +208,10 @@ class SplitVOp : public XlaOpKernel { } }; -REGISTER_XLA_OP(Name("SplitV"), SplitVOp); +REGISTER_XLA_OP(Name("SplitV") + .CompileTimeConstInput("split_dim") + .CompileTimeConstInput("size_splits"), + SplitVOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc index bb7891b31f6d52fd84cf72579c343f50473e1632..d77fb768ef4d124c403a1dc9b321c4f29571d806 100644 --- a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc @@ -40,7 +40,7 @@ namespace { Status GetStackShape(xla::ComputationBuilder* builder, XlaResource* resource, TensorShape* stack_shape) { - auto shape_or_status = builder->GetShape(resource->value); + auto shape_or_status = builder->GetShape(resource->value()); if (!shape_or_status.ok()) { return shape_or_status.status(); } @@ -63,22 +63,24 @@ Status GetStackShape(xla::ComputationBuilder* builder, XlaResource* resource, Status MaybeInitializeStack(xla::ComputationBuilder* builder, XlaResource* resource, DataType dtype, const TensorShape& elem_shape) { - if (resource->type != dtype) { + if (resource->type() != dtype) { return errors::InvalidArgument( - "Stack dtype is ", DataTypeString(resource->type), " but op has dtype ", - DataTypeString(dtype), "."); + "Stack dtype is ", DataTypeString(resource->type()), + " but op has dtype ", DataTypeString(dtype), "."); } TensorShape stack_shape; - stack_shape.AddDim(resource->tensor_array_size); + stack_shape.AddDim(resource->tensor_array_size()); stack_shape.AppendShape(elem_shape); - if (resource->value.handle() == 0) { + if (!resource->initialized()) { // Stack has not been initialized. - xla::ComputationDataHandle zero = XlaHelpers::Zero(builder, resource->type); - resource->value = + xla::ComputationDataHandle zero = + XlaHelpers::Zero(builder, resource->type()); + TF_RETURN_IF_ERROR(resource->SetValue( + dtype, builder->Tuple({builder->Broadcast(zero, stack_shape.dim_sizes()), - builder->ConstantR0(0)}); + builder->ConstantR0(0)}))); } else { // Checks the expected shape matches the actual shape. TensorShape actual_shape; @@ -105,7 +107,9 @@ class StackOp : public XlaOpKernel { OP_REQUIRES( ctx, size >= 0, errors::InvalidArgument( - "XLA compilation requires a fixed stack size upper bound.")); + "XLA compilation requires a fixed stack size upper bound. If " + "you are using tf.while_loop, set the maximum_iterations parameter " + "to fix this issue.")); // We defer initializing the Stack resource until we see the first push. // Otherwise we do not know the shape of the stack elements. @@ -116,7 +120,7 @@ class StackOp : public XlaOpKernel { OP_REQUIRES_OK( ctx, xc.CreateResource(XlaResource::kStack, -1, std::move(name), dtype_, value, &resource)); - resource->tensor_array_size = size; + resource->set_tensor_array_size(size); ctx->SetResourceOutput(0, resource); } @@ -127,7 +131,7 @@ class StackOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(StackOp); }; -REGISTER_XLA_OP(Name("StackV2"), StackOp); +REGISTER_XLA_OP(Name("StackV2").CompileTimeConstInput("max_size"), StackOp); class StackPushOp : public XlaOpKernel { public: @@ -145,8 +149,8 @@ class StackPushOp : public XlaOpKernel { // Initializes the Stack, if the element shape was not already known. OP_REQUIRES_OK(ctx, MaybeInitializeStack(b, resource, dtype_, elem_shape)); - xla::ComputationDataHandle ta = b->GetTupleElement(resource->value, 0); - xla::ComputationDataHandle index = b->GetTupleElement(resource->value, 1); + xla::ComputationDataHandle ta = b->GetTupleElement(resource->value(), 0); + xla::ComputationDataHandle index = b->GetTupleElement(resource->value(), 1); xla::ComputationDataHandle value = ctx->Input(1); // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0]. @@ -160,9 +164,11 @@ class StackPushOp : public XlaOpKernel { // TODO(phawkins): We don't check the index is in bounds --- there is no // error mechanism in XLA. - resource->value = - b->Tuple({b->DynamicUpdateSlice(ta, update, start_indices), - b->Add(index, b->ConstantR0(1))}); + OP_REQUIRES_OK( + ctx, + resource->SetValue( + dtype_, b->Tuple({b->DynamicUpdateSlice(ta, update, start_indices), + b->Add(index, b->ConstantR0(1))}))); ctx->SetOutput(0, value); } @@ -187,27 +193,22 @@ class StackPopOp : public XlaOpKernel { XlaResource* resource; OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource)); - OP_REQUIRES(ctx, resource->type == dtype_, - errors::InvalidArgument( - "Stack dtype is ", DataTypeString(resource->type), - " but Op requested dtype ", DataTypeString(dtype_), ".")); - // There is a somewhat subtle issue here: here "uninitialized" means we have // not yet seen a pop in the order that we compile operators, not the order // that we run them. However, in practice the two orders should be the same // for the sole user of the stack operators (loop gradients). - OP_REQUIRES(ctx, resource->value.handle() != 0, + OP_REQUIRES(ctx, resource->initialized(), errors::InvalidArgument("Stack pop on uninitialized stack")); TensorShape stack_shape; OP_REQUIRES_OK(ctx, GetStackShape(b, resource, &stack_shape)); - xla::ComputationDataHandle state = resource->value; + xla::ComputationDataHandle state = resource->value(); xla::ComputationDataHandle ta = b->GetTupleElement(state, 0); xla::ComputationDataHandle index = b->GetTupleElement(state, 1); index = b->Sub(index, b->ConstantR0(1)); - resource->value = b->Tuple({ta, index}); + OP_REQUIRES_OK(ctx, resource->SetValue(dtype_, b->Tuple({ta, index}))); // start_indices of the DynamicSlice are [index, 0, 0, ..., 0]. auto start_indices = diff --git a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc index 6af4bd0496e0da926726e3f74376281f539e925a..f0525a5fb86d6d6f0aae954a916186cffc7f3a9f 100644 --- a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc @@ -106,7 +106,11 @@ class StridedSliceOp : public XlaOpKernel { DataType index_type_; }; -REGISTER_XLA_OP(Name("StridedSlice"), StridedSliceOp); +REGISTER_XLA_OP(Name("StridedSlice") + .CompileTimeConstInput("begin") + .CompileTimeConstInput("end") + .CompileTimeConstInput("strides"), + StridedSliceOp); class StridedSliceGradOp : public XlaOpKernel { public: @@ -211,7 +215,12 @@ class StridedSliceGradOp : public XlaOpKernel { DataType index_type_; }; -REGISTER_XLA_OP(Name("StridedSliceGrad"), StridedSliceGradOp); +REGISTER_XLA_OP(Name("StridedSliceGrad") + .CompileTimeConstInput("shape") + .CompileTimeConstInput("begin") + .CompileTimeConstInput("end") + .CompileTimeConstInput("strides"), + StridedSliceGradOp); class StridedSliceAssignOp : public XlaOpKernel { public: @@ -320,7 +329,11 @@ class StridedSliceAssignOp : public XlaOpKernel { DataType index_type_; }; -REGISTER_XLA_OP(Name("ResourceStridedSliceAssign"), StridedSliceAssignOp); +REGISTER_XLA_OP(Name("ResourceStridedSliceAssign") + .CompileTimeConstInput("begin") + .CompileTimeConstInput("end") + .CompileTimeConstInput("strides"), + StridedSliceAssignOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc index 351fda251798e43b607fb445f2c98abd57b3d86b..9224072a3cb92b8ff0e99c79e568ca1a76966ed6 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc @@ -21,10 +21,10 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" -#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/tf2xla/xla_resource.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/partial_tensor_shape.h" @@ -50,29 +50,30 @@ namespace { Status MaybeInitializeTensorArray(xla::ComputationBuilder* builder, XlaResource* resource, DataType dtype, const TensorShape& elem_shape) { - if (resource->kind != XlaResource::kTensorArray) { + if (resource->kind() != XlaResource::kTensorArray) { return errors::InvalidArgument("Unexpected non-TensorArray resource"); } - if (resource->type != dtype) { + if (resource->type() != dtype) { return errors::InvalidArgument( - "TensorArray dtype is ", DataTypeString(resource->type), + "TensorArray dtype is ", DataTypeString(resource->type()), " but op has dtype ", DataTypeString(dtype), "."); } - TF_RET_CHECK(resource->tensor_array_size >= 0) - << resource->name << " size " << resource->tensor_array_size; + TF_RET_CHECK(resource->tensor_array_size() >= 0) + << resource->name() << " size " << resource->tensor_array_size(); TensorShape ta_shape; - ta_shape.AddDim(resource->tensor_array_size); + ta_shape.AddDim(resource->tensor_array_size()); ta_shape.AppendShape(elem_shape); - if (resource->value.handle() == 0) { - // TensorArray has not been initialized. - xla::ComputationDataHandle zero = XlaHelpers::Zero(builder, resource->type); - resource->value = builder->Broadcast(zero, ta_shape.dim_sizes()); + if (!resource->initialized()) { + xla::ComputationDataHandle zero = + XlaHelpers::Zero(builder, resource->type()); + TF_RETURN_IF_ERROR(resource->SetValue( + dtype, builder->Broadcast(zero, ta_shape.dim_sizes()))); } else { // Checks the elem_shape matches the TensorArray shape. - auto shape_or_status = builder->GetShape(resource->value); + auto shape_or_status = builder->GetShape(resource->value()); if (!shape_or_status.ok()) { return shape_or_status.status(); } @@ -93,19 +94,17 @@ Status MaybeInitializeTensorArray(xla::ComputationBuilder* builder, Status CheckTensorArrayIsInitialized(const string& op_name, const XlaResource* resource, DataType dtype) { - if (resource->kind != XlaResource::kTensorArray) { + if (resource->kind() != XlaResource::kTensorArray) { return errors::InvalidArgument( - "Unexpected non-TensorArray resource passed " - "to ", - op_name); + "Unexpected non-TensorArray resource passed to ", op_name); } - if (resource->value.handle() == 0) { + if (!resource->initialized()) { return errors::InvalidArgument("Uninitialized TensorArray passed to ", op_name); } - if (resource->type != dtype) { + if (resource->type() != dtype) { return errors::InvalidArgument( - "TensorArray dtype is ", DataTypeString(resource->type), + "TensorArray dtype is ", DataTypeString(resource->type()), " but op has dtype ", DataTypeString(dtype), "."); } @@ -177,7 +176,7 @@ class TensorArrayOp : public XlaOpKernel { OP_REQUIRES_OK( ctx, xc.CreateResource(XlaResource::kTensorArray, -1, std::move(name), dtype_, value, &var)); - var->tensor_array_size = size; + var->set_tensor_array_size(size); ctx->SetResourceOutput(0, var); Tensor flow(DT_FLOAT, TensorShape({})); @@ -193,7 +192,8 @@ class TensorArrayOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(TensorArrayOp); }; -REGISTER_XLA_OP(Name("TensorArrayV3"), TensorArrayOp); +REGISTER_XLA_OP(Name("TensorArrayV3").CompileTimeConstInput("size"), + TensorArrayOp); class TensorArrayWriteOp : public XlaOpKernel { public: @@ -213,7 +213,7 @@ class TensorArrayWriteOp : public XlaOpKernel { OP_REQUIRES_OK(ctx, MaybeInitializeTensorArray(b, resource, dtype_, elem_shape)); - xla::ComputationDataHandle ta = resource->value; + xla::ComputationDataHandle ta = resource->value(); xla::ComputationDataHandle index = ctx->Input(1); xla::ComputationDataHandle value = ctx->Input(2); xla::ComputationDataHandle flow = ctx->Input(3); @@ -230,7 +230,7 @@ class TensorArrayWriteOp : public XlaOpKernel { xla::ComputationDataHandle written = DynamicAddSlice(b, ta, update, slice_shape.dim_sizes(), start_indices); - resource->value = written; + OP_REQUIRES_OK(ctx, resource->SetValue(dtype_, written)); ctx->SetOutput(0, flow); } @@ -259,7 +259,7 @@ class TensorArrayReadOp : public XlaOpKernel { TensorShape ta_shape; OP_REQUIRES_OK(ctx, GetTensorArrayShape(resource, b, &ta_shape)); - xla::ComputationDataHandle ta = resource->value; + xla::ComputationDataHandle ta = resource->value(); xla::ComputationDataHandle index = ctx->Input(1); // start_indices of the DynamicSlice are [index, 0, 0, ..., 0]. @@ -309,7 +309,33 @@ class TensorArrayGatherOp : public XlaOpKernel { auto indices = ctx->Input(1); DataType index_type = ctx->input_type(1); - xla::ComputationDataHandle ta = resource->value; + xla::ComputationDataHandle ta = resource->value(); + + // Look for the case where the gather takes a simple slice from the + // tensor array (0, 1, 2, 3, 4, ..., N) + std::vector const_indices; + Status status = ctx->ConstantInputAsIntVector(1, &const_indices); + if (status.ok()) { + bool gather_is_dense_slice = true; + for (auto i = 0; i < const_indices.size(); i++) { + if (const_indices[i] != i) { + gather_is_dense_slice = false; + break; + } + } + + if (gather_is_dense_slice) { + std::vector begin(ta_shape.dims(), 0); + std::vector strides(ta_shape.dims(), 1); + std::vector end(ta_shape.dims(), 1); + end[0] = const_indices.size(); + for (auto i = 1; i < ta_shape.dims(); i++) { + end[i] = ta_shape.dim_size(i); + } + ctx->SetOutput(0, b->Slice(ta, begin, end, strides)); + return; + } + } xla::ComputationDataHandle gather = XlaComputeGatherDynamicSlice( ctx, ta, ta_shape, indices, indices_shape, 0, dtype_, index_type, b); @@ -348,35 +374,54 @@ class TensorArrayScatterOp : public XlaOpKernel { const int num_indices = indices_shape.dim_size(0); const xla::ComputationDataHandle indices = ctx->Input(1); - xla::ComputationDataHandle ta = resource->value; + xla::ComputationDataHandle ta = resource->value(); const xla::ComputationDataHandle value = ctx->Input(2); const xla::ComputationDataHandle flow = ctx->Input(3); - auto slice_dims = value_shape.dim_sizes(); - slice_dims[0] = 1LL; - - std::vector value_starts(value_shape.dims(), 0); - auto value_ends = value_shape.dim_sizes(); - - std::vector value_strides(value_shape.dims(), 1); - - // For every (index, value) pair, update the corresponding TensorArray - // storage. - for (int i = 0; i < num_indices; ++i) { - // Slice out part of the value. - value_starts[0] = i; - value_ends[0] = i + 1; - auto slice = b->Slice(value, value_starts, value_ends, value_strides); + // Look for the case where the scatter is for each sub-tensor in order. The + // tensor array implementation allows for this to be a straight addition. + bool scatter_all_elements_in_order = false; + std::vector const_indices; + Status status = ctx->ConstantInputAsIntVector(1, &const_indices); + if (status.ok() && num_indices == value_shape.dim_size(0)) { + scatter_all_elements_in_order = true; + for (auto i = 0; i < num_indices; i++) { + if (const_indices[i] != i) { + scatter_all_elements_in_order = false; + break; + } + } + } - // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0]. - auto index = b->Slice(indices, {i}, {i + 1}, {1}); - auto start_indices = - b->Pad(b->Reshape(index, {1}), b->ConstantR0(0), - xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}})); - ta = DynamicAddSlice(b, ta, slice, slice_dims, start_indices); + if (scatter_all_elements_in_order) { + ta = b->Add(ta, value); + } else { + auto slice_dims = value_shape.dim_sizes(); + slice_dims[0] = 1LL; + + std::vector value_starts(value_shape.dims(), 0); + auto value_ends = value_shape.dim_sizes(); + + std::vector value_strides(value_shape.dims(), 1); + + // For every (index, value) pair, update the corresponding TensorArray + // storage. + for (int i = 0; i < num_indices; ++i) { + // Slice out part of the value. + value_starts[0] = i; + value_ends[0] = i + 1; + auto slice = b->Slice(value, value_starts, value_ends, value_strides); + + // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0]. + auto index = b->Slice(indices, {i}, {i + 1}, {1}); + auto start_indices = + b->Pad(b->Reshape(index, {1}), b->ConstantR0(0), + xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}})); + ta = DynamicAddSlice(b, ta, slice, slice_dims, start_indices); + } } - resource->value = ta; + OP_REQUIRES_OK(ctx, resource->SetValue(dtype_, ta)); ctx->SetOutput(0, flow); } @@ -405,7 +450,7 @@ class TensorArrayConcatOp : public XlaOpKernel { TensorShape ta_shape; OP_REQUIRES_OK(ctx, GetTensorArrayShape(resource, b, &ta_shape)); - xla::ComputationDataHandle ta = resource->value; + xla::ComputationDataHandle ta = resource->value(); auto ta_dims = ta_shape.dim_sizes(); std::vector shape(ta_dims.begin() + 1, ta_dims.end()); @@ -460,16 +505,17 @@ class TensorArraySplitOp : public XlaOpKernel { OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource)); OP_REQUIRES_OK(ctx, MaybeInitializeTensorArray(b, resource, dtype_, elem_shape)); - xla::ComputationDataHandle ta = resource->value; + xla::ComputationDataHandle ta = resource->value(); TensorShape ta_shape; - ta_shape.AddDim(resource->tensor_array_size); + ta_shape.AddDim(resource->tensor_array_size()); ta_shape.AppendShape(elem_shape); - OP_REQUIRES(ctx, lengths.size() == resource->tensor_array_size, - errors::InvalidArgument( - "TensorArray's size is not equal to the size of lengths (", - lengths.size(), " vs. ", resource->tensor_array_size, ")")); + OP_REQUIRES( + ctx, lengths.size() == resource->tensor_array_size(), + errors::InvalidArgument( + "TensorArray's size is not equal to the size of lengths (", + lengths.size(), " vs. ", resource->tensor_array_size(), ")")); const xla::ComputationDataHandle value = ctx->Input(1); const xla::ComputationDataHandle flow = ctx->Input(3); @@ -479,7 +525,9 @@ class TensorArraySplitOp : public XlaOpKernel { value_shape.DebugString(), " vs. ", ta_shape.DebugString())); - resource->value = b->Add(ta, b->Reshape(value, ta_shape.dim_sizes())); + OP_REQUIRES_OK( + ctx, resource->SetValue( + dtype_, b->Add(ta, b->Reshape(value, ta_shape.dim_sizes())))); ctx->SetOutput(0, flow); } @@ -490,7 +538,8 @@ class TensorArraySplitOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(TensorArraySplitOp); }; -REGISTER_XLA_OP(Name("TensorArraySplitV3"), TensorArraySplitOp); +REGISTER_XLA_OP(Name("TensorArraySplitV3").CompileTimeConstInput("lengths"), + TensorArraySplitOp); class TensorArraySizeOp : public XlaOpKernel { public: @@ -500,7 +549,8 @@ class TensorArraySizeOp : public XlaOpKernel { XlaResource* var; OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &var)); Tensor size_tensor(DT_INT32, {}); - size_tensor.scalar()() = static_cast(var->tensor_array_size); + size_tensor.scalar()() = + static_cast(var->tensor_array_size()); ctx->SetConstantOutput(0, size_tensor); } @@ -523,7 +573,7 @@ class TensorArrayGradOp : public XlaOpKernel { OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource)); OP_REQUIRES_OK( - ctx, CheckTensorArrayIsInitialized(name(), resource, resource->type)); + ctx, CheckTensorArrayIsInitialized(name(), resource, resource->type())); TensorShape ta_shape; OP_REQUIRES_OK(ctx, GetTensorArrayShape(resource, b, &ta_shape)); diff --git a/tensorflow/compiler/tf2xla/kernels/tile_ops.cc b/tensorflow/compiler/tf2xla/kernels/tile_ops.cc index 9ee6bd892504e683a191484fb09259619759f36d..9aefcd4fc7f94a1dba1c56273c55d0b98fbbfaf2 100644 --- a/tensorflow/compiler/tf2xla/kernels/tile_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tile_ops.cc @@ -122,7 +122,7 @@ class TileOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(TileOp); }; -REGISTER_XLA_OP(Name("Tile"), TileOp); +REGISTER_XLA_OP(Name("Tile").CompileTimeConstInput("multiples"), TileOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/transpose_op.cc b/tensorflow/compiler/tf2xla/kernels/transpose_op.cc index 2fc5d40d1059b868eef0a632071e7cccdecaf9f4..c167642174b328a968d7f7ce1f0ad6e0ab8a7a68 100644 --- a/tensorflow/compiler/tf2xla/kernels/transpose_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/transpose_op.cc @@ -54,7 +54,8 @@ class TransposeOp : public XlaOpKernel { OP_REQUIRES_OK(ctx, ctx->ConstantInputReshaped(1, {dims}, &literal)); std::vector perm(dims); - std::copy(literal.s32s().begin(), literal.s32s().end(), perm.begin()); + std::copy(literal.data().begin(), literal.data().end(), + perm.begin()); std::vector transposed_order; // Check whether permutation is a permutation of integers of [0 .. dims). @@ -72,8 +73,9 @@ class TransposeOp : public XlaOpKernel { } } for (int i = 0; i < dims; ++i) { - OP_REQUIRES(ctx, bits[i], errors::InvalidArgument( - i, " is missing from 'perm' argument.")); + OP_REQUIRES( + ctx, bits[i], + errors::InvalidArgument(i, " is missing from 'perm' argument.")); } // 0-D, 1-D, and identity transposes do nothing. @@ -87,7 +89,7 @@ class TransposeOp : public XlaOpKernel { } }; -REGISTER_XLA_OP(Name("Transpose"), TransposeOp); +REGISTER_XLA_OP(Name("Transpose").CompileTimeConstInput("perm"), TransposeOp); // InvertPermutation frequently forms part of the gradient of Transpose. // @@ -103,8 +105,9 @@ class InvertPermutationOp : public XlaOpKernel { explicit InvertPermutationOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - OP_REQUIRES(ctx, FastBoundsCheck(ctx->InputShape(0).num_elements(), - std::numeric_limits::max()), + OP_REQUIRES(ctx, + FastBoundsCheck(ctx->InputShape(0).num_elements(), + std::numeric_limits::max()), errors::InvalidArgument("permutation of nonnegative int32s " "must have <= int32 max elements")); @@ -128,7 +131,9 @@ class InvertPermutationOp : public XlaOpKernel { } }; -REGISTER_XLA_OP(Name("InvertPermutation").TypeConstraint("T", DT_INT32), +REGISTER_XLA_OP(Name("InvertPermutation") + .TypeConstraint("T", DT_INT32) + .CompileTimeConstInput("x"), InvertPermutationOp); } // namespace diff --git a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc index b19ea22f50d2dd44e8d1d81f5930263f364030e1..68847ae7a2cb926edd9d29007e24b0db7fb5a75f 100644 --- a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h" #include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h" +#include "tensorflow/compiler/tf2xla/kernels/shape_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" @@ -22,6 +23,7 @@ limitations under the License. #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/kernels/no_op.h" namespace tensorflow { @@ -121,5 +123,26 @@ class ResourceGatherOp : public XlaOpKernel { REGISTER_XLA_OP(Name("ResourceGather").TypeConstraint("dtype", kNumericTypes), ResourceGatherOp); +class VariableShapeOp : public XlaOpKernel { + public: + explicit VariableShapeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("out_type", &out_dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + DataType variable_dtype; + TensorShape shape; + OP_REQUIRES_OK(ctx, + ctx->GetVariableTypeAndShape(0, &variable_dtype, &shape)); + Tensor shape_constant(out_dtype_, TensorShape({shape.dims()})); + OP_REQUIRES_OK(ctx, TensorShapeToConstant(shape, &shape_constant)); + ctx->SetConstantOutput(0, shape_constant); + } + + private: + DataType out_dtype_; +}; + +REGISTER_XLA_OP(Name("VariableShape"), VariableShapeOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.cc b/tensorflow/compiler/tf2xla/kernels/while_op.cc index ead26478ff2a3a1302e95e4ee5dbbf366b04efc6..4a711e4d9b7aedb166a8a0ec9fe9ec2390f01b17 100644 --- a/tensorflow/compiler/tf2xla/kernels/while_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/while_op.cc @@ -39,7 +39,7 @@ Status MakeXlaCompilerArgumentsFromInputs( *has_uninitialized_vars = false; *has_tensor_arrays = false; for (int i = 0; i < ctx->num_inputs(); ++i) { - VLOG(2) << " Input " << i + VLOG(2) << " Input " << i << " type: " << DataTypeString(ctx->input_type(i)) << " shape: " << ctx->InputShape(i).DebugString(); XlaCompiler::Argument& arg = (*args)[i]; @@ -50,25 +50,25 @@ Status MakeXlaCompilerArgumentsFromInputs( XlaResource* resource; TF_RETURN_IF_ERROR(ctx->GetResourceInput(i, &resource)); - arg.initialized = resource->value.handle() > 0; + arg.initialized = resource->initialized(); arg.kind = XlaCompiler::Argument::kResource; - arg.resource_kind = resource->kind; + arg.resource_kind = resource->kind(); if (arg.resource_kind == XlaResource::kTensorArray) { *has_tensor_arrays = true; } - arg.type = resource->type; + arg.type = resource->type(); if (arg.initialized) { TF_RETURN_IF_ERROR(resource->PackedShape(ctx->builder(), &arg.shape)); } else { *has_uninitialized_vars = true; } - arg.tensor_array_size = resource->tensor_array_size; - for (const auto& gradient : resource->tensor_array_gradients) { + arg.tensor_array_size = resource->tensor_array_size(); + for (const auto& gradient : resource->tensor_array_gradients()) { arg.tensor_array_gradients.insert(gradient.first); } - arg.name = resource->name; - VLOG(2) << " resource " << resource->name + arg.name = resource->name(); + VLOG(2) << " resource " << resource->name() << " type: " << DataTypeString(arg.type) << " shape: " << xla::ShapeUtil::HumanString(arg.shape) << " initialized: " << arg.initialized; @@ -120,6 +120,7 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { body_options.use_tuple_arg = true; body_options.return_updated_values_for_all_resources = true; body_options.resolve_compile_time_constants = false; + body_options.is_entry_computation = false; XlaCompiler::CompilationResult body; OP_REQUIRES_OK(ctx, compiler->CompileFunction(body_options, body_name_attr_, arguments, &body)); @@ -162,13 +163,14 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { } std::unique_ptr zero = xla::Literal::CreateFromShape(shape); - resource->value = builder->ConstantLiteral(*zero); + OP_REQUIRES_OK(ctx, resource->SetValue( + update.type, builder->ConstantLiteral(*zero))); } // Add any TensorArray gradients touched by the body to the enclosing // graph. for (const string& grad_source : update.tensor_array_gradients_accessed) { - VLOG(4) << "TensorArray " << resource->name << " accessed gradient " + VLOG(4) << "TensorArray " << resource->name() << " accessed gradient " << grad_source; XlaResource* gradient; OP_REQUIRES_OK(ctx, resource->GetOrCreateTensorArrayGradient( @@ -177,7 +179,7 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { // Add all of the TensorArray gradients to the argument. For simplicity, // we always pass all known gradients. - for (const auto& gradient : resource->tensor_array_gradients) { + for (const auto& gradient : resource->tensor_array_gradients()) { arg.tensor_array_gradients.insert(gradient.first); } @@ -196,14 +198,21 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { XlaCompiler::CompileOptions cond_options; cond_options.use_tuple_arg = true; cond_options.resolve_compile_time_constants = false; + cond_options.is_entry_computation = false; XlaCompiler::CompilationResult cond; OP_REQUIRES_OK(ctx, compiler->CompileFunction(cond_options, cond_name_attr_, arguments, &cond)); - xla::Shape body_input_shape = - xla::ShapeUtil::MakeTupleShape(body.xla_input_shapes); - xla::Shape cond_input_shape = - xla::ShapeUtil::MakeTupleShape(cond.xla_input_shapes); + OP_REQUIRES(ctx, body.xla_input_shapes.size() == 1, + errors::FailedPrecondition("Expected one input shape")); + xla::Shape body_input_shape = body.xla_input_shapes[0]; + OP_REQUIRES(ctx, xla::ShapeUtil::IsTuple(body_input_shape), + errors::FailedPrecondition("Expected tuple shape")); + OP_REQUIRES(ctx, cond.xla_input_shapes.size() == 1, + errors::FailedPrecondition("Expected one input shape")); + xla::Shape cond_input_shape = cond.xla_input_shapes[0]; + OP_REQUIRES(ctx, xla::ShapeUtil::IsTuple(cond_input_shape), + errors::FailedPrecondition("Expected tuple shape")); VLOG(2) << "Body shape: " << xla::ShapeUtil::HumanString(body_input_shape) << " -> " << xla::ShapeUtil::HumanString(body.xla_output_shape); @@ -283,10 +292,11 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { OP_REQUIRES_OK(ctx, resource->SetFromPack( arguments[update.input_index].tensor_array_gradients, - builder->GetTupleElement(while_result, pos), builder)); + builder->GetTupleElement(while_result, pos), + /*reset_initial_values=*/false, builder)); } VLOG(2) << "Loop-carried variable: pos: " << update.input_index - << " name: " << resource->name << " modified: " << update.modified + << " name: " << resource->name() << " modified: " << update.modified << " type: " << DataTypeString(update.type) << " shape: " << xla::ShapeUtil::HumanString(update.shape); // Copies the identity of the resource variable from input to output diff --git a/tensorflow/compiler/tf2xla/lib/batch_dot.cc b/tensorflow/compiler/tf2xla/lib/batch_dot.cc index 28a5e6a58bb312f4c4821bcce484a08160009d56..9b0e6174475c22e325c090bec5f1d56822e106bc 100644 --- a/tensorflow/compiler/tf2xla/lib/batch_dot.cc +++ b/tensorflow/compiler/tf2xla/lib/batch_dot.cc @@ -27,7 +27,6 @@ namespace tensorflow { // The current implementation simply unrolls the computation along the batch // dimension. -// TODO(andydavis): add batching support to XLA's Dot operator. xla::StatusOr BatchDot( xla::ComputationBuilder* builder, xla::ComputationDataHandle x, xla::ComputationDataHandle y, bool transpose_x, bool transpose_y) { @@ -52,26 +51,20 @@ xla::StatusOr BatchDot( // The batch dimensions must be equal and the matrix dimensions must be // valid. - std::vector dimensions; - int64 batch_count = 1; + std::vector batch_dimension_numbers; for (int i = 0; i < ndims - 2; ++i) { - int64 x_size = x_shape->dimensions(i); - int64 y_size = y_shape->dimensions(i); - if (x_size != y_size) { + if (x_shape->dimensions(i) != y_shape->dimensions(i)) { return errors::InvalidArgument( "Dimension ", i, " of inputs to BatchedDot must be equal: ", xla::ShapeUtil::HumanString(*x_shape), " vs ", xla::ShapeUtil::HumanString(*y_shape)); } - dimensions.push_back(x_size); - batch_count *= x_size; + batch_dimension_numbers.push_back(i); } int x_inner_dim = transpose_x ? (ndims - 2) : (ndims - 1); int y_inner_dim = transpose_y ? (ndims - 1) : (ndims - 2); - int64 x_inner_dim_size = x_shape->dimensions(x_inner_dim); - int64 y_inner_dim_size = y_shape->dimensions(y_inner_dim); - if (x_inner_dim_size != y_inner_dim_size) { + if (x_shape->dimensions(x_inner_dim) != y_shape->dimensions(y_inner_dim)) { return errors::InvalidArgument( "Dimensions ", x_inner_dim, " and ", y_inner_dim, " of arguments to BatchedDot must be equal: ", @@ -80,19 +73,22 @@ xla::StatusOr BatchDot( " transpose: ", transpose_y); } - // If there are no batch dimensions, use a regular Dot. This case exists - // to improve the readability of the emitted graphs. - if (dimensions.empty()) { - auto lhs = transpose_x ? builder->Transpose(x, {1, 0}) : x; - auto rhs = transpose_y ? builder->Transpose(y, {1, 0}) : y; - return builder->Dot(lhs, rhs); + // Check for zero lhs/rhs dim size. + if (xla::ShapeUtil::HasZeroElements(*x_shape) || + xla::ShapeUtil::HasZeroElements(*y_shape)) { + std::vector dimensions(batch_dimension_numbers.size()); + for (int i = 0; i < batch_dimension_numbers.size(); ++i) { + dimensions[i] = x_shape->dimensions(batch_dimension_numbers[i]); + } + int x_outer_dim = transpose_x ? (ndims - 1) : (ndims - 2); + int y_outer_dim = transpose_y ? (ndims - 2) : (ndims - 1); + dimensions.push_back(x_shape->dimensions(x_outer_dim)); + dimensions.push_back(y_shape->dimensions(y_outer_dim)); + return builder->Broadcast( + builder->ConstantLiteral(xla::Literal::Zero(x_shape->element_type())), + dimensions); } - int x_outer_dim = transpose_x ? (ndims - 1) : (ndims - 2); - int y_outer_dim = transpose_y ? (ndims - 2) : (ndims - 1); - dimensions.push_back(x_shape->dimensions(x_outer_dim)); - dimensions.push_back(y_shape->dimensions(y_outer_dim)); - if (x_shape->element_type() == xla::C64 && transpose_x) { x = builder->Conj(x); } @@ -100,55 +96,23 @@ xla::StatusOr BatchDot( y = builder->Conj(y); } - // Reshape input tensors into 3D tensors by flattening the batch - // dimensions. This makes it easier to unroll the batch dimension. - auto x_flat = - builder->Reshape(x, {batch_count, x_shape->dimensions(ndims - 2), - x_shape->dimensions(ndims - 1)}); - auto y_flat = - builder->Reshape(y, {batch_count, y_shape->dimensions(ndims - 2), - y_shape->dimensions(ndims - 1)}); - - // Slice batches into individual matrices and multiply them. - std::vector out_slices; - for (int64 i = 0; i < batch_count; ++i) { - // Slice off individual matrices and reshape to 2D tensors. - auto x_slice = builder->Slice( - x_flat, {i, 0, 0}, - {i + 1, x_shape->dimensions(ndims - 2), x_shape->dimensions(ndims - 1)}, - {1, 1, 1}); - x_slice = builder->Reshape(x_slice, {x_shape->dimensions(ndims - 2), - x_shape->dimensions(ndims - 1)}); - auto y_slice = builder->Slice( - y_flat, {i, 0, 0}, - {i + 1, y_shape->dimensions(ndims - 2), y_shape->dimensions(ndims - 1)}, - {1, 1, 1}); - y_slice = builder->Reshape(y_slice, {y_shape->dimensions(ndims - 2), - y_shape->dimensions(ndims - 1)}); - - // Transpose if needed. - auto lhs = transpose_x ? builder->Transpose(x_slice, {1, 0}) : x_slice; - auto rhs = transpose_y ? builder->Transpose(y_slice, {1, 0}) : y_slice; - - // Multiply matrices and add an outer singleton dimension to the output - // so we can concatenate along the flattened batch dimension later. - auto out = builder->Dot(lhs, rhs); - out = builder->Reshape(out, - {1, dimensions[ndims - 2], dimensions[ndims - 1]}); - out_slices.push_back(out); + // If there are no batch dimensions, use a regular Dot. + // TODO(b/69062148) Remove this code when Dot emitters can be passed + // dimensions to transpose directly (i.e. without requiring a Transpose HLO). + if (batch_dimension_numbers.empty()) { + auto lhs = transpose_x ? builder->Transpose(x, {1, 0}) : x; + auto rhs = transpose_y ? builder->Transpose(y, {1, 0}) : y; + return builder->Dot(lhs, rhs); } - // Concatenate output slices and reshape to original number of dimensions. - xla::ComputationDataHandle data; - if (out_slices.empty()) { - // It is illegal to pass an empty list to ConcatInDim. - // The batch count is empty, so both inputs must have zero elements. - // Arbitrarily use the left input as the argument to Reshape(). - data = x; - } else { - data = builder->ConcatInDim(out_slices, 0); + xla::DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(x_inner_dim); + dot_dnums.add_rhs_contracting_dimensions(y_inner_dim); + for (auto batch_dimension_number : batch_dimension_numbers) { + dot_dnums.add_lhs_batch_dimensions(batch_dimension_number); + dot_dnums.add_rhs_batch_dimensions(batch_dimension_number); } - return builder->Reshape(data, dimensions); + return builder->DotGeneral(x, y, dot_dnums); } } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/util.cc b/tensorflow/compiler/tf2xla/lib/util.cc index 7ffe0aa6df9b21c4311eb6c8d311fba1e115b3f4..ce24b61b5dc7176f3caa05e3eb9257399fef7926 100644 --- a/tensorflow/compiler/tf2xla/lib/util.cc +++ b/tensorflow/compiler/tf2xla/lib/util.cc @@ -28,7 +28,7 @@ limitations under the License. namespace tensorflow { xla::ComputationDataHandle Zeros(xla::ComputationBuilder* builder, - xla::Shape& shape) { + const xla::Shape& shape) { return builder->Broadcast( builder->ConstantLiteral(xla::Literal::Zero(shape.element_type())), xla::AsInt64Slice(shape.dimensions())); @@ -40,6 +40,9 @@ xla::ComputationDataHandle FloatLiteral(xla::ComputationBuilder* builder, case xla::F16: return builder->ConstantR0(static_cast(value)); break; + case xla::BF16: + return builder->ConstantR0(static_cast(value)); + break; case xla::F32: return builder->ConstantR0(static_cast(value)); break; diff --git a/tensorflow/compiler/tf2xla/lib/util.h b/tensorflow/compiler/tf2xla/lib/util.h index 8fba6b5cf247e9b2c26533c53ece8b0d7d4f4c36..fb138b4f736500aac8184770d97fbf930ced69ea 100644 --- a/tensorflow/compiler/tf2xla/lib/util.h +++ b/tensorflow/compiler/tf2xla/lib/util.h @@ -25,7 +25,7 @@ namespace tensorflow { // Returns a zero-filled tensor with shape `shape`. xla::ComputationDataHandle Zeros(xla::ComputationBuilder* builder, - xla::Shape& shape); + const xla::Shape& shape); // Returns a floating point scalar constant of 'type' with 'value'. // If 'type' is complex, returns a real value with zero imaginary component. diff --git a/tensorflow/compiler/tf2xla/literal_util.cc b/tensorflow/compiler/tf2xla/literal_util.cc index 576cd9bf9abb43e29d9eb8f706e0f42ac2d038e9..fcbd157c6191655865d5e250fdf71338780bc2a6 100644 --- a/tensorflow/compiler/tf2xla/literal_util.cc +++ b/tensorflow/compiler/tf2xla/literal_util.cc @@ -23,17 +23,17 @@ limitations under the License. namespace tensorflow { Status HostTensorToLiteral(const Tensor& host_tensor, xla::Literal* literal) { - literal->Clear(); + xla::Shape literal_shape; TF_RETURN_IF_ERROR(TensorShapeToXLAShape( - host_tensor.dtype(), host_tensor.shape(), literal->mutable_shape())); + host_tensor.dtype(), host_tensor.shape(), &literal_shape)); - literal->Reserve(host_tensor.NumElements()); + *literal = xla::Literal(literal_shape); // memcpy over the payload ... // TODO(phawkins): handle string types. size_t total_bytes = host_tensor.TotalBytes(); if (total_bytes > 0) { - void* dst_ptr = literal->MutableInternalData(); + void* dst_ptr = literal->untyped_data(); const void* src_ptr = DMAHelper::base(&host_tensor); memcpy(dst_ptr, src_ptr, total_bytes); } @@ -56,7 +56,7 @@ Status LiteralToHostTensor(const xla::Literal& literal, DataType target_type, *host_tensor = Tensor(target_type, shape); size_t total_bytes = host_tensor->TotalBytes(); if (total_bytes > 0) { - const void* src_ptr = literal.InternalData(); + const void* src_ptr = literal.untyped_data(); void* dst_ptr = DMAHelper::base(host_tensor); memcpy(dst_ptr, src_ptr, total_bytes); } diff --git a/tensorflow/compiler/tf2xla/sharding_util.cc b/tensorflow/compiler/tf2xla/sharding_util.cc index d9c839b61019b92b6de3a77a7bec610ae848a9a4..1a0e09758f7cc6714793300c6ece14093a8ad246 100644 --- a/tensorflow/compiler/tf2xla/sharding_util.cc +++ b/tensorflow/compiler/tf2xla/sharding_util.cc @@ -14,34 +14,59 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/tf2xla/sharding_util.h" +#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/util/device_name_utils.h" namespace tensorflow { +namespace { +const char kDeviceSuffixReplicatedCore[] = "REPLICATED_CORE"; +const char kShardingAttribute[] = "_XlaSharding"; +} // namespace -static const char DEVICE_SUFFIX_REPLICATED_CORE[] = "REPLICATED_CORE"; +namespace { +xla::StatusOr> +GetShardingFromNodeDef(const NodeDef& node_def) { + if (!HasNodeAttr(node_def, kShardingAttribute)) { + return tensorflow::gtl::optional(); + } + string value; + xla::OpSharding sharding; + TF_RETURN_IF_ERROR(GetNodeAttr(node_def, kShardingAttribute, &value)); + if (!sharding.ParseFromString(value)) { + return xla::InvalidArgument( + "Experimental _XlaSharding attribute was not a valid encoded " + "xla::OpSharding proto."); + } + return tensorflow::gtl::optional(sharding); +} -static Status CoreOutOfRangeError(int core, int num_cores_per_replica) { +Status CoreOutOfRangeError(int core, int num_cores_per_replica) { return errors::InvalidArgument( "Invalid replicated core id: ", core, "; num_cores_per_replica=", num_cores_per_replica); } +} // namespace xla::StatusOr> -ParseShardingFromDevice(const string& device_name, int num_cores_per_replica) { +ParseShardingFromDevice( + const string& device_name, int num_cores_per_replica, + tensorflow::gtl::optional explicit_sharding) { if (device_name.empty()) { return tensorflow::gtl::optional(); } - DeviceNameUtils::ParsedName parsed_device; if (!DeviceNameUtils::ParseFullName(device_name, &parsed_device)) { return errors::InvalidArgument("Malformed assigned device '", device_name, "'"); } - if (!parsed_device.has_type || - !StringPiece(parsed_device.type) - .ends_with(DEVICE_SUFFIX_REPLICATED_CORE)) { + + if (explicit_sharding.has_value()) { + return explicit_sharding; + } else if (!parsed_device.has_type || !parsed_device.has_id || + !StringPiece(parsed_device.type) + .contains(kDeviceSuffixReplicatedCore)) { return tensorflow::gtl::optional(); } else { const int core = parsed_device.id; @@ -49,24 +74,38 @@ ParseShardingFromDevice(const string& device_name, int num_cores_per_replica) { return CoreOutOfRangeError(core, num_cores_per_replica); } return tensorflow::gtl::optional( - xla::ShardingBuilder::AssignDevice(core)); + xla::sharding_builder::AssignDevice(core)); } } +xla::StatusOr> +ParseShardingFromDevice(const NodeDef& node_def, int num_cores_per_replica) { + const string& device_name = node_def.device(); + TF_ASSIGN_OR_RETURN(tensorflow::gtl::optional sharding, + GetShardingFromNodeDef(node_def)); + return ParseShardingFromDevice(device_name, num_cores_per_replica, sharding); +} + xla::StatusOr> ParseShardingFromDevice(const Node& node, int num_cores_per_replica) { string device_name = node.assigned_device_name(); if (device_name.empty()) { device_name = node.requested_device(); } - return ParseShardingFromDevice(device_name, num_cores_per_replica); + TF_ASSIGN_OR_RETURN(tensorflow::gtl::optional sharding, + GetShardingFromNodeDef(node.def())); + return ParseShardingFromDevice(device_name, num_cores_per_replica, sharding); } + void SetShardingDeviceAssignmentFromNode(const Node& src, Node* dst) { string device_name = src.assigned_device_name(); if (device_name.empty()) { device_name = src.requested_device(); } dst->set_assigned_device_name(device_name); + if (const AttrValue* attr = src.attrs().Find(kShardingAttribute)) { + dst->AddAttr(kShardingAttribute, *attr); + } } } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/sharding_util.h b/tensorflow/compiler/tf2xla/sharding_util.h index f6468bba9f950fec88dcc6b3ec760f014d3a0ef3..b1c817bdcc211648b16e395313ca171d1acb9ea9 100644 --- a/tensorflow/compiler/tf2xla/sharding_util.h +++ b/tensorflow/compiler/tf2xla/sharding_util.h @@ -17,7 +17,7 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/sharding_builder.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/lib/core/status.h" @@ -29,14 +29,21 @@ namespace tensorflow { // - if the device name is invalid. // - the core is parsed and is out of the range [0, num_cores_per_replica). // -// Otherwise, returns either a non-value or a sharding set as per -// xla:ShardingBuilder::AssignDevice. +// Otherwise, returns either: +// - explicit_sharding if explicit_sharding.has_value() +// - a non-value if there is no assigned core or +// - a sharding set as per xla::sharding_builder::AssignDevice. xla::StatusOr> -ParseShardingFromDevice(const string& device_name, int num_cores_per_replica); +ParseShardingFromDevice(const string& device_name, int num_cores_per_replica, + tensorflow::gtl::optional + explicit_sharding = tensorflow::gtl::nullopt); xla::StatusOr> ParseShardingFromDevice(const Node& node, int num_cores_per_replica); +xla::StatusOr> +ParseShardingFromDevice(const NodeDef& node_def, int num_cores_per_replica); + void SetShardingDeviceAssignmentFromNode(const Node& src, Node* dst); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/tf2xla.cc b/tensorflow/compiler/tf2xla/tf2xla.cc index a14c93a2b9494b89f579bc20ee0510c136f8f01b..906f2290433face4cce3296b2f815d50d8c496ce 100644 --- a/tensorflow/compiler/tf2xla/tf2xla.cc +++ b/tensorflow/compiler/tf2xla/tf2xla.cc @@ -253,8 +253,7 @@ Status CreateXlaArgs(const Graph& graph, // Converts the TensorFlow graph into an XLA computation, by executing the // graph symbolically, with each op building up the XLA HLO. Status ConvertGraphToXla(std::unique_ptr graph, xla::Client* client, - xla::Computation* computation, - bool* requires_runtime_context) { + xla::Computation* computation) { XlaOpRegistry::RegisterCompilationKernels(); for (Node* node : graph->nodes()) { node->set_assigned_device_name( @@ -277,7 +276,6 @@ Status ConvertGraphToXla(std::unique_ptr graph, xla::Client* client, TF_RETURN_IF_ERROR(compiler.CompileGraph(XlaCompiler::CompileOptions(), "tfcompile", std::move(graph), xla_args, &result)); - *requires_runtime_context = result.requires_runtime_context; *computation = std::move(*result.computation); int num_const_results = 0; @@ -352,12 +350,10 @@ Status InitGraph(const GraphDef& graph_def, const tf2xla::Config& config, Status ConvertGraphDefToXla(const GraphDef& graph_def, const tf2xla::Config& config, xla::Client* client, - xla::Computation* computation, - bool* requires_runtime_context) { + xla::Computation* computation) { std::unique_ptr graph; TF_RETURN_IF_ERROR(InitGraph(graph_def, config, &graph)); - TF_RETURN_IF_ERROR(ConvertGraphToXla(std::move(graph), client, computation, - requires_runtime_context)); + TF_RETURN_IF_ERROR(ConvertGraphToXla(std::move(graph), client, computation)); return Status::OK(); } diff --git a/tensorflow/compiler/tf2xla/tf2xla.h b/tensorflow/compiler/tf2xla/tf2xla.h index ab99beebf7946237425d4d304a858ac6817177b8..473c431b12d441c652f1d0d6c11c5e87836ab36d 100644 --- a/tensorflow/compiler/tf2xla/tf2xla.h +++ b/tensorflow/compiler/tf2xla/tf2xla.h @@ -30,13 +30,9 @@ namespace tensorflow { // // The computation is built in the context of the given `client`, which may // subsequently be used to compile or execute the computation. -// -// If `requires_runtime_context` is filled with true, this indicates the last -// argument of the computation is XlaLocalRuntimeContext*. Status ConvertGraphDefToXla(const GraphDef& graph_def, const tf2xla::Config& config, xla::Client* client, - xla::Computation* computation, - bool* requires_runtime_context); + xla::Computation* computation); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/tf2xla_supported_ops.cc b/tensorflow/compiler/tf2xla/tf2xla_supported_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..7aca889a266439538c4cd1c153460e6cc871b246 --- /dev/null +++ b/tensorflow/compiler/tf2xla/tf2xla_supported_ops.cc @@ -0,0 +1,97 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/tf2xla_supported_ops.h" + +#include +#include +#include +#include + +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/core/framework/kernel_def.pb.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace tensorflow { +namespace tf2xla { +namespace { + +void PrintSupportedOps(const string& device, const string& regen_run) { + XlaOpRegistry::RegisterCompilationKernels(); + + std::vector kdefs = + XlaOpRegistry::DeviceKernels(device, + /*include_compilation_only_kernels=*/true); + std::sort( + kdefs.begin(), kdefs.end(), + [](const KernelDef* a, const KernelDef* b) { return a->op() < b->op(); }); + + std::cout << "**Supported operators for device: " << device << "**\n\n" + << "Operator | Type Constraint\n" + << "-------- | ---------------" << std::endl; + for (const KernelDef* kdef : kdefs) { + std::vector constraints; + for (const KernelDef::AttrConstraint& constraint : kdef->constraint()) { + std::vector types; + for (int type : constraint.allowed_values().list().type()) { + types.push_back(DataTypeString(static_cast(type))); + } + std::sort(types.begin(), types.end()); + constraints.push_back("`" + constraint.name() + "={" + + str_util::Join(types, ",") + "}`"); + } + std::cout << "`" << kdef->op() << "` | " + << str_util::Join(constraints, "
") << std::endl; + } + + std::cout << "\nTo regenerate this table, run:\n\n```shell\n" + << regen_run << " --device=" << device << "\n```" << std::endl; +} + +} // namespace + +void SupportedOpsMain(int argc, char** argv, const char* regen_run) { + std::vector device_names = XlaOpRegistry::BackendNames(); + std::sort(device_names.begin(), device_names.end()); + + // Set up and parse flags. + string device; + std::vector flag_list = { + {"device", &device, + "Name of the compilation device for which to print supported ops, " + "one of: " + + str_util::Join(device_names, ",")}, + }; + string usage = Flags::Usage(argv[0], flag_list); + bool parsed_flags_ok = Flags::Parse(&argc, argv, flag_list); + QCHECK(parsed_flags_ok) << "\n" << usage; + QCHECK(XlaOpRegistry::IsBackendRegistered(device)) + << "\nUnknown device: " << device << "\n" + << usage; + + // Run the program. + port::InitMain(usage.c_str(), &argc, &argv); + QCHECK(argc == 1) << "\nERROR: This command does not take any arguments " + "other than flags\n\n" + << usage; + PrintSupportedOps(device, regen_run); +} + +} // namespace tf2xla +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/tf2xla_supported_ops.h b/tensorflow/compiler/tf2xla/tf2xla_supported_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..1b45fb4cdd3b0173b04e130b7416874a9a406dc5 --- /dev/null +++ b/tensorflow/compiler/tf2xla/tf2xla_supported_ops.h @@ -0,0 +1,33 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_TF2XLA_SUPPORTED_OPS_H_ +#define TENSORFLOW_COMPILER_TF2XLA_TF2XLA_SUPPORTED_OPS_H_ + +namespace tensorflow { +namespace tf2xla { + +// The implementation of a main function for a binary that prints a table of +// supported tf2xla operators for a given device, along with their type +// constraints, to stdout. +// +// Pass the argc and argv from main, unmodified. Use regen_run to specify the +// command used to regenerate the table. +void SupportedOpsMain(int argc, char** argv, const char* regen_run); + +} // namespace tf2xla +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_TF2XLA_SUPPORTED_OPS_H_ diff --git a/tensorflow/compiler/tf2xla/tf2xla_supported_ops_main.cc b/tensorflow/compiler/tf2xla/tf2xla_supported_ops_main.cc new file mode 100644 index 0000000000000000000000000000000000000000..690666c2400d45e33c1a5d1818b68a86a70a5be3 --- /dev/null +++ b/tensorflow/compiler/tf2xla/tf2xla_supported_ops_main.cc @@ -0,0 +1,22 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/tf2xla_supported_ops.h" + +int main(int argc, char** argv) { + const char* regen_run = + "bazel run -c opt -- tensorflow/compiler/tf2xla:tf2xla_supported_ops"; + tensorflow::tf2xla::SupportedOpsMain(argc, argv, regen_run); +} diff --git a/tensorflow/compiler/tf2xla/tf2xla_test.cc b/tensorflow/compiler/tf2xla/tf2xla_test.cc index ecd15652fe84b0c19d2f7fc18f877236547f9be9..a9978e697b091715ce120f0d18fdddd259e08b32 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_test.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_test.cc @@ -70,10 +70,7 @@ TEST(ConvertGraphDefToXla, Sum) { xla::LocalClient* client = xla::ClientLibrary::LocalClientOrDie(); xla::Computation computation; - bool requires_runtime_context; - TF_EXPECT_OK(ConvertGraphDefToXla(graph_def, config, client, &computation, - &requires_runtime_context)); - ASSERT_FALSE(requires_runtime_context); + TF_EXPECT_OK(ConvertGraphDefToXla(graph_def, config, client, &computation)); // Set up arguments. auto x_literal = xla::Literal::CreateR0(10); diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.cc b/tensorflow/compiler/tf2xla/tf2xla_util.cc index 55f2f3149c6ba7bfa18608f961c8a76103a50756..f428a194328935fec1210ea96245344de859e611 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_util.cc @@ -88,8 +88,8 @@ Status ValidateConfig(const tf2xla::Config& config) { TF_RETURN_IF_ERROR(CheckNameDuplicates("fetch", fetch.name(), &names)); } TF_RETURN_IF_ERROR(CheckFeedFetchNameConflicts("fetch", names)); - if (config.feed().empty() || config.fetch().empty()) { - return errors::InvalidArgument("feeds and fetches must be specified"); + if (config.fetch().empty()) { + return errors::InvalidArgument("fetches must be specified"); } return Status::OK(); } diff --git a/tensorflow/compiler/tf2xla/tf2xla_util_test.cc b/tensorflow/compiler/tf2xla/tf2xla_util_test.cc index 436039e154842443f779aba276bc571fc2ab7537..ed10d80609641b090cf78bf2e17364fe2fa89c31 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util_test.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_util_test.cc @@ -58,24 +58,14 @@ TEST(ValidateConfig, Good) { TEST(ValidateConfig, BadEmpty) { tf2xla::Config config; - ExpectErrorContains(ValidateConfig(config), - "feeds and fetches must be specified"); -} - -TEST(ValidateConfig, BadNoFeed) { - tf2xla::Config config; - tf2xla::Fetch* fetch = config.add_fetch(); - fetch->mutable_id()->set_node_name("foo"); - ExpectErrorContains(ValidateConfig(config), - "feeds and fetches must be specified"); + ExpectErrorContains(ValidateConfig(config), "fetches must be specified"); } TEST(ValidateConfig, BadNoFetch) { tf2xla::Config config; tf2xla::Feed* feed = config.add_feed(); feed->mutable_id()->set_node_name("foo"); - ExpectErrorContains(ValidateConfig(config), - "feeds and fetches must be specified"); + ExpectErrorContains(ValidateConfig(config), "fetches must be specified"); } TEST(ValidateConfig, BadFeedNodeName) { diff --git a/tensorflow/compiler/tf2xla/xla_compilation_device.cc b/tensorflow/compiler/tf2xla/xla_compilation_device.cc index 4f32c29954b2d809d31ef8c584b6a6c3dcdf5cef..fcb0a4e63814b4afc114bdaea312a92dd8396a2e 100644 --- a/tensorflow/compiler/tf2xla/xla_compilation_device.cc +++ b/tensorflow/compiler/tf2xla/xla_compilation_device.cc @@ -100,7 +100,7 @@ void XlaCompilationDevice::Compute(OpKernel* op_kernel, b->SetOpMetadata(metadata); auto sharding_parse_result = ParseShardingFromDevice( - op_kernel->requested_device(), std::numeric_limits::max()); + op_kernel->def(), std::numeric_limits::max()); OP_REQUIRES_OK(context, sharding_parse_result.status()); tensorflow::gtl::optional op_sharding = sharding_parse_result.ValueOrDie(); @@ -135,98 +135,4 @@ void XlaExpression::set_constant_value(Tensor value) { constant_value_ = std::move(value); } -Status XlaResource::GetXlaShape(xla::ComputationBuilder* builder, - xla::Shape* shape) const { - auto shape_or_status = builder->GetShape(value); - if (!shape_or_status.ok()) { - return shape_or_status.status(); - } - *shape = *shape_or_status.ValueOrDie(); - return Status::OK(); -} - -Status XlaResource::GetShape(xla::ComputationBuilder* builder, - TensorShape* shape) const { - xla::Shape xla_shape; - TF_RETURN_IF_ERROR(GetXlaShape(builder, &xla_shape)); - TF_RETURN_IF_ERROR(XLAShapeToTensorShape(xla_shape, shape)); - return Status::OK(); -} - -Status XlaResource::GetOrCreateTensorArrayGradient( - const string& source, xla::ComputationBuilder* builder, - XlaResource** gradient_out) { - VLOG(2) << "Gradient lookup for resource: " << name - << " gradient: " << source; - TF_RET_CHECK(kind == kTensorArray); - std::unique_ptr& gradient = tensor_array_gradients[source]; - if (!gradient) { - gradient.reset(new XlaResource); - gradient->kind = XlaResource::kTensorArray; - gradient->name = strings::StrCat("TensorArrayGrad: ", name); - gradient->type = type; - gradient->tensor_array_size = tensor_array_size; - - TensorShape ta_shape; - TF_RETURN_IF_ERROR(GetShape(builder, &ta_shape)); - gradient->value = builder->Broadcast(XlaHelpers::Zero(builder, type), - ta_shape.dim_sizes()); - gradient->initial_value = gradient->value; - } - *gradient_out = gradient.get(); - return Status::OK(); -} - -Status XlaResource::PackedShape(xla::ComputationBuilder* builder, - xla::Shape* packed_shape) const { - if (tensor_array_gradients.empty()) { - return GetXlaShape(builder, packed_shape); - } - TF_RET_CHECK(kind == kTensorArray); - std::vector elem_shapes(1 + tensor_array_gradients.size()); - int pos = 0; - TF_RETURN_IF_ERROR(GetXlaShape(builder, &elem_shapes[pos++])); - for (const auto& gradient : tensor_array_gradients) { - TF_RETURN_IF_ERROR( - gradient.second->GetXlaShape(builder, &elem_shapes[pos++])); - } - *packed_shape = xla::ShapeUtil::MakeTupleShape(elem_shapes); - return Status::OK(); -} - -Status XlaResource::Pack(xla::ComputationDataHandle* pack, - xla::ComputationBuilder* builder) const { - if (tensor_array_gradients.empty()) { - *pack = value; - } else { - TF_RET_CHECK(kind == kTensorArray); - std::vector elems; - elems.push_back(value); - for (const auto& gradient : tensor_array_gradients) { - elems.push_back(gradient.second->value); - } - *pack = builder->Tuple(elems); - } - return Status::OK(); -} - -Status XlaResource::SetFromPack(const std::set& gradient_sources, - const xla::ComputationDataHandle& pack, - xla::ComputationBuilder* builder) { - if (gradient_sources.empty()) { - value = pack; - } else { - TF_RET_CHECK(kind == kTensorArray); - int pos = 0; - value = builder->GetTupleElement(pack, pos++); - for (const auto& source : gradient_sources) { - XlaResource* gradient; - TF_RETURN_IF_ERROR( - GetOrCreateTensorArrayGradient(source, builder, &gradient)); - gradient->value = builder->GetTupleElement(pack, pos++); - } - } - return Status::OK(); -} - } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_compilation_device.h b/tensorflow/compiler/tf2xla/xla_compilation_device.h index 6230acd718bc330f178007b575b5119de5b3d4f4..0243ee332fbdca0fe5e28b1a7d9530df4417f807 100644 --- a/tensorflow/compiler/tf2xla/xla_compilation_device.h +++ b/tensorflow/compiler/tf2xla/xla_compilation_device.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "tensorflow/compiler/tf2xla/xla_resource.h" #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/common_runtime/local_device.h" @@ -66,87 +67,6 @@ class XlaCompilationDevice : public LocalDevice { std::unique_ptr allocator_; }; -// Represents a resource, such as a Variable or TensorArray. -// TODO(phawkins): make this into a properly abstracted class. -struct XlaResource { - enum Kind { - kInvalid, - kVariable, - kTensorArray, - kStack, - }; - - Kind kind = kInvalid; - - // If this resource is visible externally, what was its argument number? - int arg_num = -1; - - // A descriptive name for the resource, used in error messages. - string name; - - // Current type and value of the resource. Uninitialized resources are - // represented by a default (zero) handle and type DT_INVALID. - // While the type of a resource is notionally fixed during execution, when - // a resource is first initialized we do not yet know its type, so we keep - // track of its type dynamically. - DataType type = DT_INVALID; - xla::ComputationDataHandle value; - - // Value of the resource at computation entry. Used to detect which - // variables have new values that need to be written back. - xla::ComputationDataHandle initial_value; - - // TensorArray-specific fields - - // 'tensor_array_size' stores the expected size of the TensorArray. We need - // to store this since sometimes TensorArrays must be initialized lazily since - // we do not know the element shape at construction time. - int64 tensor_array_size = -1; - - // 'tensor_array_gradient' is a map from TensorArrayGradV3 'source' attributes - // to an XlaResource containing the gradient TensorArrays. We store a pointer - // here since there should only be one gradient TensorArray per 'source' - // string, irrespective of the number of calls to TensorArrayGrad. The map - // is ordered since values are packed into tuples by Pack() sorted by name - // order. - std::map> tensor_array_gradients; - - // Returns the shape of the resource as an xla::Shape. - Status GetXlaShape(xla::ComputationBuilder* builder, xla::Shape* shape) const; - - // Returns the shape of the resource as an TensorShape. Fails if the shape is - // not representable as a TensorShape. - Status GetShape(xla::ComputationBuilder* builder, TensorShape* shape) const; - - // Looks up the gradient for `source`, or creates it if it does not already - // exist. The call target must be an initialized TensorArray resource. A - // TensorArray can have multiple named gradients; see the operator - // documentation for TensorArrayGradV3 for details. - Status GetOrCreateTensorArrayGradient(const string& source, - xla::ComputationBuilder* builder, - XlaResource** gradient_out); - - // Packs a resource into a single XLA value `pack`, suitable for use as - // an XlaCompiler::Argument. For non-TensorArrays or TensorArrays without - // gradients, sets `*pack` to `value`. - // For TensorArrays with gradients, packs the value and its gradient values in - // a tuple; the gradients values are packed in order by source name. - Status Pack(xla::ComputationDataHandle* pack, - xla::ComputationBuilder* builder) const; - - // Returns the shape of the `pack` value computed by `Pack()`. - Status PackedShape(xla::ComputationBuilder* builder, - xla::Shape* packed_shape) const; - - // Updates the resource with values from `pack`. If `gradient_sources` is - // non-empty, treats `pack` as a tuple that represents a TensorArray and - // its gradients, and unpacks and updates the gradient resources. Opposite - // of Pack(). - Status SetFromPack(const std::set& gradient_sources, - const xla::ComputationDataHandle& pack, - xla::ComputationBuilder* builder); -}; - // A XlaExpression wraps an XLA computation. Each Tensor on an // XlaCompilationDevice contains an XlaExpression, and the shape of the Tensor // matches the shape of the subcomputation in the ComputationDataHandle. Each diff --git a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc index b5c17c5273bb15e20184b2fefd93880d4828105e..79da701fd244a461a60588153b601d5c1870fa89 100644 --- a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc +++ b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc @@ -28,9 +28,10 @@ XlaCompiledCpuFunction::XlaCompiledCpuFunction(const StaticData& static_data, temps_(new void*[static_data.num_temps]), arg_names_(static_data.arg_names), result_names_(static_data.result_names), - program_shape_(static_data.program_shape) { + program_shape_(static_data.program_shape), + hlo_profile_printer_(static_data.hlo_profile_printer) { // Allocate arg and temp buffers. - if (alloc_mode == AllocMode::ARGS_RESULTS_AND_TEMPS) { + if (alloc_mode == AllocMode::ARGS_RESULTS_PROFILES_AND_TEMPS) { alloc_args_ = tensorflow::tfcompile::runtime::MallocContiguousBuffers( static_data.arg_sizes, static_data.num_args, args_, /*annotate_initialized=*/false); @@ -39,9 +40,13 @@ XlaCompiledCpuFunction::XlaCompiledCpuFunction(const StaticData& static_data, static_data.temp_sizes, static_data.num_temps, temps_, /*annotate_initialized=*/true); - // The runtime context is always the last arg, if it is required. - if (static_data.requires_runtime_context) { - args_[static_data.num_args - 1] = &context_; + // If Hlo profiling is enabled the generated code expects an appropriately + // sized buffer to be passed in as the last argument. If Hlo profiling is + // disabled the last function argument is still present in the function + // signature, but it is ignored by the generated code and we pass in null for + // it. + if (hlo_profiling_enabled()) { + profile_counters_ = new int64[static_data.profile_counters_size](); } } @@ -50,6 +55,7 @@ XlaCompiledCpuFunction::~XlaCompiledCpuFunction() { tensorflow::tfcompile::runtime::FreeContiguous(alloc_temps_); delete[] args_; delete[] temps_; + delete[] profile_counters_; } namespace { diff --git a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h index f49a7889222ff989144217ab10b27595f89e4311..e0ae3ed9a811bcc49ce8862037a67d293e879e57 100644 --- a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h +++ b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h @@ -16,10 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILED_CPU_FUNCTION_H_ #define TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILED_CPU_FUNCTION_H_ -#include +#include #include -#include "tensorflow/compiler/tf2xla/xla_local_runtime_context.h" #include "tensorflow/compiler/xla/executable_run_options.h" #include "tensorflow/core/platform/types.h" @@ -27,6 +26,7 @@ limitations under the License. // never use this functionality. namespace xla { class ProgramShape; +class HloProfilePrinter; } namespace tensorflow { @@ -48,12 +48,10 @@ namespace tensorflow { class XlaCompiledCpuFunction { public: // Type of the raw function, produced by either JIT or AOT. - // - // TODO(toddw): Add support for hlo profiling, and replace std::function with - // a raw function pointer, for some codesize savings. - using RawFunction = std::function; + using RawFunction = void (*)(void* result, + const xla::ExecutableRunOptions* run_options, + const void** args, void** temps, + int64* profile_counters); // StaticData represents the state necessary to run an XLA-compiled // function. For JIT this is backed by data in XlaJitCompiledCpuFunction; for @@ -71,9 +69,6 @@ class XlaCompiledCpuFunction { // The 0-based index of the result tuple, in the temp buffers. size_t result_index = 0; - // Is the final arg XlaLocalRuntimeContext? - bool requires_runtime_context = false; - // [Optional] Arrays of arg and result names. These are arrays of C-style // strings, where the array is terminated by nullptr. const char** arg_names = nullptr; @@ -81,21 +76,29 @@ class XlaCompiledCpuFunction { // [Optional] Arg and result shapes. const xla::ProgramShape* program_shape = nullptr; + + // [Optional] Profile printer. Null if profiling is disabled. + const xla::HloProfilePrinter* hlo_profile_printer = nullptr; + + // [Optional] The number of profile counters expected in the profile counter + // buffer by the generated code and hlo_profile_printer. 0 if profiling is + // disabled. + int64 profile_counters_size = 0; }; // AllocMode controls the buffer allocation mode. enum class AllocMode { - // Allocate all buffers - args, results and temps. - ARGS_RESULTS_AND_TEMPS, + // Allocate all buffers - args, results, profile and temps. + ARGS_RESULTS_PROFILES_AND_TEMPS, - // Only allocate result and temp buffers. + // Only allocate result, profile and temp buffers. // Use set_arg_data to set argument buffers before Run is called. - RESULTS_AND_TEMPS_ONLY, + RESULTS_PROFILES_AND_TEMPS_ONLY, }; XlaCompiledCpuFunction( const StaticData& static_data, - AllocMode alloc_mode = AllocMode::ARGS_RESULTS_AND_TEMPS); + AllocMode alloc_mode = AllocMode::ARGS_RESULTS_PROFILES_AND_TEMPS); virtual ~XlaCompiledCpuFunction(); XlaCompiledCpuFunction(const XlaCompiledCpuFunction&) = delete; @@ -104,21 +107,22 @@ class XlaCompiledCpuFunction { // Sets the intra-op thread pool used to run individual ops concurrently. void set_thread_pool(const Eigen::ThreadPoolDevice* pool) { run_options_.set_intra_op_thread_pool(pool); - context_.thread_pool = pool; } // Runs the computation, with inputs read from arg buffers, and outputs // written to result buffers. Returns true on success and false on failure. bool Run() { - context_.error = false; - context_.error_msg.clear(); raw_function_(temps_[result_index_], &run_options_, - const_cast(args_), temps_); - return !context_.error; + const_cast(args_), temps_, profile_counters_); + return true; } // Returns the error message from the previous failed Run call. - const string& error_msg() const { return context_.error_msg; } + // + // TODO(fschneider): For now this always returns an empty string because there + // is no support for error reporting in XLA. Remove this once all callers are + // updated. + string error_msg() const { return {}; } // ------------------------------ // Arg methods for managing input buffers. Buffers are in row-major order. @@ -141,10 +145,6 @@ class XlaCompiledCpuFunction { // tensorflow::tfcompile::runtime::kAlign. If possible, use the functions in // tensorflow/compiler/aot/runtime.h to ensure correct alignment. // - // If StaticData.requires_runtime_context==true, the final argument is an - // XlaLocalRuntimeContext, which is managed internally by this class, and - // should not be changed. - // // Aliasing of argument and result buffers is not allowed, and results in // undefined behavior. void set_arg_data(size_t index, void* data) { args_[index] = data; } @@ -162,6 +162,16 @@ class XlaCompiledCpuFunction { return static_cast(temps_[result_index_]); } + // Profile counters for this XLA computation. + // + // When Hlo profiling is enabled (`hlo_profiling_enabled()` return true in + // this case) these counters are non-null and are automatically populated by + // `Run`. The counters can then be pretty-printed using + // `hlo_profile_printer()`. + // + // When Hlo profiling is disabled, this accessor returns null. + const int64* profile_counters() const { return profile_counters_; } + // Returns the buffer for the positional result at the given `index`. void* result_data(size_t index) { return results()[index]; } const void* result_data(size_t index) const { return results()[index]; } @@ -195,6 +205,12 @@ class XlaCompiledCpuFunction { // program shape isn't available. const xla::ProgramShape* ProgramShape() const { return program_shape_; } + bool hlo_profiling_enabled() const { return hlo_profile_printer_ != nullptr; } + const xla::HloProfilePrinter& hlo_profile_printer() const { + assert(hlo_profiling_enabled()); + return *hlo_profile_printer_; + } + private: const RawFunction raw_function_; const size_t result_index_; @@ -208,14 +224,17 @@ class XlaCompiledCpuFunction { void* alloc_args_ = nullptr; void* alloc_temps_ = nullptr; + // Backing memory for profiling counters. + int64* profile_counters_ = nullptr; + // Options and context passed to the compiled function. xla::ExecutableRunOptions run_options_; - tensorflow::XlaLocalRuntimeContext context_; // Optional metadata. const char** arg_names_ = nullptr; const char** result_names_ = nullptr; const xla::ProgramShape* program_shape_ = nullptr; + const xla::HloProfilePrinter* hlo_profile_printer_ = nullptr; }; } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 48cebdf74c71f974bf075e0255626ec57eb9a149..69b265436bb19bbbdd9deb872f4097d4bac7ea52 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -268,7 +268,8 @@ Status BuildArguments(const Graph& graph, XlaContext* context, std::vector* arg_cores, std::vector* arg_expressions, std::vector* input_mapping, - std::vector* input_shapes) { + std::vector* input_shapes, + bool is_entry_computation) { arg_expressions->resize(args.size()); *arg_cores = std::vector(args.size(), -1); @@ -292,7 +293,7 @@ Status BuildArguments(const Graph& graph, TF_RETURN_IF_ERROR( context->CreateResource(arg.resource_kind, i, arg.name, arg.type, xla::ComputationDataHandle(), &resource)); - resource->tensor_array_size = arg.tensor_array_size; + resource->set_tensor_array_size(arg.tensor_array_size); arg_expression.set_resource(resource); if (arg.initialized) { resources.push_back(i); @@ -316,15 +317,22 @@ Status BuildArguments(const Graph& graph, return Status::OK(); } - input_shapes->resize(parameters.size()); + std::vector arg_shapes; + arg_shapes.reserve(parameters.size()); input_mapping->resize(parameters.size()); for (std::vector::size_type i = 0; i < parameters.size(); ++i) { const XlaCompiler::Argument& arg = args[parameters[i]]; // Computes the shapes of non-constant arguments. - (*input_shapes)[i] = arg.shape; + arg_shapes.push_back(arg.shape); (*input_mapping)[i] = parameters[i]; } + if (use_tuple_arg) { + input_shapes->push_back(xla::ShapeUtil::MakeTupleShape(arg_shapes)); + } else { + *input_shapes = arg_shapes; + } + // Use the _Arg nodes in the graph to resolve core assignments. for (const Node* n : graph.nodes()) { if (StringPiece(n->type_string()) != "_Arg") continue; @@ -348,14 +356,28 @@ Status BuildArguments(const Graph& graph, // Build parameter handles for non-constant arguments. std::vector arg_handles(parameters.size()); if (use_tuple_arg) { - xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(*input_shapes); - xla::ComputationDataHandle tuple = - builder->Parameter(0, tuple_shape, "arg_tuple"); + xla::ComputationDataHandle tuple; + if (is_entry_computation) { + xla::OpSharding tuple_sharding; + tuple_sharding.set_type(xla::OpSharding::Type::OpSharding_Type_TUPLE); + for (int64 parameter : parameters) { + const int core = (*arg_cores)[parameter]; + const int root_device = 0; + *tuple_sharding.add_tuple_shardings() = + core == -1 ? xla::sharding_builder::AssignDevice(root_device) + : xla::sharding_builder::AssignDevice(core); + } + xla::ScopedShardingAssignment assign_tuple_sharding(builder, + tuple_sharding); + tuple = builder->Parameter(0, (*input_shapes)[0], "arg_tuple"); + } else { + tuple = builder->Parameter(0, (*input_shapes)[0], "arg_tuple"); + } for (std::vector::size_type i = 0; i < parameters.size(); ++i) { const int core = (*arg_cores)[parameters[i]]; xla::ScopedShardingAssignment assign_sharding( builder, core == -1 ? tensorflow::gtl::optional() - : xla::ShardingBuilder::AssignDevice(core)); + : xla::sharding_builder::AssignDevice(core)); arg_handles[i] = builder->GetTupleElement(tuple, i); } } else { @@ -363,7 +385,7 @@ Status BuildArguments(const Graph& graph, const int core = (*arg_cores)[parameters[i]]; xla::ScopedShardingAssignment assign_sharding( builder, core == -1 ? tensorflow::gtl::optional() - : xla::ShardingBuilder::AssignDevice(core)); + : xla::sharding_builder::AssignDevice(core)); arg_handles[i] = builder->Parameter(i, (*input_shapes)[i], strings::StrCat("arg", i)); } @@ -374,21 +396,18 @@ Status BuildArguments(const Graph& graph, for (std::vector::size_type i = 0; i < parameters.size(); ++i) { const XlaCompiler::Argument& arg = args[parameters[i]]; VLOG(2) << " XLA arg " << i - << " shape: " << xla::ShapeUtil::HumanString((*input_shapes)[i]) + << " shape: " << xla::ShapeUtil::HumanString(arg_shapes[i]) << " name: " << arg.name << " TF arg " << parameters[i]; XlaExpression& arg_expression = (*arg_expressions)[parameters[i]]; switch (arg.kind) { case XlaCompiler::Argument::kResource: { TF_RET_CHECK(arg.initialized); XlaResource* resource = arg_expression.resource(); - TF_RETURN_IF_ERROR(resource->SetFromPack(arg.tensor_array_gradients, - arg_handles[i], builder)); + TF_RETURN_IF_ERROR( + resource->SetFromPack(arg.tensor_array_gradients, arg_handles[i], + /*reset_initial_values=*/true, builder)); VLOG(2) << " resource: num_gradients: " << arg.tensor_array_gradients.size(); - resource->initial_value = resource->value; - for (const auto& gradient : resource->tensor_array_gradients) { - gradient.second->initial_value = gradient.second->value; - } break; } case XlaCompiler::Argument::kParameter: @@ -439,43 +458,43 @@ Status BuildComputation( std::vector arg_resources; arg_resources.reserve(resources.size()); for (const auto& resource : resources) { - if (resource->arg_num >= 0) { + if (resource->arg_num() >= 0) { arg_resources.push_back(resource.get()); } } std::sort(arg_resources.begin(), arg_resources.end(), [](const XlaResource* a, const XlaResource* b) { - return a->arg_num < b->arg_num; + return a->arg_num() < b->arg_num(); }); for (const XlaResource* resource : arg_resources) { - const XlaCompiler::Argument& arg = args[resource->arg_num]; - const int core = arg_cores[resource->arg_num]; - DCHECK_LT(resource->arg_num, arg_cores.size()); + const XlaCompiler::Argument& arg = args[resource->arg_num()]; + const int core = arg_cores[resource->arg_num()]; + DCHECK_LT(resource->arg_num(), arg_cores.size()); bool modified = - resource->value.handle() != resource->initial_value.handle(); + resource->value().handle() != resource->initial_value().handle(); // TensorArray gradients were modified if their values changed or there are // any newly created gradients. - for (const auto& grad : resource->tensor_array_gradients) { - modified = - modified || - grad.second->value.handle() != grad.second->initial_value.handle() || - arg.tensor_array_gradients.count(grad.first) == 0; + for (const auto& grad : resource->tensor_array_gradients()) { + modified = modified || + grad.second->value().handle() != + grad.second->initial_value().handle() || + arg.tensor_array_gradients.count(grad.first) == 0; } if (return_updated_values_for_all_resources || modified) { resource_updates->emplace_back(); XlaCompiler::ResourceUpdate& update = resource_updates->back(); - update.input_index = resource->arg_num; - update.type = resource->type; + update.input_index = resource->arg_num(); + update.type = resource->type(); update.modified = modified; - for (const auto& grad : resource->tensor_array_gradients) { + for (const auto& grad : resource->tensor_array_gradients()) { update.tensor_array_gradients_accessed.insert(grad.first); } // Request that the value be returned on a specific core. xla::ScopedShardingAssignment assign_sharding( builder, core == -1 ? tensorflow::gtl::optional() - : xla::ShardingBuilder::AssignDevice(core)); + : xla::sharding_builder::AssignDevice(core)); xla::ComputationDataHandle handle; TF_RETURN_IF_ERROR(resource->Pack(&handle, builder)); @@ -502,18 +521,6 @@ Status BuildComputation( return Status::OK(); } -void AssignMajorToMinorLayout(xla::Shape* shape) { - if (xla::ShapeUtil::IsTuple(*shape)) { - for (xla::Shape& elem_shape : *shape->mutable_tuple_shapes()) { - AssignMajorToMinorLayout(&elem_shape); - } - } else { - auto& minor_to_major = *shape->mutable_layout()->mutable_minor_to_major(); - minor_to_major.Resize(xla::ShapeUtil::Rank(*shape), 0); - std::iota(minor_to_major.rbegin(), minor_to_major.rend(), 0); - } -} - } // namespace Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, @@ -543,13 +550,12 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, options.resolve_compile_time_constants); core::ScopedUnref context_unref(context); - result->tuple_arg = options.use_tuple_arg; - std::vector arg_expressions; std::vector arg_cores; - TF_RETURN_IF_ERROR(BuildArguments( - *graph, args, options.use_tuple_arg, &builder, context, &arg_cores, - &arg_expressions, &result->input_mapping, &result->xla_input_shapes)); + TF_RETURN_IF_ERROR( + BuildArguments(*graph, args, options.use_tuple_arg, &builder, context, + &arg_cores, &arg_expressions, &result->input_mapping, + &result->xla_input_shapes, options.is_entry_computation)); context->set_args(std::move(arg_expressions)); TF_RETURN_IF_ERROR(ExecuteGraph(context, std::move(graph), device_, @@ -564,11 +570,6 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, result->computation.get(), &num_computation_outputs, &num_nonconst_outputs, &result->resource_updates)); - result->requires_runtime_context = context->has_context_parameter(); - - // Tuple arguments and runtime context parameters are incompatible. - TF_RET_CHECK(!(options.use_tuple_arg && result->requires_runtime_context)); - VLOG(2) << "Outputs: total: " << context->retvals().size() << " nonconstant: " << num_nonconst_outputs; result->outputs.resize(context->retvals().size()); @@ -596,7 +597,7 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, << xla::ShapeUtil::HumanString(result->xla_output_shape); // Tensorflow expects a major-to-minor order of results. - AssignMajorToMinorLayout(&result->xla_output_shape); + xla::LayoutUtil::SetToDefaultLayout(&result->xla_output_shape); // Converts the output shapes to TensorShapes. int computation_output = 0; diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h index ac7d4cfb127d1de8c92f3a855191c45af77888ad..6a46e54f61cb4dbb2a2c1916696655a4e3d85fff 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.h +++ b/tensorflow/compiler/tf2xla/xla_compiler.h @@ -54,8 +54,6 @@ namespace tensorflow { // +---------------------+-----------------------------------------+ // Within each block, the arguments are arranged by the _Arg index from which // they were derived. -// If `Options::requires_runtime_context` is true, then an additional runtime -// context argument is passed as a final argument. // // The run-time outputs of the XLA computation are arranged in the following // order: @@ -154,6 +152,10 @@ class XlaCompiler { // as Tensors at compile-time, rather than as run-time outputs of the // computation. bool resolve_compile_time_constants = true; + + // True when compiling the entry computation, false for subcomputations + // (while, call, etc.) + bool is_entry_computation = true; }; struct OutputDescription { @@ -191,16 +193,9 @@ class XlaCompiler { // original arguments, and are not necessarily in the same order.) std::vector input_mapping; - // Does the computation require the local runtime context to be passed as - // the last argument? - bool requires_runtime_context = false; - // Input shapes of the computation. std::vector xla_input_shapes; - // Should the arguments be packed into a single tuple? - bool tuple_arg; - // Output shape in XLA format. The output shape is always a tuple. xla::Shape xla_output_shape; @@ -232,8 +227,7 @@ class XlaCompiler { int graph_def_version = TF_GRAPH_DEF_VERSION; // If 'allow_cpu_custom_calls' is true, kernels may make use of CustomCall() - // for CPU; additionally, an optional XlaLocalRuntimeContext* may be passed - // to the computation. + // for CPU. bool allow_cpu_custom_calls = false; // If not nullptr, populate_resource_manager is called with the diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc index 93aae8485d157cd4afbf804d695d5c0ab8d7946c..7ebe4b75bc1e33e506624314b11163e36a2477de 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc @@ -227,6 +227,42 @@ TEST_F(XlaCompilerTest, Simple) { xla::LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal); } +TEST_F(XlaCompilerTest, HasSaneErrorOnNonCompileTimeConstantInputToReshape) { + // Builds a graph that adds reshapes a tensor, but with the shape not + // statically known. + Scope scope = Scope::NewRootScope().ExitOnError(); + auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0); + auto b = ops::_Arg(scope.WithOpName("B"), DT_INT32, 1); + auto c = ops::Reshape(scope.WithOpName("C"), a, b); + auto d = ops::_Retval(scope.WithOpName("D"), c, 0); + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(scope.ToGraph(graph.get())); + + // Builds a description of the arguments. + std::vector args(2); + args[0].kind = XlaCompiler::Argument::kParameter; + args[0].type = DT_INT32; + args[0].shape = xla::ShapeUtil::MakeShape(xla::S32, {2}); + args[1].kind = XlaCompiler::Argument::kParameter; + args[1].type = DT_INT32; + args[1].shape = xla::ShapeUtil::MakeShape(xla::S32, {2}); + + // Compiles the graph. + XlaCompiler compiler(DefaultOptions()); + + XlaCompiler::CompilationResult result; + Status status = + compiler.CompileGraph(XlaCompiler::CompileOptions(), "reshape", + std::move(graph), args, &result); + EXPECT_FALSE(status.ok()); + EXPECT_TRUE( + StringPiece(status.error_message()).contains("depends on a parameter")) + << status.error_message(); + EXPECT_TRUE( + StringPiece(status.error_message()).contains("[[Node: C = Reshape")) + << status.error_message(); +} + // Tests handling of compile-time constant outputs. TEST_F(XlaCompilerTest, ConstantOutputs) { // Builds a graph with one compile-time constant output and one data-dependent diff --git a/tensorflow/compiler/tf2xla/xla_context.cc b/tensorflow/compiler/tf2xla/xla_context.cc index 651bafd6c5d946adfedd63ebbe93e4ea016f0b37..e8d17e2e0a1ba01f16d4bbbd2895b112f4dd1989 100644 --- a/tensorflow/compiler/tf2xla/xla_context.cc +++ b/tensorflow/compiler/tf2xla/xla_context.cc @@ -70,24 +70,6 @@ XlaContext::XlaContext(XlaCompiler* compiler, xla::ComputationBuilder* builder, allow_cpu_custom_calls_(allow_cpu_custom_calls), resolve_compile_time_constants_(resolve_compile_time_constants) {} -const xla::ComputationDataHandle& -XlaContext::GetOrCreateRuntimeContextParameter() { - CHECK(allow_cpu_custom_calls_); - if (has_context_parameter_) return context_parameter_; - has_context_parameter_ = true; - - // Allocate the next available parameter for the context parameter. - int num_parameters = 0; - for (const XlaExpression& arg : args_) { - if (!arg.has_constant_value()) { - ++num_parameters; - } - } - context_parameter_ = builder_->Parameter( - num_parameters, xla::ShapeUtil::MakeOpaqueShape(), "tf_context"); - return context_parameter_; -} - string XlaContext::DebugString() { return "TLA JIT context"; } // This is called by the Retval Op to associate a computed value @@ -125,14 +107,9 @@ Status XlaContext::CreateResource(XlaResource::Kind kind, int arg_num, string name, DataType type, const xla::ComputationDataHandle& handle, XlaResource** resource) { - resources_.emplace_back(new XlaResource); + resources_.emplace_back( + new XlaResource(kind, arg_num, std::move(name), type, handle)); *resource = resources_.back().get(); - XlaResource& r = **resource; - r.kind = kind; - r.arg_num = arg_num; - r.name = std::move(name); - r.type = type; - r.initial_value = r.value = handle; return Status::OK(); } @@ -178,6 +155,20 @@ const xla::Computation* XlaContext::GetOrCreateAdd(const DataType type) { }); } +const xla::Computation* XlaContext::GetOrCreateMul(const DataType type) { + return LookupOrCreate(type, &mul_func_, [this, type] { + const string type_string = DataTypeString(type); + VLOG(1) << "Building Mul() for " << type_string; + xla::ComputationBuilder b(builder()->client(), "mul<" + type_string + ">"); + xla::PrimitiveType xla_type; + TF_CHECK_OK(DataTypeToPrimitiveType(type, &xla_type)); + auto x = b.Parameter(0, xla::ShapeUtil::MakeShape(xla_type, {}), "x"); + auto y = b.Parameter(1, xla::ShapeUtil::MakeShape(xla_type, {}), "y"); + b.Mul(x, y); + return b.Build().ConsumeValueOrDie(); + }); +} + const xla::Computation* XlaContext::LookupOrCreate( DataType type, ComputationMap* out, const std::function& create) { diff --git a/tensorflow/compiler/tf2xla/xla_context.h b/tensorflow/compiler/tf2xla/xla_context.h index de8aafa3628e6eebdabbc508cd95a2ac86e3472f..1a7dafe8cdb56cc9b8fcd3ba6e262c21c2a07d90 100644 --- a/tensorflow/compiler/tf2xla/xla_context.h +++ b/tensorflow/compiler/tf2xla/xla_context.h @@ -56,15 +56,10 @@ class XlaContext : public ResourceBase { xla::ComputationBuilder* builder(); bool allow_cpu_custom_calls() const { return allow_cpu_custom_calls_; } - bool has_context_parameter() const { return has_context_parameter_; } const std::vector& args() const { return args_; } void set_args(std::vector args); - // Get the runtime context parameter, adding one if it does not already exist. - // Dies if not compiling a local executable. - const xla::ComputationDataHandle& GetOrCreateRuntimeContextParameter(); - const std::vector& retvals() { return retvals_; } // This is called by the Retval Op to associate a computed value @@ -102,6 +97,11 @@ class XlaContext : public ResourceBase { // separate specialization of the computation for each DataType. const xla::Computation* GetOrCreateAdd(const DataType type); + // Get an XLA lambda to compute Mul. This is cached in the + // XlaContext since it may be used by multiple Ops. There is a + // separate specialization of the computation for each DataType. + const xla::Computation* GetOrCreateMul(const DataType type); + // The name of the XlaContext resource during symbolic graph execution. static const char kXlaContextResourceName[]; @@ -116,16 +116,9 @@ class XlaContext : public ResourceBase { const bool allow_cpu_custom_calls_; // If true, constant return values are returned as Tensors instead of - // run-time computation outptus. + // run-time computation outputs. const bool resolve_compile_time_constants_; - // When 'has_context_parameter_' is true, this is the computation handle - // for an additional final parameter to the computation, through which will be - // passed a XlaLocalRuntimeContext* at runtime. Created on demand by - // GetOrCreateRuntimeContextParameter(). - bool has_context_parameter_ = false; - xla::ComputationDataHandle context_parameter_; - // Arguments to the Tensorflow graph, indexed by _Arg index. // Includes both compile-time constant arguments and runtime parameters. std::vector args_; @@ -155,6 +148,9 @@ class XlaContext : public ResourceBase { // Cached computation to compute Sum of two elements, specialized by type. ComputationMap add_func_; + // Cached computation to compute Mul of two elements, specialized by type. + ComputationMap mul_func_; + // Cached computation to compute Sigmoid of an element, specialized by type. ComputationMap sigmoid_func_; diff --git a/tensorflow/compiler/tf2xla/xla_gpu_backend.cc b/tensorflow/compiler/tf2xla/xla_gpu_backend.cc index d504613d232c779e47a506657d2825d052e726dc..8ca757e72355d890c13b8b448d35c327d3986696 100644 --- a/tensorflow/compiler/tf2xla/xla_gpu_backend.cc +++ b/tensorflow/compiler/tf2xla/xla_gpu_backend.cc @@ -21,8 +21,6 @@ namespace tensorflow { bool GpuOpFilter(KernelDef* kdef) { // TODO(b/31361304): The GPU backend does not parallelize PRNG ops, leading to // slow code. - // TODO(b/34969189) The implementation of TruncatedNormal generates illegal - // code on GPU. if (kdef->op() == "RandomStandardNormal" || kdef->op() == "RandomUniform" || kdef->op() == "RandomUniformInt" || kdef->op() == "TruncatedNormal") { return false; diff --git a/tensorflow/compiler/tf2xla/xla_helpers.cc b/tensorflow/compiler/tf2xla/xla_helpers.cc index 9c3e15d2fa4c84af94d137f2e03107bcc980f4cd..77e24162676045b88dc8b62d2c6a4ecc1e738e96 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.cc +++ b/tensorflow/compiler/tf2xla/xla_helpers.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// This file defines helper routines for Tla JIT compilation. +// This file defines helper routines for XLA compilation. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/lib/util.h" @@ -121,6 +121,8 @@ xla::ComputationDataHandle XlaHelpers::One(xla::ComputationBuilder* b, xla::ComputationDataHandle XlaHelpers::Epsilon(xla::ComputationBuilder* b, DataType data_type) { switch (data_type) { + case DT_BFLOAT16: + return b->ConstantR0(bfloat16::epsilon()); case DT_FLOAT: return b->ConstantR0(std::numeric_limits::epsilon()); case DT_DOUBLE: @@ -138,40 +140,44 @@ xla::ComputationDataHandle XlaHelpers::IntegerLiteral( TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type)); switch (type) { case xla::U8: - literal = *xla::Literal::CreateR0(value); + literal = std::move(*xla::Literal::CreateR0(value)); break; case xla::U32: - literal = *xla::Literal::CreateR0(value); + literal = std::move(*xla::Literal::CreateR0(value)); break; case xla::U64: - literal = *xla::Literal::CreateR0(value); + literal = std::move(*xla::Literal::CreateR0(value)); break; case xla::S8: - literal = *xla::Literal::CreateR0(value); + literal = std::move(*xla::Literal::CreateR0(value)); break; case xla::S32: - literal = *xla::Literal::CreateR0(value); + literal = std::move(*xla::Literal::CreateR0(value)); break; case xla::S64: - literal = *xla::Literal::CreateR0(value); + literal = std::move(*xla::Literal::CreateR0(value)); break; case xla::F32: - literal = *xla::Literal::CreateR0(value); + literal = std::move(*xla::Literal::CreateR0(value)); break; case xla::F64: - literal = *xla::Literal::CreateR0(value); + literal = std::move(*xla::Literal::CreateR0(value)); break; case xla::C64: - literal = *xla::Literal::CreateR0(value); + literal = std::move(*xla::Literal::CreateR0(value)); break; case xla::PRED: LOG(FATAL) << "pred element type is not integral"; case xla::S16: case xla::U16: LOG(FATAL) << "u16/s16 literals not yet implemented"; + case xla::BF16: + literal = std::move( + *xla::Literal::CreateR0(static_cast(value))); + break; case xla::F16: - literal = - *xla::Literal::CreateR0(static_cast(value)); + literal = std::move( + *xla::Literal::CreateR0(static_cast(value))); break; case xla::TUPLE: LOG(FATAL) << "tuple element type is not integral"; @@ -207,8 +213,8 @@ xla::ComputationDataHandle XlaHelpers::FloatLiteral(xla::ComputationBuilder* b, "elements."); } - *output = input; - output->mutable_shape()->Swap(&shape); + *output = input.Clone(); + output->mutable_shape_do_not_use()->Swap(&shape); return Status::OK(); } diff --git a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc index 1dd454ea8d57e21526e5bcde0c8efc5514983b93..584417bc72c8f6645c05912e857b031cfb394e54 100644 --- a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc +++ b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc @@ -37,27 +37,14 @@ namespace { // Returns a vector of positional argument buffer sizes. xla::StatusOr> ComputeArgSizes( - const xla::ProgramShape& program_shape, bool requires_runtime_context) { + const xla::ProgramShape& program_shape) { std::vector arg_sizes; const size_t num_args = program_shape.parameters_size(); arg_sizes.reserve(num_args); for (int i = 0; i < num_args; ++i) { const xla::Shape& arg_shape = program_shape.parameters(i); - if (i == num_args - 1 && requires_runtime_context) { - // If the compiled function needs an XlaLocalRuntimeContext* arg, it's - // always last, and must be represented as an opaque type. - const xla::PrimitiveType type = arg_shape.element_type(); - if (type != xla::OPAQUE) { - return errors::InvalidArgument( - "expected final context arg to be opaque, but got type: ", - xla::PrimitiveType_Name(type), ", from program shape: ", - xla::ShapeUtil::HumanString(program_shape)); - } - arg_sizes.push_back(-1); - } else { - constexpr size_t kPointerSize = sizeof(void*); - arg_sizes.push_back(xla::ShapeUtil::ByteSizeOf(arg_shape, kPointerSize)); - } + constexpr size_t kPointerSize = sizeof(void*); + arg_sizes.push_back(xla::ShapeUtil::ByteSizeOf(arg_shape, kPointerSize)); } return std::move(arg_sizes); } @@ -90,21 +77,6 @@ xla::StatusOr ComputeResultIndex( return result_slice.index(); } -// Adapt ComputeFunctionType, which includes a final profile_counters arg, to -// RawFunction, which doesn't include that final arg. -// -// TODO(toddw): Change RawFunction and AOT to also pass the final -// profile_counters arg, and remove this adapter. -XlaCompiledCpuFunction::RawFunction RawFunctionAdapter( - xla::cpu::CpuExecutable::ComputeFunctionType compute_function) { - return [compute_function](void* result, - const xla::ExecutableRunOptions* run_options, - const void** args, void** temps) { - return compute_function(result, run_options, args, temps, - /*profile_counters=*/nullptr); - }; -} - // Collect names from `entries`, where T is one of tf2xla::{Feed,Fetch}. We hold // the actual strings in nonempty_names, and hold arrays of pointers in // name_ptrs, terminated by a nullptr entry. @@ -144,9 +116,8 @@ XlaJitCompiledCpuFunction::Compile( TF_ASSIGN_OR_RETURN(xla::LocalClient * client, xla::ClientLibrary::GetOrCreateLocalClient()); xla::Computation computation; - bool requires_runtime_context; - TF_RETURN_IF_ERROR(tensorflow::ConvertGraphDefToXla( - graph_def, config, client, &computation, &requires_runtime_context)); + TF_RETURN_IF_ERROR(tensorflow::ConvertGraphDefToXla(graph_def, config, client, + &computation)); // Get and verify the program shape. TF_ASSIGN_OR_RETURN(std::unique_ptr program_shape, @@ -177,14 +148,13 @@ XlaJitCompiledCpuFunction::Compile( const xla::cpu::CpuExecutable* cpu_executable = static_cast(executable->executable()); XlaCompiledCpuFunction::RawFunction raw_function = - RawFunctionAdapter(cpu_executable->compute_function()); + cpu_executable->compute_function(); const xla::BufferAssignment& buffer_assignment = cpu_executable->buffer_assignment(); // Compute buffer sizes and the result index, needed to run the raw function. - TF_ASSIGN_OR_RETURN( - std::vector arg_sizes, - ComputeArgSizes(*program_shape, requires_runtime_context)); + TF_ASSIGN_OR_RETURN(std::vector arg_sizes, + ComputeArgSizes(*program_shape)); TF_ASSIGN_OR_RETURN(std::vector temp_sizes, ComputeTempSizes(buffer_assignment)); TF_ASSIGN_OR_RETURN(size_t result_index, @@ -203,7 +173,6 @@ XlaJitCompiledCpuFunction::Compile( jit->static_data_.temp_sizes = jit->temp_sizes_.data(); jit->static_data_.num_temps = jit->temp_sizes_.size(); jit->static_data_.result_index = result_index; - jit->static_data_.requires_runtime_context = requires_runtime_context; // Optional metadata is collected and set below. CollectNames(config.feed(), &jit->nonempty_arg_names_, &jit->arg_names_); CollectNames(config.fetch(), &jit->nonempty_result_names_, @@ -211,6 +180,14 @@ XlaJitCompiledCpuFunction::Compile( jit->static_data_.arg_names = jit->arg_names_.data(); jit->static_data_.result_names = jit->result_names_.data(); jit->static_data_.program_shape = jit->program_shape_.get(); + + if (cpu_executable->hlo_profiling_enabled()) { + jit->static_data_.hlo_profile_printer = + &cpu_executable->hlo_profile_printer(); + jit->static_data_.profile_counters_size = + cpu_executable->hlo_profile_printer().profile_counters_size(); + } + return std::move(jit_unique_ptr); } diff --git a/tensorflow/compiler/tf2xla/xla_local_runtime_context.h b/tensorflow/compiler/tf2xla/xla_local_runtime_context.h deleted file mode 100644 index dca420d6ee3fec45f88ac3b450ab0cb4fb83d38a..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/tf2xla/xla_local_runtime_context.h +++ /dev/null @@ -1,55 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_LOCAL_RUNTIME_CONTEXT_H_ -#define TENSORFLOW_COMPILER_TF2XLA_XLA_LOCAL_RUNTIME_CONTEXT_H_ - -#include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/platform/types.h" - -// Forward-declare the ThreadPoolDevice so that it can be ignored unless it's -// actually used. E.g. some ahead-of-time compiled computations don't need a -// thread pool. -namespace Eigen { -struct ThreadPoolDevice; -} - -namespace tensorflow { - -// An instance of this class is passed to each call from tensorflow into a -// compiled XLA computation. See xla_launch_ops.cc. -struct XlaLocalRuntimeContext { - public: - XlaLocalRuntimeContext() {} - - // Kernels implemented using custom call ops set this if they encounter an - // error. The error is checked after the entire XLA computation is - // complete. - // - // error+error_msg are used instead of Status to reduce the binary size - // overhead for ahead-of-time compiled binaries. - bool error = false; - string error_msg; - - // Kernels that need a thread pool can get it from here. - const Eigen::ThreadPoolDevice* thread_pool = nullptr; - - private: - TF_DISALLOW_COPY_AND_ASSIGN(XlaLocalRuntimeContext); -}; - -} // namespace tensorflow - -#endif // TENSORFLOW_COMPILER_TF2XLA_XLA_LOCAL_RUNTIME_CONTEXT_H_ diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index 2b4cc9ba2d62b0e559e1456e6bfe6ab1e094e1df..c0c4251eabcd06d7c84ae76f349d657fa9f6d641 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -118,13 +118,36 @@ Status XlaOpKernelContext::ConstantInputReshaped( std::iota(layout_indices.rbegin(), layout_indices.rend(), 0); xla::Layout layout = xla::LayoutUtil::MakeLayout(layout_indices); + xla::StatusOr is_constant = builder()->IsConstant(handle); + if (!is_constant.ok()) { + Status status = is_constant.status(); + errors::AppendToMessage(&status, "while evaluating input ", index, " of ", + context_->op_kernel().type_string(), + " operator as a compile-time constant."); + return status; + } + + if (!is_constant.ValueOrDie()) { + return errors::InvalidArgument( + "Input ", index, " to ", context_->op_kernel().type_string(), + " operator must be a compile-time constant.\n" + "\n" + "XLA compilation requires that operator arguments that represent " + "shapes or dimensions be evaluated to concrete values at compile time. " + "This error means that a shape or dimension argument could not be " + "evaluated at compile time, usually because the value of the argument " + "depends on a parameter to the computation, on a variable, or on a " + "stateful operation such as a random number generator."); + } + // Ask the XLA compiler to evaluate the data handle to a literal. xla::StatusOr> computed = builder()->ComputeConstant(handle, &layout); if (!computed.ok()) { - return errors::InvalidArgument( - "Error evaluating ", context_->op_kernel().name(), " input ", index, - ": ", computed.status().error_message()); + return errors::Internal("Error evaluating ", context_->op_kernel().name(), + " input ", index, + "as a compile-time constant.\nError: ", + computed.status().error_message()); } *constant_literal = std::move(*computed.ValueOrDie()); @@ -206,15 +229,15 @@ Status XlaOpKernelContext::ConstantInputAsInt64Literal(int index, xla::Literal literal; TF_RETURN_IF_ERROR(ConstantInput(index, &literal)); switch (literal.shape().element_type()) { - case xla::S32: - out->Clear(); - *out->mutable_shape() = literal.shape(); - out->mutable_shape()->set_element_type(xla::S64); - for (int32 x : literal.s32s()) { - out->add_s64s(x); + case xla::S32: { + *out = xla::Literal( + xla::ShapeUtil::ChangeElementType(literal.shape(), xla::S64)); + auto src_data = literal.data(); + for (int64 i = 0; i < src_data.size(); ++i) { + out->data()[i] = src_data[i]; } return Status::OK(); - + } case xla::S64: *out = std::move(literal); return Status::OK(); @@ -268,12 +291,12 @@ Status XlaOpKernelContext::ReadVariableInput( const XlaExpression* expression = CastExpressionFromTensor(tensor); XlaResource* variable = expression->resource(); TF_RET_CHECK(variable != nullptr); - TF_RET_CHECK(variable->kind == XlaResource::kVariable); - if (variable->value.handle() == 0) { + TF_RET_CHECK(variable->kind() == XlaResource::kVariable); + if (!variable->initialized()) { return errors::InvalidArgument("Read of uninitialized variable ", - variable->name); + variable->name()); } - *value = variable->value; + *value = variable->value(); return Status::OK(); } @@ -283,13 +306,13 @@ Status XlaOpKernelContext::GetVariableTypeAndShape(int index, DataType* type, const XlaExpression* expression = CastExpressionFromTensor(tensor); XlaResource* variable = expression->resource(); TF_RET_CHECK(variable != nullptr); - TF_RET_CHECK(variable->kind == XlaResource::kVariable); - if (variable->value.handle() == 0) { + TF_RET_CHECK(variable->kind() == XlaResource::kVariable); + if (!variable->initialized()) { return errors::InvalidArgument("Read of uninitialized variable ", - variable->name); + variable->name()); } - *type = variable->type; - auto shape_or_status = builder()->GetShape(variable->value); + *type = variable->type(); + auto shape_or_status = builder()->GetShape(variable->value()); if (!shape_or_status.ok()) { return shape_or_status.status(); } @@ -381,16 +404,8 @@ Status XlaOpKernelContext::AssignVariable( CastExpressionFromTensor(context_->input(input_index)); XlaResource* variable = expression->resource(); TF_RET_CHECK(variable != nullptr); - TF_RET_CHECK(variable->kind == XlaResource::kVariable); - if (!((variable->type == DT_INVALID && type != DT_INVALID) || - (variable->type == type))) { - return errors::InvalidArgument( - "Types of variables cannot change after initialization: old type was ", - DataTypeString(variable->type), ", new type is ", DataTypeString(type)); - } - variable->type = type; - variable->value = handle; - return Status::OK(); + TF_RET_CHECK(variable->kind() == XlaResource::kVariable); + return variable->SetValue(type, handle); } XlaCompiler* XlaOpKernelContext::compiler() const { @@ -417,6 +432,11 @@ const xla::Computation* XlaOpKernelContext::GetOrCreateAdd( return XlaContext::Get(context_).GetOrCreateAdd(type); } +const xla::Computation* XlaOpKernelContext::GetOrCreateMul( + const DataType type) { + return XlaContext::Get(context_).GetOrCreateMul(type); +} + XlaOpKernel::XlaOpKernel(OpKernelConstruction* context) : OpKernel(context) {} void XlaOpKernel::Compute(OpKernelContext* context) { diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h index 76bcf594e6a0601763844847583c18ee26d8adf3..f1ae81a5aa9d507a3e0dd577568377385b1844e6 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.h +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h @@ -178,7 +178,7 @@ class XlaOpKernelContext { // If this kernel invocation is within a function execution, // call_frame() returns the call frame for the function call. - FunctionCallFrame* call_frame() const { return context_->call_frame(); } + CallFrameInterface* call_frame() const { return context_->call_frame(); } FunctionLibraryRuntime* function_library() const { return context_->function_library(); @@ -210,6 +210,11 @@ class XlaOpKernelContext { // separate specialization of the computation for each DataType. const xla::Computation* GetOrCreateAdd(const DataType type); + // Gets an XLA lambda to compute Mul. This is cached in the + // XlaContext since it may be used by multiple Ops. There is a + // separate specialization of the computation for each DataType. + const xla::Computation* GetOrCreateMul(const DataType type); + private: OpKernelContext* const context_; }; diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.cc b/tensorflow/compiler/tf2xla/xla_op_registry.cc index 02318cf7fa1d4edc12507f6b4d66a8e897cbe100..0dde6a986c61bdd5b0b2e6d7a16b29ab95be98ab 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.cc +++ b/tensorflow/compiler/tf2xla/xla_op_registry.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/core/framework/device_base.h" #include "tensorflow/core/framework/kernel_def.pb.h" #include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/op_def_util.h" #include "tensorflow/core/platform/mem.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -82,6 +83,11 @@ XlaOpRegistry::~XlaOpRegistry() = default; return false; } } + if (x.compile_time_constant_inputs != y.compile_time_constant_inputs) { + LOG(WARNING) << "Registrations of " << x.name + << " have incompatible compile time constant inputs."; + return false; + } return true; } @@ -155,7 +161,14 @@ void XlaOpRegistry::RegisterCompilationKernels() { const string& op_name = op.first; const std::unique_ptr& op_registration = op.second; const OpDef* op_def; - TF_CHECK_OK(op_registry->LookUpOpDef(op_name, &op_def)); + Status lookup_status = op_registry->LookUpOpDef(op_name, &op_def); + if (!lookup_status.ok()) { + LOG(ERROR) << lookup_status.error_message(); + XLA_LOG_LINES( + ERROR, "Ops registered: \n" + + dynamic_cast(op_registry)->DebugString(true)); + } + TF_CHECK_OK(lookup_status); std::unordered_set type_attrs; for (const OpDef::AttrDef& attr_def : op_def->attr()) { @@ -187,22 +200,39 @@ void XlaOpRegistry::RegisterCompilationKernels() { // Constrain each type attribute to the intersection of: // a) the types supported by the backend, and - // b) the attribute's type constraints. - // TODO(phawkins): it may be necessary to also take the intersection with - // the set of types supported by the OpDef. + // b) the types allowed by the OpDef, and + // c) the type constraints. for (const string& type_attr : type_attrs) { KernelDef::AttrConstraint* attr_constraint = kdef->add_constraint(); attr_constraint->set_name(type_attr); auto* allowed_values = attr_constraint->mutable_allowed_values()->mutable_list(); - auto it = op_registration->type_constraints.find(type_attr); + const OpDef::AttrDef& op_def_attr = *FindAttr(type_attr, *op_def); + const auto* op_def_allowed_types = + op_def_attr.has_allowed_values() + ? &op_def_attr.allowed_values().list().type() + : nullptr; + auto constraint_it = op_registration->type_constraints.find(type_attr); + const std::set* type_constraints = + constraint_it != op_registration->type_constraints.end() + ? &constraint_it->second + : nullptr; for (DataType dtype : backend.second.supported_types) { - if (it == op_registration->type_constraints.end() || - (it != op_registration->type_constraints.end() && - it->second.find(dtype) != it->second.end())) { - allowed_values->add_type(dtype); + // Filter out types that aren't allowed by the OpDef. + if (op_def_allowed_types != nullptr && + std::find(op_def_allowed_types->begin(), + op_def_allowed_types->end(), + dtype) == op_def_allowed_types->end()) { + continue; + } + // Filter out types based on the type constraints. + if (type_constraints != nullptr && + type_constraints->find(dtype) == type_constraints->end()) { + continue; } + // Passed all the filters, this type is allowed. + allowed_values->add_type(dtype); } if (op_registration->allow_resource_types) { allowed_values->add_type(DT_RESOURCE); @@ -245,6 +275,33 @@ std::vector XlaOpRegistry::DeviceKernels( return kernels; } +/* static */ const std::unordered_set* +XlaOpRegistry::CompileTimeConstantInputs(const string& op) { + XlaOpRegistry& registry = Instance(); + mutex_lock lock(registry.mutex_); + auto it = registry.ops_.find(op); + if (it == registry.ops_.end()) { + return nullptr; + } + return &it->second->compile_time_constant_inputs; +} + +std::vector XlaOpRegistry::BackendNames() { + std::vector names; + XlaOpRegistry& registry = Instance(); + mutex_lock lock(registry.mutex_); + for (const auto& backend_pair : registry.backends_) { + names.push_back(backend_pair.first); + } + return names; +} + +bool XlaOpRegistry::IsBackendRegistered(const string& name) { + XlaOpRegistry& registry = Instance(); + mutex_lock lock(registry.mutex_); + return registry.backends_.find(name) != registry.backends_.end(); +} + XlaOpRegistry& XlaOpRegistry::Instance() { static XlaOpRegistry* r = new XlaOpRegistry; return *r; @@ -303,6 +360,12 @@ XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::TypeConstraint( return *this; } +XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::CompileTimeConstInput( + StringPiece input_name) { + registration_->compile_time_constant_inputs.insert(input_name.ToString()); + return *this; +} + std::unique_ptr XlaOpRegistrationBuilder::Build( XlaOpRegistry::Factory factory) { registration_->factory = factory; diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.h b/tensorflow/compiler/tf2xla/xla_op_registry.h index 6aee8c91cc01b4382ef867fa8e438eede008ac73..ff7453194af3a85bded86a5ce298f8779422dccb 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.h +++ b/tensorflow/compiler/tf2xla/xla_op_registry.h @@ -45,11 +45,11 @@ extern const char* const DEVICE_GPU_XLA_JIT; // "GPU_XLA_JIT" extern const char* const DEVICE_XLA_CPU; extern const char* const DEVICE_XLA_GPU; -constexpr std::array kFloatTypes = { - {DT_HALF, DT_FLOAT, DT_DOUBLE}}; -constexpr std::array kNumericTypes = { +constexpr std::array kFloatTypes = { + {DT_HALF, DT_FLOAT, DT_DOUBLE, DT_BFLOAT16}}; +constexpr std::array kNumericTypes = { {DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, - DT_COMPLEX64}}; + DT_COMPLEX64, DT_BFLOAT16}}; constexpr std::array kCpuAllTypes = { {DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE, @@ -97,6 +97,12 @@ class XlaOpRegistry { gtl::ArraySlice supported_types, BackendOpFilter op_filter); + // Returns the names of the registered backends. + static std::vector BackendNames(); + + // Returns true iff a backend with the given name is registered. + static bool IsBackendRegistered(const string& name); + // Registers `device_name` for XLA compilation, using information from // `registration`. static void RegisterCompilationDevice(const string& device_name, @@ -116,12 +122,17 @@ class XlaOpRegistry { static void RegisterCompilationKernels(); // Returns KernelDefs for compilation ops registered on - // 'compilation_device_name'. - // Does not include kernels registered as CompilationOnly. + // 'compilation_device_name'. Does not include kernels registered as + // CompilationOnly, iff include_compilation_only_kernels=false. static std::vector DeviceKernels( const string& compilation_device_name, bool include_compilation_only_kernels); + // Returns the set of compile-time constant inputs to 'op'. Returns nullptr + // if the op is not registered. + static const std::unordered_set* CompileTimeConstantInputs( + const string& op); + private: friend class XlaBackendRegistrar; friend class XlaOpRegistrar; @@ -175,6 +186,9 @@ class XlaOpRegistry { bool has_device_whitelist = false; std::unordered_set device_whitelist; + // Names of arguments that must be compile-time constants. + std::unordered_set compile_time_constant_inputs; + // Factory used to build OpKernels that perform symbolic execution. Factory factory; }; @@ -236,6 +250,9 @@ class XlaOpRegistrationBuilder { // Allow DT_RESOURCE types for type parameters. XlaOpRegistrationBuilder& AllowResourceTypes(); + // Mark 'input_name' as an argument whose value must be known at compile-time. + XlaOpRegistrationBuilder& CompileTimeConstInput(StringPiece input_name); + std::unique_ptr Build( XlaOpRegistry::Factory factory); diff --git a/tensorflow/compiler/tf2xla/xla_resource.cc b/tensorflow/compiler/tf2xla/xla_resource.cc new file mode 100644 index 0000000000000000000000000000000000000000..9abac8bdaa77c99a57b2f8ac66fe6ed06fbcd102 --- /dev/null +++ b/tensorflow/compiler/tf2xla/xla_resource.cc @@ -0,0 +1,157 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/xla_resource.h" + +#include +#include + +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/sharding_util.h" +#include "tensorflow/compiler/tf2xla/xla_context.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" + +namespace tensorflow { + +XlaResource::XlaResource(Kind kind, int arg_num, string name, + DataType initial_type, + const xla::ComputationDataHandle& initial_value) + : kind_(kind), + arg_num_(arg_num), + name_(std::move(name)), + type_(initial_type), + value_(initial_value), + initial_value_(initial_value) { + CHECK(kind_ != kInvalid); +} + +Status XlaResource::SetValue(DataType type, + const xla::ComputationDataHandle& value) { + if (type_ == DT_INVALID && type == DT_INVALID) { + return errors::InvalidArgument("Attempted to initialized resource ", name_, + " to an invalid type"); + } + if (type_ != DT_INVALID && type_ != type) { + return errors::InvalidArgument("Type of resource ", name_, + " cannot be changed after initialization: " + "old type was ", + DataTypeString(type_), ", new type is ", + DataTypeString(type)); + } + type_ = type; + value_ = value; + return Status::OK(); +} + +Status XlaResource::GetXlaShape(xla::ComputationBuilder* builder, + xla::Shape* shape) const { + auto shape_or_status = builder->GetShape(value_); + if (!shape_or_status.ok()) { + return shape_or_status.status(); + } + *shape = *shape_or_status.ValueOrDie(); + return Status::OK(); +} + +Status XlaResource::GetShape(xla::ComputationBuilder* builder, + TensorShape* shape) const { + xla::Shape xla_shape; + TF_RETURN_IF_ERROR(GetXlaShape(builder, &xla_shape)); + TF_RETURN_IF_ERROR(XLAShapeToTensorShape(xla_shape, shape)); + return Status::OK(); +} + +Status XlaResource::GetOrCreateTensorArrayGradient( + const string& source, xla::ComputationBuilder* builder, + XlaResource** gradient_out) { + VLOG(2) << "Gradient lookup for resource: " << name_ + << " gradient: " << source; + TF_RET_CHECK(kind_ == kTensorArray); + std::unique_ptr& gradient = tensor_array_gradients_[source]; + if (!gradient) { + TensorShape ta_shape; + TF_RETURN_IF_ERROR(GetShape(builder, &ta_shape)); + xla::ComputationDataHandle gradient_value = builder->Broadcast( + XlaHelpers::Zero(builder, type_), ta_shape.dim_sizes()); + gradient.reset( + new XlaResource(/*kind=*/kTensorArray, /*arg_num=*/-1, + /*name=*/strings::StrCat("TensorArrayGrad: ", name_), + type_, gradient_value)); + gradient->tensor_array_size_ = tensor_array_size_; + } + *gradient_out = gradient.get(); + return Status::OK(); +} + +Status XlaResource::PackedShape(xla::ComputationBuilder* builder, + xla::Shape* packed_shape) const { + if (tensor_array_gradients_.empty()) { + return GetXlaShape(builder, packed_shape); + } + TF_RET_CHECK(kind_ == kTensorArray); + std::vector elem_shapes(1 + tensor_array_gradients_.size()); + int pos = 0; + TF_RETURN_IF_ERROR(GetXlaShape(builder, &elem_shapes[pos++])); + for (const auto& gradient : tensor_array_gradients_) { + TF_RETURN_IF_ERROR( + gradient.second->GetXlaShape(builder, &elem_shapes[pos++])); + } + *packed_shape = xla::ShapeUtil::MakeTupleShape(elem_shapes); + return Status::OK(); +} + +Status XlaResource::Pack(xla::ComputationDataHandle* pack, + xla::ComputationBuilder* builder) const { + if (tensor_array_gradients_.empty()) { + *pack = value_; + } else { + TF_RET_CHECK(kind_ == kTensorArray); + std::vector elems; + elems.push_back(value_); + for (const auto& gradient : tensor_array_gradients_) { + elems.push_back(gradient.second->value_); + } + *pack = builder->Tuple(elems); + } + return Status::OK(); +} + +Status XlaResource::SetFromPack(const std::set& gradient_sources, + const xla::ComputationDataHandle& pack, + bool reset_initial_values, + xla::ComputationBuilder* builder) { + if (gradient_sources.empty()) { + value_ = pack; + } else { + TF_RET_CHECK(kind_ == kTensorArray); + int pos = 0; + value_ = builder->GetTupleElement(pack, pos++); + for (const auto& source : gradient_sources) { + XlaResource* gradient; + TF_RETURN_IF_ERROR( + GetOrCreateTensorArrayGradient(source, builder, &gradient)); + gradient->value_ = builder->GetTupleElement(pack, pos++); + if (reset_initial_values) { + gradient->initial_value_ = gradient->value_; + } + } + } + if (reset_initial_values) { + initial_value_ = value_; + } + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_resource.h b/tensorflow/compiler/tf2xla/xla_resource.h new file mode 100644 index 0000000000000000000000000000000000000000..6b46089e4f5e10c195bb59f78c33305c2fa3f84d --- /dev/null +++ b/tensorflow/compiler/tf2xla/xla_resource.h @@ -0,0 +1,149 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_RESOURCE_H_ +#define TENSORFLOW_COMPILER_TF2XLA_XLA_RESOURCE_H_ + +#include + +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +// Represents a resource, such as a Variable or TensorArray. +class XlaResource { + public: + enum Kind { + kInvalid, + kVariable, + kTensorArray, + kStack, + }; + + XlaResource(Kind kind, int arg_num, string name, DataType initial_type, + const xla::ComputationDataHandle& initial_value); + + XlaResource(const XlaResource&) = delete; + XlaResource(XlaResource&&) = delete; + XlaResource& operator=(const XlaResource&) = delete; + XlaResource& operator=(XlaResource&&) = delete; + + Kind kind() const { return kind_; } + + // If this resource is visible externally to the computation, what was its + // argument number? + // < 0 means "not visible externally". + int arg_num() const { return arg_num_; } + + // A descriptive name for the resource, used in error messages. + const string& name() const { return name_; } + + // Current type and value of the resource. Uninitialized resources are + // represented by a default (zero) handle and type DT_INVALID. + // While the type of a resource is notionally fixed during execution, when + // a resource is first initialized we do not yet know its type, so we keep + // track of its type dynamically. + DataType type() const { return type_; } + const xla::ComputationDataHandle& value() const { return value_; } + + // Value of the resource at computation entry. Used to detect which + // variables have new values that need to be written back. + const xla::ComputationDataHandle& initial_value() const { + return initial_value_; + } + + bool initialized() const { return value_.handle() > 0; } + + // Sets the current type/value of the resource. + Status SetValue(DataType type, const xla::ComputationDataHandle& value); + + // Returns the shape of the resource as an xla::Shape. + Status GetXlaShape(xla::ComputationBuilder* builder, xla::Shape* shape) const; + + // Returns the shape of the resource as an TensorShape. Fails if the shape is + // not representable as a TensorShape. + Status GetShape(xla::ComputationBuilder* builder, TensorShape* shape) const; + + // Looks up the gradient for `source`, or creates it if it does not already + // exist. The call target must be an initialized TensorArray resource. A + // TensorArray can have multiple named gradients; see the operator + // documentation for TensorArrayGradV3 for details. + Status GetOrCreateTensorArrayGradient(const string& source, + xla::ComputationBuilder* builder, + XlaResource** gradient_out); + + // Packs a resource into a single XLA value `pack`, suitable for use as + // an XlaCompiler::Argument. For non-TensorArrays or TensorArrays without + // gradients, sets `*pack` to `value`. + // For TensorArrays with gradients, packs the value and its gradient values in + // a tuple; the gradients values are packed in order by source name. + Status Pack(xla::ComputationDataHandle* pack, + xla::ComputationBuilder* builder) const; + + // Returns the shape of the `pack` value computed by `Pack()`. + Status PackedShape(xla::ComputationBuilder* builder, + xla::Shape* packed_shape) const; + + // Updates the resource with values from `pack`. If `gradient_sources` is + // non-empty, treats `pack` as a tuple that represents a TensorArray and + // its gradients, and unpacks and updates the gradient resources. + // If `reset_initial_values` is true, sets the initial_values as well as the + // values. + // Opposite of Pack(). + Status SetFromPack(const std::set& gradient_sources, + const xla::ComputationDataHandle& pack, + bool reset_initial_values, + xla::ComputationBuilder* builder); + + // TensorArray-specific fields + + // 'tensor_array_size' stores the expected size of the TensorArray or Stack. + // We need to store this since sometimes TensorArrays must be initialized + // lazily since we do not know the element shape at construction time. + int64 tensor_array_size() const { return tensor_array_size_; } + void set_tensor_array_size(int64 size) { tensor_array_size_ = size; } + + // 'tensor_array_gradient' is a map from TensorArrayGradV3 'source' attributes + // to an XlaResource containing the gradient TensorArrays. We store a pointer + // here since there should only be one gradient TensorArray per 'source' + // string, irrespective of the number of calls to TensorArrayGrad. The map + // is ordered since values are packed into tuples by Pack() sorted by name + // order. + const std::map>& tensor_array_gradients() + const { + return tensor_array_gradients_; + } + + private: + const Kind kind_; + const int arg_num_; + const string name_; + + DataType type_; + xla::ComputationDataHandle value_; + xla::ComputationDataHandle initial_value_; + + int64 tensor_array_size_ = -1; + + std::map> tensor_array_gradients_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_XLA_RESOURCE_H_ diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index d3f292207fee396fb4248dede5c0eeb5cd2b87c9..438f1443f17717a3806827abcb36d4ccbbbf756c 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -20,6 +20,10 @@ package_group( load("//tensorflow:tensorflow.bzl", "cc_header_only_library") load("//tensorflow:tensorflow.bzl", "tf_cc_test") load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library") +load( + "//tensorflow/core:platform/default/build_config.bzl", + "tf_proto_library_py", +) # Filegroup used to collect source files for dependency checking. filegroup( @@ -36,6 +40,12 @@ xla_proto_library( visibility = ["//visibility:public"], ) +tf_proto_library_py( + name = "xla_data_proto", # bzl adds a _py suffix + srcs = ["xla_data.proto"], + visibility = ["//visibility:public"], +) + xla_proto_library( name = "xla_proto", srcs = ["xla.proto"], @@ -250,6 +260,7 @@ tf_cc_test( srcs = ["shape_util_test.cc"], deps = [ ":shape_util", + ":status_macros", ":test", ":test_helpers", ":types", @@ -290,7 +301,9 @@ cc_library( ":array2d", ":array3d", ":array4d", + ":shape_tree", ":shape_util", + ":sparse_index_array", ":status_macros", ":types", ":util", @@ -617,6 +630,28 @@ tf_cc_test( ], ) +cc_library( + name = "sparse_index_array", + srcs = ["sparse_index_array.cc"], + hdrs = ["sparse_index_array.h"], + deps = [ + ":array2d", + ":shape_util", + ":xla_data_proto", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "sparse_index_array_test", + srcs = ["sparse_index_array_test.cc"], + deps = [ + ":sparse_index_array", + ":test", + "//tensorflow/core:test_main", + ], +) + # ----------------------------------------------------------------------------- filegroup( diff --git a/tensorflow/compiler/xla/array.h b/tensorflow/compiler/xla/array.h index 213e0bac6c77e9972de8d4dd7dfc8c7cf3a1b865..71aa057cd3a1c273c0e851497a78f94ba37c778e 100644 --- a/tensorflow/compiler/xla/array.h +++ b/tensorflow/compiler/xla/array.h @@ -22,6 +22,7 @@ limitations under the License. #include #include #include +#include #include #include #include diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD index f953407a567b91fdf6ae727d6982a2a778c5873e..d6b4ebfc39ae039ff27fe9fb8a3487c870832f3e 100644 --- a/tensorflow/compiler/xla/client/BUILD +++ b/tensorflow/compiler/xla/client/BUILD @@ -186,6 +186,20 @@ cc_library( ], ) +cc_library( + name = "sharding_builder", + srcs = ["sharding_builder.cc"], + hdrs = ["sharding_builder.h"], + deps = [ + "//tensorflow/compiler/xla:array", + "//tensorflow/compiler/xla:shape_tree", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + ], +) + # ----------------------------------------------------------------------------- filegroup( diff --git a/tensorflow/compiler/xla/client/client.cc b/tensorflow/compiler/xla/client/client.cc index 66937d64aff18817bbd5310e0c24e19556e9d727..d15ccb0c28522c647617153aaa8e738d029dfaba 100644 --- a/tensorflow/compiler/xla/client/client.cc +++ b/tensorflow/compiler/xla/client/client.cc @@ -60,7 +60,7 @@ StatusOr> Client::Transfer( "server provided response without a literal in " "TransferToClient request"); } - return MakeUnique(response.literal()); + return Literal::CreateFromProto(*response.mutable_literal()); } StatusOr> Client::TransferToServer( @@ -142,7 +142,7 @@ StatusOr> Client::TransferFromOutfeed( "TransferToClient request"); } - return MakeUnique(response.literal()); + return Literal::CreateFromProto(response.literal()); } Status Client::ResetDevice() { diff --git a/tensorflow/compiler/xla/client/computation_builder.cc b/tensorflow/compiler/xla/client/computation_builder.cc index cce931000331e98b00f57025cb13a5d3982c2845..46f2ed4836eda6bf6d5b68f2e29ac6888cd1749b 100644 --- a/tensorflow/compiler/xla/client/computation_builder.cc +++ b/tensorflow/compiler/xla/client/computation_builder.cc @@ -34,25 +34,9 @@ limitations under the License. namespace xla { -ComputationDataHandle ComputationBuilder::ParseOpResponse( - const Status& status, OpResponse* response) { - VLOG(2) << "done with op request"; - - if (!status.ok()) { - NoteError(status); - return ComputationDataHandle(); - } - - if (response->output().handle() == 0) { - NoteError(InternalError("No output handle")); - return ComputationDataHandle(); - } - return response->output(); -} - ComputationBuilder::ComputationBuilder(Client* client, const string& computation_name) - : name_(computation_name), first_error_(Status::OK()), client_(client) {} + : name_(computation_name), client_(client) {} ComputationBuilder::~ComputationBuilder() {} @@ -76,9 +60,8 @@ std::unique_ptr ComputationBuilder::CreateSubBuilder( } Status ComputationBuilder::PrepareComputation() { - if (!first_error_.ok()) { - return first_error_; - } + TF_RETURN_IF_ERROR(first_error_); + if (!computation_.IsNull()) { return Status::OK(); } @@ -100,6 +83,49 @@ Status ComputationBuilder::PrepareComputation() { return Status::OK(); } +Status ComputationBuilder::RunOp(OpRequest* op_request, + OpResponse* op_response) { + TF_RETURN_IF_ERROR(first_error_); + TF_RETURN_IF_ERROR(PrepareComputation()); + + // Fill in fields that are set on every OpRequest. + *op_request->mutable_computation() = computation_.handle(); + *op_request->mutable_metadata() = metadata_; + if (sharding_) { + *op_request->mutable_sharding() = *sharding_; + } + + const string& op_name = + OpRequest::descriptor()->FindFieldByNumber(op_request->op_case())->name(); + VLOG(2) << "running op request: " << op_name; + Status status = client_->stub()->Op(op_request, op_response); + VLOG(2) << "done with op request: " << op_name; + return status; +} + +void ComputationBuilder::RunOpAndNoteError(OpRequest* op_request) { + OpResponse op_response; + Status status = RunOp(op_request, &op_response); + if (!status.ok()) { + NoteError(status); + } +} + +ComputationDataHandle ComputationBuilder::RunOpAndParseResponse( + OpRequest* op_request) { + OpResponse op_response; + Status status = RunOp(op_request, &op_response); + if (!status.ok()) { + NoteError(status); + return ComputationDataHandle(); + } + if (op_response.output().handle() == 0) { + NoteError(InternalError("No output handle")); + return ComputationDataHandle(); + } + return op_response.output(); +} + bool ComputationBuilder::MakeWindow( tensorflow::gtl::ArraySlice window_dimensions, tensorflow::gtl::ArraySlice window_strides, @@ -158,81 +184,55 @@ bool ComputationBuilder::MakeWindow( return true; } -ComputationDataHandle ComputationBuilder::ConstantOp( - const PopulateLiteral& populate) { - if (!first_error_.ok() || !PrepareComputation().ok()) { - return ComputationDataHandle(); - } - - ConstantRequest request; - Literal literal; - populate(&literal); - *request.mutable_literal() = literal.ToProto(); - VLOG(3) << "created constant: " << request.literal().ShortDebugString(); - OpRequest op_request; - *op_request.mutable_constant_request() = request; - *op_request.mutable_computation() = computation_.handle(); - AddCommonFieldsToOpRequest(&op_request); - OpResponse response; - - VLOG(2) << "making constant request"; - Status s = client_->stub()->Op(&op_request, &response); - return ParseOpResponse(s, &response); -} - ComputationDataHandle ComputationBuilder::ConstantLiteral( const Literal& literal) { - return ConstantOp( - [literal](Literal* mutable_literal) { *mutable_literal = literal; }); + OpRequest op_request; + ConstantRequest* request = op_request.mutable_constant_request(); + *request->mutable_literal() = literal.ToProto(); + VLOG(3) << "created constant: " << request->literal().ShortDebugString(); + return RunOpAndParseResponse(&op_request); } ComputationDataHandle ComputationBuilder::Parameter(int64 parameter_number, const Shape& shape, const string& name) { - if (!first_error_.ok() || !PrepareComputation().ok()) { - return ComputationDataHandle(); - } - - ParameterRequest request; - *request.mutable_shape() = shape; - request.set_parameter(parameter_number); - request.set_name(name); OpRequest op_request; - *op_request.mutable_computation() = computation_.handle(); - *op_request.mutable_parameter_request() = request; - AddCommonFieldsToOpRequest(&op_request); - OpResponse response; - - VLOG(2) << "making parameter request"; - Status s = client_->stub()->Op(&op_request, &response); - return ParseOpResponse(s, &response); + ParameterRequest* request = op_request.mutable_parameter_request(); + *request->mutable_shape() = shape; + request->set_parameter(parameter_number); + request->set_name(name); + return RunOpAndParseResponse(&op_request); } -StatusOr> ComputationBuilder::GetShape( +StatusOr> ComputationBuilder::GetShapeWithoutNoteError( const ComputationDataHandle& operand) { - if (!first_error_.ok()) { - return first_error_; - } - GetLocalShapeRequest request; *request.mutable_computation() = computation_.handle(); *request.mutable_operand() = operand; GetLocalShapeResponse response; VLOG(2) << "making get-shape request"; - Status s = client_->stub()->GetLocalShape(&request, &response); + TF_RETURN_IF_ERROR(client_->stub()->GetLocalShape(&request, &response)); VLOG(2) << "done with request"; - if (!s.ok()) { - NoteError(s); - return first_error_; - } TF_RET_CHECK(response.has_shape()); std::unique_ptr shape = WrapUnique(response.release_shape()); TF_RET_CHECK(shape != nullptr); return std::move(shape); } +StatusOr> ComputationBuilder::GetShape( + const ComputationDataHandle& operand) { + TF_RETURN_IF_ERROR(first_error_); + + auto status_or_shape = GetShapeWithoutNoteError(operand); + if (!status_or_shape.ok()) { + NoteError(status_or_shape.status()); + return first_error_; + } + return status_or_shape; +} + ComputationDataHandle ComputationBuilder::CheckShape( const ComputationDataHandle& operand, const Shape& expected_shape) { std::unique_ptr actual_shape = GetShape(operand).ConsumeValueOrDie(); @@ -258,30 +258,19 @@ ComputationDataHandle ComputationBuilder::Slice( tensorflow::gtl::ArraySlice start_indices, tensorflow::gtl::ArraySlice limit_indices, tensorflow::gtl::ArraySlice strides) { - if (!first_error_.ok() || !PrepareComputation().ok()) { - return ComputationDataHandle(); - } - - SliceRequest request; - *request.mutable_operand() = operand; + OpRequest op_request; + SliceRequest* request = op_request.mutable_slice_request(); + *request->mutable_operand() = operand; for (int64 index : start_indices) { - request.add_start_indices(index); + request->add_start_indices(index); } for (int64 index : limit_indices) { - request.add_limit_indices(index); + request->add_limit_indices(index); } for (int64 index : strides) { - request.add_strides(index); + request->add_strides(index); } - OpRequest op_request; - *op_request.mutable_computation() = computation_.handle(); - *op_request.mutable_slice_request() = request; - AddCommonFieldsToOpRequest(&op_request); - OpResponse response; - - VLOG(2) << "making slice request"; - Status s = client_->stub()->Op(&op_request, &response); - return ParseOpResponse(s, &response); + return RunOpAndParseResponse(&op_request); } ComputationDataHandle ComputationBuilder::SliceInDim( @@ -307,143 +296,78 @@ ComputationDataHandle ComputationBuilder::DynamicSlice( const ComputationDataHandle& operand, const ComputationDataHandle& start_indices, tensorflow::gtl::ArraySlice slice_sizes) { - if (!first_error_.ok() || !PrepareComputation().ok()) { - return ComputationDataHandle(); - } - - DynamicSliceRequest request; - *request.mutable_operand() = operand; - *request.mutable_start_indices() = start_indices; + OpRequest op_request; + DynamicSliceRequest* request = op_request.mutable_dynamic_slice_request(); + *request->mutable_operand() = operand; + *request->mutable_start_indices() = start_indices; for (int64 index : slice_sizes) { - request.add_slice_sizes(index); + request->add_slice_sizes(index); } - OpRequest op_request; - *op_request.mutable_computation() = computation_.handle(); - *op_request.mutable_dynamic_slice_request() = request; - AddCommonFieldsToOpRequest(&op_request); - OpResponse response; - - VLOG(2) << "making dynamic slice request"; - Status s = client_->stub()->Op(&op_request, &response); - return ParseOpResponse(s, &response); + return RunOpAndParseResponse(&op_request); } ComputationDataHandle ComputationBuilder::DynamicUpdateSlice( const ComputationDataHandle& operand, const ComputationDataHandle& update, const ComputationDataHandle& start_indices) { - if (!first_error_.ok() || !PrepareComputation().ok()) { - return ComputationDataHandle(); - } - - DynamicUpdateSliceRequest request; - *request.mutable_operand() = operand; - *request.mutable_update() = update; - *request.mutable_start_indices() = start_indices; OpRequest op_request; - *op_request.mutable_computation() = computation_.handle(); - *op_request.mutable_dynamic_update_slice_request() = request; - AddCommonFieldsToOpRequest(&op_request); - OpResponse response; - - VLOG(2) << "making dynamic update slice request"; - Status s = client_->stub()->Op(&op_request, &response); - return ParseOpResponse(s, &response); + DynamicUpdateSliceRequest* request = + op_request.mutable_dynamic_update_slice_request(); + *request->mutable_operand() = operand; + *request->mutable_update() = update; + *request->mutable_start_indices() = start_indices; + return RunOpAndParseResponse(&op_request); } ComputationDataHandle ComputationBuilder::ConcatInDim( tensorflow::gtl::ArraySlice operands, int64 dimension) { - if (!first_error_.ok() || !PrepareComputation().ok()) { - return ComputationDataHandle(); - } - - ConcatenateRequest request; + OpRequest op_request; + ConcatenateRequest* request = op_request.mutable_concatenate_request(); for (const ComputationDataHandle& operand : operands) { - *request.add_operands() = operand; + *request->add_operands() = operand; } - request.set_dimension(dimension); - OpRequest op_request; - *op_request.mutable_computation() = computation_.handle(); - *op_request.mutable_concatenate_request() = request; - AddCommonFieldsToOpRequest(&op_request); - OpResponse response; - - VLOG(2) << "making concatenate request"; - Status s = client_->stub()->Op(&op_request, &response); - return ParseOpResponse(s, &response); + request->set_dimension(dimension); + return RunOpAndParseResponse(&op_request); } ComputationDataHandle ComputationBuilder::Broadcast( const ComputationDataHandle& operand, tensorflow::gtl::ArraySlice broadcast_sizes) { - if (!first_error_.ok() || !PrepareComputation().ok()) { - return ComputationDataHandle(); - } - - BroadcastRequest request; - *request.mutable_operand() = operand; + OpRequest op_request; + BroadcastRequest* request = op_request.mutable_broadcast_request(); + *request->mutable_operand() = operand; for (int64 size : broadcast_sizes) { - request.add_broadcast_sizes(size); + request->add_broadcast_sizes(size); } - OpRequest op_request; - *op_request.mutable_computation() = computation_.handle(); - *op_request.mutable_broadcast_request() = request; - AddCommonFieldsToOpRequest(&op_request); - OpResponse response; - - VLOG(2) << "making broadcast request"; - Status s = client_->stub()->Op(&op_request, &response); - return ParseOpResponse(s, &response); + return RunOpAndParseResponse(&op_request); } ComputationDataHandle ComputationBuilder::Pad( const ComputationDataHandle& operand, const ComputationDataHandle& padding_value, const PaddingConfig& padding_config) { - if (!first_error_.ok() || !PrepareComputation().ok()) { - return ComputationDataHandle(); - } - - PadRequest request; - *request.mutable_operand() = operand; - *request.mutable_padding_value() = padding_value; - *request.mutable_padding_config() = padding_config; OpRequest op_request; - *op_request.mutable_computation() = computation_.handle(); - *op_request.mutable_pad_request() = request; - AddCommonFieldsToOpRequest(&op_request); - OpResponse response; - - VLOG(2) << "making pad request"; - Status s = client_->stub()->Op(&op_request, &response); - return ParseOpResponse(s, &response); + PadRequest* request = op_request.mutable_pad_request(); + *request->mutable_operand() = operand; + *request->mutable_padding_value() = padding_value; + *request->mutable_padding_config() = padding_config; + return RunOpAndParseResponse(&op_request); } ComputationDataHandle ComputationBuilder::Reshape( const ComputationDataHandle& operand, tensorflow::gtl::ArraySlice dimensions, tensorflow::gtl::ArraySlice new_sizes) { - if (!first_error_.ok() || !PrepareComputation().ok()) { - return ComputationDataHandle(); - } - - ReshapeRequest request; - *request.mutable_operand() = operand; + OpRequest op_request; + ReshapeRequest* request = op_request.mutable_reshape_request(); + *request->mutable_operand() = operand; for (int64 dimension : dimensions) { - request.add_dimensions(dimension); + request->add_dimensions(dimension); } for (int64 new_size : new_sizes) { - request.add_new_sizes(new_size); + request->add_new_sizes(new_size); } - OpRequest op_request; - *op_request.mutable_computation() = computation_.handle(); - *op_request.mutable_reshape_request() = request; - AddCommonFieldsToOpRequest(&op_request); - OpResponse response; - - VLOG(2) << "making reshape request"; - Status s = client_->stub()->Op(&op_request, &response); - return ParseOpResponse(s, &response); + return RunOpAndParseResponse(&op_request); } ComputationDataHandle ComputationBuilder::Reshape( @@ -455,7 +379,6 @@ ComputationDataHandle ComputationBuilder::Reshape( StatusOr> shape = GetShape(operand); if (!shape.ok()) { - first_error_ = shape.status(); return ComputationDataHandle(); } std::vector dimensions(shape.ValueOrDie()->dimensions().size()); @@ -485,7 +408,6 @@ ComputationDataHandle ComputationBuilder::Collapse( // dimensions by the product of their sizes. StatusOr> shape_or_status = GetShape(operand); if (!shape_or_status.ok()) { - first_error_ = shape_or_status.status(); return ComputationDataHandle(); } std::unique_ptr original_shape = shape_or_status.ConsumeValueOrDie(); @@ -517,26 +439,11 @@ ComputationDataHandle ComputationBuilder::Collapse( void ComputationBuilder::Trace(const string& tag, const ComputationDataHandle& operand) { - if (!first_error_.ok() || !PrepareComputation().ok()) { - return; - } - - TraceRequest request; - request.set_tag(tag); - *request.mutable_operand() = operand; OpRequest op_request; - *op_request.mutable_computation() = computation_.handle(); - *op_request.mutable_trace_request() = request; - AddCommonFieldsToOpRequest(&op_request); - OpResponse response; - - VLOG(2) << "making trace request"; - Status s = client_->stub()->Op(&op_request, &response); - VLOG(2) << "done with request"; - - if (!s.ok()) { - NoteError(s); - } + TraceRequest* request = op_request.mutable_trace_request(); + request->set_tag(tag); + *request->mutable_operand() = operand; + RunOpAndNoteError(&op_request); } ComputationDataHandle ComputationBuilder::Select( @@ -547,44 +454,23 @@ ComputationDataHandle ComputationBuilder::Select( ComputationDataHandle ComputationBuilder::Tuple( tensorflow::gtl::ArraySlice elements) { - if (!first_error_.ok() || !PrepareComputation().ok()) { - return ComputationDataHandle(); - } - - VariadicOpRequest request; - request.set_varop(VAROP_TUPLE); + OpRequest op_request; + VariadicOpRequest* request = op_request.mutable_variadic_op_request(); + request->set_varop(VAROP_TUPLE); for (const ComputationDataHandle& operand : elements) { - *request.add_operands() = operand; + *request->add_operands() = operand; } - OpRequest op_request; - *op_request.mutable_computation() = computation_.handle(); - *op_request.mutable_variadic_op_request() = request; - AddCommonFieldsToOpRequest(&op_request); - OpResponse response; - - VLOG(2) << "making variadic op request"; - Status s = client_->stub()->Op(&op_request, &response); - return ParseOpResponse(s, &response); + return RunOpAndParseResponse(&op_request); } ComputationDataHandle ComputationBuilder::GetTupleElement( const ComputationDataHandle& tuple_data, int64 index) { - if (!first_error_.ok() || !PrepareComputation().ok()) { - return ComputationDataHandle(); - } - - GetTupleElementRequest request; - *request.mutable_operand() = tuple_data; - request.set_index(index); OpRequest op_request; - *op_request.mutable_computation() = computation_.handle(); - *op_request.mutable_get_tuple_element_request() = request; - AddCommonFieldsToOpRequest(&op_request); - OpResponse response; - - VLOG(2) << "making get tuple element op request"; - Status s = client_->stub()->Op(&op_request, &response); - return ParseOpResponse(s, &response); + GetTupleElementRequest* request = + op_request.mutable_get_tuple_element_request(); + *request->mutable_operand() = tuple_data; + request->set_index(index); + return RunOpAndParseResponse(&op_request); } ComputationDataHandle ComputationBuilder::Eq( @@ -625,16 +511,33 @@ ComputationDataHandle ComputationBuilder::Lt( ComputationDataHandle ComputationBuilder::Dot( const ComputationDataHandle& lhs, const ComputationDataHandle& rhs) { - return BinaryOp(BINOP_DOT, lhs, rhs, /*broadcast_dimensions=*/{}); + StatusOr> lhs_shape_or_status = GetShape(lhs); + if (!lhs_shape_or_status.ok()) { + return ComputationDataHandle(); + } + std::unique_ptr lhs_shape = lhs_shape_or_status.ConsumeValueOrDie(); + + DotDimensionNumbers dimension_numbers; + dimension_numbers.add_lhs_contracting_dimensions( + lhs_shape->dimensions_size() == 1 ? 0 : 1); + dimension_numbers.add_rhs_contracting_dimensions(0); + return DotGeneral(lhs, rhs, dimension_numbers); +} + +ComputationDataHandle ComputationBuilder::DotGeneral( + const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, + const DotDimensionNumbers& dimension_numbers) { + OpRequest op_request; + DotRequest* request = op_request.mutable_dot_request(); + *request->mutable_lhs() = lhs; + *request->mutable_rhs() = rhs; + *request->mutable_dimension_numbers() = dimension_numbers; + return RunOpAndParseResponse(&op_request); } ComputationDataHandle ComputationBuilder::Conv( const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, tensorflow::gtl::ArraySlice window_strides, Padding padding) { - if (!first_error_.ok() || !PrepareComputation().ok()) { - return ComputationDataHandle(); - } - return ConvWithGeneralDimensions( lhs, rhs, window_strides, padding, CreateDefaultConvDimensionNumbers(window_strides.size())); @@ -644,10 +547,6 @@ ComputationDataHandle ComputationBuilder::ConvWithGeneralPadding( const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, tensorflow::gtl::ArraySlice window_strides, tensorflow::gtl::ArraySlice> padding) { - if (!first_error_.ok() || !PrepareComputation().ok()) { - return ComputationDataHandle(); - } - return ConvGeneral(lhs, rhs, window_strides, padding, CreateDefaultConvDimensionNumbers(window_strides.size())); } @@ -715,13 +614,11 @@ ComputationDataHandle ComputationBuilder::ConvWithGeneralDimensions( StatusOr> lhs_shape_or_status = GetShape(lhs); if (!lhs_shape_or_status.ok()) { - first_error_ = lhs_shape_or_status.status(); return ComputationDataHandle(); } StatusOr> rhs_shape_or_status = GetShape(rhs); if (!rhs_shape_or_status.ok()) { - first_error_ = rhs_shape_or_status.status(); return ComputationDataHandle(); } @@ -776,13 +673,11 @@ ComputationDataHandle ComputationBuilder::ConvGeneralDilated( StatusOr> lhs_shape_or_status = GetShape(lhs); if (!lhs_shape_or_status.ok()) { - first_error_ = lhs_shape_or_status.status(); return ComputationDataHandle(); } StatusOr> rhs_shape_or_status = GetShape(rhs); if (!rhs_shape_or_status.ok()) { - first_error_ = rhs_shape_or_status.status(); return ComputationDataHandle(); } @@ -800,122 +695,78 @@ ComputationDataHandle ComputationBuilder::ConvGeneralDilated( rhs_shape->dimensions(dimension_numbers.kernel_spatial_dimensions(i)); } - ConvolveRequest request; - *request.mutable_lhs() = lhs; - *request.mutable_rhs() = rhs; - *request.mutable_dimension_numbers() = dimension_numbers; + OpRequest op_request; + ConvolveRequest* request = op_request.mutable_convolve_request(); + *request->mutable_lhs() = lhs; + *request->mutable_rhs() = rhs; + *request->mutable_dimension_numbers() = dimension_numbers; if (!MakeWindow(window_dimensions, window_strides, padding, lhs_dilation, - rhs_dilation, request.mutable_window())) { + rhs_dilation, request->mutable_window())) { // Error is recorded in MakeWindow. return ComputationDataHandle(); } - OpRequest op_request; - *op_request.mutable_computation() = computation_.handle(); - *op_request.mutable_convolve_request() = request; - AddCommonFieldsToOpRequest(&op_request); - OpResponse response; - VLOG(2) << "making convolve request"; - Status s = client_->stub()->Op(&op_request, &response); - return ParseOpResponse(s, &response); + return RunOpAndParseResponse(&op_request); } -ComputationDataHandle ComputationBuilder::Infeed(const Shape& shape, - const string& config) { - if (!first_error_.ok() || !PrepareComputation().ok()) { - return ComputationDataHandle(); +ComputationDataHandle ComputationBuilder::Fft( + const ComputationDataHandle& operand, const FftType fft_type, + const tensorflow::gtl::ArraySlice fft_length) { + OpRequest op_request; + FftRequest* request = op_request.mutable_fft_request(); + *request->mutable_operand() = operand; + request->set_fft_type(fft_type); + for (int64 dim_len : fft_length) { + request->add_fft_length(dim_len); } + return RunOpAndParseResponse(&op_request); +} - InfeedRequest request; - *request.mutable_shape() = shape; - *request.mutable_config() = config; +ComputationDataHandle ComputationBuilder::Infeed(const Shape& shape, + const string& config) { OpRequest op_request; - *op_request.mutable_computation() = computation_.handle(); - *op_request.mutable_infeed_request() = request; - AddCommonFieldsToOpRequest(&op_request); - OpResponse response; - - VLOG(2) << "making infeed op request"; - Status s = client_->stub()->Op(&op_request, &response); - - return ParseOpResponse(s, &response); + InfeedRequest* request = op_request.mutable_infeed_request(); + *request->mutable_shape() = shape; + *request->mutable_config() = config; + return RunOpAndParseResponse(&op_request); } void ComputationBuilder::Outfeed(const ComputationDataHandle& operand, const Shape& shape, const string& outfeed_config) { - if (!first_error_.ok() || !PrepareComputation().ok()) { - return; - } - - OutfeedRequest request; - request.set_outfeed_config(outfeed_config); - *request.mutable_operand() = operand; - *request.mutable_shape() = shape; OpRequest op_request; - *op_request.mutable_outfeed_request() = request; - *op_request.mutable_computation() = computation_.handle(); - AddCommonFieldsToOpRequest(&op_request); - OpResponse response; - - VLOG(2) << "making outfeed op request"; - tensorflow::Status s = client_->stub()->Op(&op_request, &response); - - if (!s.ok()) { - NoteError(s); - return; - } + OutfeedRequest* request = op_request.mutable_outfeed_request(); + request->set_outfeed_config(outfeed_config); + *request->mutable_operand() = operand; + *request->mutable_shape() = shape; + RunOpAndNoteError(&op_request); } ComputationDataHandle ComputationBuilder::Call( const Computation& computation, tensorflow::gtl::ArraySlice operands) { - if (!first_error_.ok() || !PrepareComputation().ok()) { - return ComputationDataHandle(); - } - - CallRequest request; - *request.mutable_to_apply() = computation.handle(); + OpRequest op_request; + CallRequest* request = op_request.mutable_call_request(); + *request->mutable_to_apply() = computation.handle(); for (const ComputationDataHandle& operand : operands) { - *request.add_operands() = operand; + *request->add_operands() = operand; } - OpRequest op_request; - *op_request.mutable_computation() = computation_.handle(); - *op_request.mutable_call_request() = request; - AddCommonFieldsToOpRequest(&op_request); - OpResponse response; - - VLOG(2) << "making call op request"; - Status s = client_->stub()->Op(&op_request, &response); - - return ParseOpResponse(s, &response); + return RunOpAndParseResponse(&op_request); } ComputationDataHandle ComputationBuilder::CustomCall( const string& call_target_name, tensorflow::gtl::ArraySlice operands, const Shape& shape) { - if (!first_error_.ok() || !PrepareComputation().ok()) { - return ComputationDataHandle(); - } - - CustomCallRequest request; - request.set_call_target_name(call_target_name); + OpRequest op_request; + CustomCallRequest* request = op_request.mutable_custom_call_request(); + request->set_call_target_name(call_target_name); for (const ComputationDataHandle& operand : operands) { - *request.add_operands() = operand; + *request->add_operands() = operand; } - *request.mutable_shape() = shape; - OpRequest op_request; - *op_request.mutable_computation() = computation_.handle(); - *op_request.mutable_custom_call_request() = request; - AddCommonFieldsToOpRequest(&op_request); - OpResponse response; - - VLOG(2) << "making custom call op request"; - Status s = client_->stub()->Op(&op_request, &response); - - return ParseOpResponse(s, &response); + *request->mutable_shape() = shape; + return RunOpAndParseResponse(&op_request); } ComputationDataHandle ComputationBuilder::Complex( @@ -1080,47 +931,25 @@ ComputationDataHandle ComputationBuilder::IsFinite( ComputationDataHandle ComputationBuilder::Transpose( const ComputationDataHandle& operand, tensorflow::gtl::ArraySlice permutation) { - if (!first_error_.ok() || !PrepareComputation().ok()) { - return ComputationDataHandle(); - } - OpRequest op_request; - *op_request.mutable_computation() = computation_.handle(); TransposeRequest* request = op_request.mutable_transpose_request(); *request->mutable_operand() = operand; for (int64 dimension : permutation) { request->add_dimensions(dimension); } - AddCommonFieldsToOpRequest(&op_request); - OpResponse response; - - VLOG(2) << "making transpose request"; - Status s = client_->stub()->Op(&op_request, &response); - return ParseOpResponse(s, &response); + return RunOpAndParseResponse(&op_request); } ComputationDataHandle ComputationBuilder::Rev( const ComputationDataHandle& operand, tensorflow::gtl::ArraySlice dimensions) { - if (!first_error_.ok() || !PrepareComputation().ok()) { - return ComputationDataHandle(); - } - - ReverseRequest request; - *request.mutable_operand() = operand; + OpRequest op_request; + ReverseRequest* request = op_request.mutable_reverse_request(); + *request->mutable_operand() = operand; for (int64 dimension : dimensions) { - request.add_dimensions(dimension); + request->add_dimensions(dimension); } - OpRequest op_request; - *op_request.mutable_computation() = computation_.handle(); - *op_request.mutable_reverse_request() = request; - AddCommonFieldsToOpRequest(&op_request); - OpResponse response; - - VLOG(2) << "making reverse op request"; - Status s = client_->stub()->Op(&op_request, &response); - - return ParseOpResponse(s, &response); + return RunOpAndParseResponse(&op_request); } ComputationDataHandle ComputationBuilder::Sort( @@ -1148,24 +977,15 @@ ComputationDataHandle ComputationBuilder::ConvertElementType( StatusOr> shape_status = GetShape(operand); if (!shape_status.ok()) { - first_error_ = shape_status.status(); return ComputationDataHandle(); } std::unique_ptr original = shape_status.ConsumeValueOrDie(); - ConvertRequest request; - *request.mutable_operand() = operand; - request.set_new_element_type(new_element_type); OpRequest op_request; - *op_request.mutable_computation() = computation_.handle(); - *op_request.mutable_convert_request() = request; - AddCommonFieldsToOpRequest(&op_request); - OpResponse response; - - VLOG(2) << "making convert request"; - Status s = client_->stub()->Op(&op_request, &response); - - return ParseOpResponse(s, &response); + ConvertRequest* request = op_request.mutable_convert_request(); + *request->mutable_operand() = operand; + request->set_new_element_type(new_element_type); + return RunOpAndParseResponse(&op_request); } ComputationDataHandle ComputationBuilder::BitcastConvertType( @@ -1176,24 +996,15 @@ ComputationDataHandle ComputationBuilder::BitcastConvertType( StatusOr> shape_status = GetShape(operand); if (!shape_status.ok()) { - first_error_ = shape_status.status(); return ComputationDataHandle(); } std::unique_ptr original = shape_status.ConsumeValueOrDie(); - ConvertRequest request; - *request.mutable_operand() = operand; - request.set_new_element_type(new_element_type); OpRequest op_request; - *op_request.mutable_computation() = computation_.handle(); - *op_request.mutable_bitcast_convert_request() = request; - AddCommonFieldsToOpRequest(&op_request); - OpResponse response; - - VLOG(2) << "making bitcast convert request"; - Status s = client_->stub()->Op(&op_request, &response); - - return ParseOpResponse(s, &response); + ConvertRequest* request = op_request.mutable_bitcast_convert_request(); + *request->mutable_operand() = operand; + request->set_new_element_type(new_element_type); + return RunOpAndParseResponse(&op_request); } ComputationDataHandle ComputationBuilder::SquareF32( @@ -1221,107 +1032,57 @@ ComputationDataHandle ComputationBuilder::Clamp( ComputationDataHandle ComputationBuilder::UnaryOp( UnaryOperation unop, const ComputationDataHandle& operand) { - if (!first_error_.ok() || !PrepareComputation().ok()) { - return ComputationDataHandle(); - } - - UnaryOpRequest request; - request.set_unop(unop); - *request.mutable_operand() = operand; OpRequest op_request; - *op_request.mutable_computation() = computation_.handle(); - *op_request.mutable_unary_op_request() = request; - AddCommonFieldsToOpRequest(&op_request); - OpResponse response; - - VLOG(2) << "making unop request"; - Status s = client_->stub()->Op(&op_request, &response); - - return ParseOpResponse(s, &response); + UnaryOpRequest* request = op_request.mutable_unary_op_request(); + request->set_unop(unop); + *request->mutable_operand() = operand; + return RunOpAndParseResponse(&op_request); } ComputationDataHandle ComputationBuilder::BinaryOp( BinaryOperation binop, const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, tensorflow::gtl::ArraySlice broadcast_dimensions) { - if (!first_error_.ok() || !PrepareComputation().ok()) { - return ComputationDataHandle(); - } - - BinaryOpRequest request; - request.set_binop(binop); - *request.mutable_lhs() = lhs; - *request.mutable_rhs() = rhs; + OpRequest op_request; + BinaryOpRequest* request = op_request.mutable_binary_op_request(); + request->set_binop(binop); + *request->mutable_lhs() = lhs; + *request->mutable_rhs() = rhs; for (int64 dimension : broadcast_dimensions) { - request.add_broadcast_dimensions(dimension); + request->add_broadcast_dimensions(dimension); } - OpRequest op_request; - *op_request.mutable_computation() = computation_.handle(); - *op_request.mutable_binary_op_request() = request; - AddCommonFieldsToOpRequest(&op_request); - OpResponse response; - - VLOG(2) << "making binop request"; - Status s = client_->stub()->Op(&op_request, &response); - - return ParseOpResponse(s, &response); + return RunOpAndParseResponse(&op_request); } ComputationDataHandle ComputationBuilder::RngOp( RandomDistribution distribution, tensorflow::gtl::ArraySlice parameters, const Shape& shape) { - if (!first_error_.ok() || !PrepareComputation().ok()) { - return ComputationDataHandle(); - } - - RngRequest request; - request.set_distribution(distribution); + OpRequest op_request; + RngRequest* request = op_request.mutable_rng_request(); + request->set_distribution(distribution); for (const ComputationDataHandle& param : parameters) { - *request.add_parameter() = param; + *request->add_parameter() = param; } - *request.mutable_shape() = shape; - OpRequest op_request; - *op_request.mutable_computation() = computation_.handle(); - *op_request.mutable_rng_request() = request; - AddCommonFieldsToOpRequest(&op_request); - OpResponse response; - - VLOG(2) << "making rngop request"; - Status s = client_->stub()->Op(&op_request, &response); - - return ParseOpResponse(s, &response); + *request->mutable_shape() = shape; + return RunOpAndParseResponse(&op_request); } ComputationDataHandle ComputationBuilder::TernaryOp( TernaryOperation triop, const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, const ComputationDataHandle& ehs) { - if (!first_error_.ok() || !PrepareComputation().ok()) { - return ComputationDataHandle(); - } - - TernaryOpRequest request; - request.set_triop(triop); - *request.mutable_lhs() = lhs; - *request.mutable_rhs() = rhs; - *request.mutable_ehs() = ehs; OpRequest op_request; - *op_request.mutable_computation() = computation_.handle(); - *op_request.mutable_ternary_op_request() = request; - AddCommonFieldsToOpRequest(&op_request); - OpResponse response; - - VLOG(2) << "making triop request"; - Status s = client_->stub()->Op(&op_request, &response); - - return ParseOpResponse(s, &response); + TernaryOpRequest* request = op_request.mutable_ternary_op_request(); + request->set_triop(triop); + *request->mutable_lhs() = lhs; + *request->mutable_rhs() = rhs; + *request->mutable_ehs() = ehs; + return RunOpAndParseResponse(&op_request); } Status ComputationBuilder::SetReturnValue( const ComputationDataHandle& operand) { - if (!first_error_.ok()) { - return first_error_; - } + TF_RETURN_IF_ERROR(first_error_); SetReturnValueRequest request; *request.mutable_computation() = computation_.handle(); @@ -1343,9 +1104,7 @@ Status ComputationBuilder::SetReturnValue( StatusOr ComputationBuilder::IsConstant( const ComputationDataHandle& operand, int64 num_parameters) { - if (!first_error_.ok()) { - return first_error_; - } + TF_RETURN_IF_ERROR(first_error_); IsConstantRequest request; *request.mutable_computation() = computation_.handle(); @@ -1366,9 +1125,7 @@ StatusOr ComputationBuilder::IsConstant( StatusOr> ComputationBuilder::ComputeConstant( const ComputationDataHandle& operand, const Layout* output_layout, tensorflow::gtl::ArraySlice parameters) { - if (!first_error_.ok()) { - return first_error_; - } + TF_RETURN_IF_ERROR(first_error_); ComputeConstantRequest request; *request.mutable_computation() = computation_.handle(); @@ -1397,7 +1154,7 @@ StatusOr> ComputationBuilder::ComputeConstant( "no computed literal in the provided response in ComputeConstant " "request"); } - return MakeUnique(response.literal()); + return Literal::CreateFromProto(response.literal()); } ComputationDataHandle ComputationBuilder::Map( @@ -1405,30 +1162,19 @@ ComputationDataHandle ComputationBuilder::Map( const Computation& computation, tensorflow::gtl::ArraySlice dimensions, tensorflow::gtl::ArraySlice static_operands) { - if (!first_error_.ok() || !PrepareComputation().ok()) { - return ComputationDataHandle(); - } - - MapRequest request; + OpRequest op_request; + MapRequest* request = op_request.mutable_map_request(); for (const ComputationDataHandle& operand : operands) { - *request.add_operands() = operand; + *request->add_operands() = operand; } - *request.mutable_to_apply() = computation.handle(); + *request->mutable_to_apply() = computation.handle(); for (int64 dimension : dimensions) { - request.add_dimensions(dimension); + request->add_dimensions(dimension); } for (const ComputationDataHandle& sop : static_operands) { - *request.add_static_operands() = sop; + *request->add_static_operands() = sop; } - OpRequest op_request; - *op_request.mutable_computation() = computation_.handle(); - *op_request.mutable_map_request() = request; - AddCommonFieldsToOpRequest(&op_request); - OpResponse response; - - VLOG(2) << "making Map request"; - Status s = client_->stub()->Op(&op_request, &response); - return ParseOpResponse(s, &response); + return RunOpAndParseResponse(&op_request); } ComputationDataHandle ComputationBuilder::RngNormal( @@ -1443,57 +1189,46 @@ ComputationDataHandle ComputationBuilder::RngUniform( return RngOp(RandomDistribution::RNG_UNIFORM, {a, b}, shape); } -ComputationDataHandle ComputationBuilder::RngBernoulli( - const ComputationDataHandle& mean, const Shape& shape) { - return RngOp(RandomDistribution::RNG_BERNOULLI, {mean}, shape); -} - ComputationDataHandle ComputationBuilder::While( const Computation& condition, const Computation& body, const ComputationDataHandle& init) { - if (!first_error_.ok() || !PrepareComputation().ok()) { - return ComputationDataHandle(); - } - - WhileRequest request; - *request.mutable_condition() = condition.handle(); - *request.mutable_body() = body.handle(); - *request.mutable_init() = init; OpRequest op_request; - *op_request.mutable_computation() = computation_.handle(); - *op_request.mutable_while_request() = request; - AddCommonFieldsToOpRequest(&op_request); - OpResponse response; - - VLOG(2) << "making while request"; - Status s = client_->stub()->Op(&op_request, &response); - return ParseOpResponse(s, &response); + WhileRequest* request = op_request.mutable_while_request(); + *request->mutable_condition() = condition.handle(); + *request->mutable_body() = body.handle(); + *request->mutable_init() = init; + return RunOpAndParseResponse(&op_request); +} + +ComputationDataHandle ComputationBuilder::Conditional( + const ComputationDataHandle& predicate, + const ComputationDataHandle& true_operand, + const Computation& true_computation, + const ComputationDataHandle& false_operand, + const Computation& false_computation) { + OpRequest op_request; + ConditionalRequest* request = op_request.mutable_conditional_request(); + *request->mutable_predicate() = predicate; + *request->mutable_true_operand() = true_operand; + *request->mutable_true_computation() = true_computation.handle(); + *request->mutable_false_operand() = false_operand; + *request->mutable_false_computation() = false_computation.handle(); + return RunOpAndParseResponse(&op_request); } ComputationDataHandle ComputationBuilder::Reduce( const ComputationDataHandle& operand, const ComputationDataHandle& init_value, const Computation& computation, tensorflow::gtl::ArraySlice dimensions_to_reduce) { - if (!first_error_.ok() || !PrepareComputation().ok()) { - return ComputationDataHandle(); - } - - ReduceRequest request; - *request.mutable_operand() = operand; - *request.mutable_init_value() = init_value; + OpRequest op_request; + ReduceRequest* request = op_request.mutable_reduce_request(); + *request->mutable_operand() = operand; + *request->mutable_init_value() = init_value; for (int64 dimension : dimensions_to_reduce) { - request.add_dimensions(dimension); + request->add_dimensions(dimension); } - *request.mutable_to_apply() = computation.handle(); - OpRequest op_request; - *op_request.mutable_computation() = computation_.handle(); - *op_request.mutable_reduce_request() = request; - AddCommonFieldsToOpRequest(&op_request); - OpResponse response; - - VLOG(2) << "making reduce request"; - Status s = client_->stub()->Op(&op_request, &response); - return ParseOpResponse(s, &response); + *request->mutable_to_apply() = computation.handle(); + return RunOpAndParseResponse(&op_request); } ComputationDataHandle ComputationBuilder::ReduceAll( @@ -1505,7 +1240,6 @@ ComputationDataHandle ComputationBuilder::ReduceAll( StatusOr> shape = GetShape(operand); if (!shape.ok()) { - first_error_ = shape.status(); return ComputationDataHandle(); } @@ -1525,7 +1259,6 @@ ComputationDataHandle ComputationBuilder::ReduceWindow( StatusOr> shape = GetShape(operand); if (!shape.ok()) { - first_error_ = shape.status(); return ComputationDataHandle(); } @@ -1551,84 +1284,50 @@ ComputationDataHandle ComputationBuilder::ReduceWindowWithGeneralPadding( tensorflow::gtl::ArraySlice window_dimensions, tensorflow::gtl::ArraySlice window_strides, tensorflow::gtl::ArraySlice> padding) { - if (!first_error_.ok() || !PrepareComputation().ok()) { - return ComputationDataHandle(); - } - - ReduceWindowRequest request; - *request.mutable_operand() = operand; - *request.mutable_to_apply() = computation.handle(); - *request.mutable_init_value() = init_value; + OpRequest op_request; + ReduceWindowRequest* request = op_request.mutable_reduce_window_request(); + *request->mutable_operand() = operand; + *request->mutable_to_apply() = computation.handle(); + *request->mutable_init_value() = init_value; if (!MakeWindow(window_dimensions, window_strides, padding, {}, {}, - request.mutable_window())) { + request->mutable_window())) { NoteError(InternalError("failed to make window")); return ComputationDataHandle(); } - OpRequest op_request; - *op_request.mutable_computation() = computation_.handle(); - *op_request.mutable_reduce_window_request() = request; - AddCommonFieldsToOpRequest(&op_request); - OpResponse response; - VLOG(2) << "making reduce-window request"; - Status s = client_->stub()->Op(&op_request, &response); - return ParseOpResponse(s, &response); + return RunOpAndParseResponse(&op_request); } ComputationDataHandle ComputationBuilder::BatchNormTraining( const ComputationDataHandle& operand, const ComputationDataHandle& scale, const ComputationDataHandle& offset, float epsilon, int64 feature_index) { - if (!first_error_.ok() || !PrepareComputation().ok()) { - return ComputationDataHandle(); - } - BatchNormTrainingRequest request; - *request.mutable_operand() = operand; - *request.mutable_scale() = scale; - *request.mutable_offset() = offset; - request.set_epsilon(epsilon); - request.set_feature_index(feature_index); - OpRequest op_request; - *op_request.mutable_batch_norm_training_request() = request; - *op_request.mutable_computation() = computation_.handle(); - AddCommonFieldsToOpRequest(&op_request); - - OpResponse response; - - VLOG(2) << "making BatchNormTraining request"; - - Status s = client_->stub()->Op(&op_request, &response); - return ParseOpResponse(s, &response); + BatchNormTrainingRequest* request = + op_request.mutable_batch_norm_training_request(); + *request->mutable_operand() = operand; + *request->mutable_scale() = scale; + *request->mutable_offset() = offset; + request->set_epsilon(epsilon); + request->set_feature_index(feature_index); + return RunOpAndParseResponse(&op_request); } ComputationDataHandle ComputationBuilder::BatchNormInference( const ComputationDataHandle& operand, const ComputationDataHandle& scale, const ComputationDataHandle& offset, const ComputationDataHandle& mean, const ComputationDataHandle& variance, float epsilon, int64 feature_index) { - if (!first_error_.ok() || !PrepareComputation().ok()) { - return ComputationDataHandle(); - } - BatchNormInferenceRequest request; - *request.mutable_operand() = operand; - *request.mutable_scale() = scale; - *request.mutable_offset() = offset; - *request.mutable_mean() = mean; - *request.mutable_variance() = variance; - request.set_epsilon(epsilon); - request.set_feature_index(feature_index); - OpRequest op_request; - *op_request.mutable_batch_norm_inference_request() = request; - *op_request.mutable_computation() = computation_.handle(); - AddCommonFieldsToOpRequest(&op_request); - - OpResponse response; - - VLOG(2) << "making BatchNormInference request"; - - Status s = client_->stub()->Op(&op_request, &response); - return ParseOpResponse(s, &response); + BatchNormInferenceRequest* request = + op_request.mutable_batch_norm_inference_request(); + *request->mutable_operand() = operand; + *request->mutable_scale() = scale; + *request->mutable_offset() = offset; + *request->mutable_mean() = mean; + *request->mutable_variance() = variance; + request->set_epsilon(epsilon); + request->set_feature_index(feature_index); + return RunOpAndParseResponse(&op_request); } ComputationDataHandle ComputationBuilder::BatchNormGrad( @@ -1636,49 +1335,25 @@ ComputationDataHandle ComputationBuilder::BatchNormGrad( const ComputationDataHandle& mean, const ComputationDataHandle& var, const ComputationDataHandle& grad_output, float epsilon, int64 feature_index) { - if (!first_error_.ok() || !PrepareComputation().ok()) { - return ComputationDataHandle(); - } - BatchNormGradRequest request; - *request.mutable_operand() = operand; - *request.mutable_scale() = scale; - *request.mutable_mean() = mean; - *request.mutable_variance() = var; - *request.mutable_grad_output() = grad_output; - request.set_epsilon(epsilon); - request.set_feature_index(feature_index); - OpRequest op_request; - *op_request.mutable_batch_norm_grad_request() = request; - *op_request.mutable_computation() = computation_.handle(); - AddCommonFieldsToOpRequest(&op_request); - - OpResponse response; - - VLOG(2) << "making BatchNormGrad request"; - - Status s = client_->stub()->Op(&op_request, &response); - - return ParseOpResponse(s, &response); + BatchNormGradRequest* request = op_request.mutable_batch_norm_grad_request(); + *request->mutable_operand() = operand; + *request->mutable_scale() = scale; + *request->mutable_mean() = mean; + *request->mutable_variance() = var; + *request->mutable_grad_output() = grad_output; + request->set_epsilon(epsilon); + request->set_feature_index(feature_index); + return RunOpAndParseResponse(&op_request); } ComputationDataHandle ComputationBuilder::CrossReplicaSum( const ComputationDataHandle& operand) { - if (!first_error_.ok() || !PrepareComputation().ok()) { - return ComputationDataHandle(); - } - - CrossReplicaSumRequest request; - *request.mutable_operand() = operand; OpRequest op_request; - *op_request.mutable_cross_replica_sum_request() = request; - *op_request.mutable_computation() = computation_.handle(); - AddCommonFieldsToOpRequest(&op_request); - OpResponse response; - - VLOG(2) << "making cross-replica-sum request"; - Status s = client_->stub()->Op(&op_request, &response); - return ParseOpResponse(s, &response); + CrossReplicaSumRequest* request = + op_request.mutable_cross_replica_sum_request(); + *request->mutable_operand() = operand; + return RunOpAndParseResponse(&op_request); } ComputationDataHandle ComputationBuilder::SelectAndScatter( @@ -1693,7 +1368,6 @@ ComputationDataHandle ComputationBuilder::SelectAndScatter( StatusOr> shape = GetShape(operand); if (!shape.ok()) { - first_error_ = shape.status(); return ComputationDataHandle(); } return SelectAndScatterWithGeneralPadding( @@ -1710,98 +1384,53 @@ ComputationDataHandle ComputationBuilder::SelectAndScatterWithGeneralPadding( tensorflow::gtl::ArraySlice> padding, const ComputationDataHandle& source, const ComputationDataHandle& init_value, const Computation& scatter) { - if (!first_error_.ok() || !PrepareComputation().ok()) { - return ComputationDataHandle(); - } - - SelectAndScatterRequest request; - *request.mutable_operand() = operand; - *request.mutable_select() = select.handle(); - *request.mutable_source() = source; - *request.mutable_init_value() = init_value; - *request.mutable_scatter() = scatter.handle(); + OpRequest op_request; + SelectAndScatterRequest* request = + op_request.mutable_select_and_scatter_request(); + *request->mutable_operand() = operand; + *request->mutable_select() = select.handle(); + *request->mutable_source() = source; + *request->mutable_init_value() = init_value; + *request->mutable_scatter() = scatter.handle(); if (!MakeWindow(window_dimensions, window_strides, padding, {}, {}, - request.mutable_window())) { + request->mutable_window())) { NoteError(InternalError("failed to make window")); return ComputationDataHandle(); } - OpRequest op_request; - *op_request.mutable_computation() = computation_.handle(); - *op_request.mutable_select_and_scatter_request() = request; - AddCommonFieldsToOpRequest(&op_request); - OpResponse response; - VLOG(2) << "making select-and-scatter request"; - Status s = client_->stub()->Op(&op_request, &response); - return ParseOpResponse(s, &response); + return RunOpAndParseResponse(&op_request); } ComputationDataHandle ComputationBuilder::ReducePrecision( const ComputationDataHandle& operand, const int exponent_bits, const int mantissa_bits) { - if (!first_error_.ok() || !PrepareComputation().ok()) { - return ComputationDataHandle(); - } - - ReducePrecisionRequest request; - *request.mutable_operand() = operand; - request.set_exponent_bits(exponent_bits); - request.set_mantissa_bits(mantissa_bits); OpRequest op_request; - *op_request.mutable_computation() = computation_.handle(); - *op_request.mutable_reduce_precision_request() = request; - AddCommonFieldsToOpRequest(&op_request); - OpResponse response; - - VLOG(2) << "making reduce-precision request"; - Status s = client_->stub()->Op(&op_request, &response); - return ParseOpResponse(s, &response); + ReducePrecisionRequest* request = + op_request.mutable_reduce_precision_request(); + *request->mutable_operand() = operand; + request->set_exponent_bits(exponent_bits); + request->set_mantissa_bits(mantissa_bits); + return RunOpAndParseResponse(&op_request); } void ComputationBuilder::Send(const ComputationDataHandle& operand, const ChannelHandle& handle) { - if (!first_error_.ok() || !PrepareComputation().ok()) { - return; - } - - SendRequest request; - *request.mutable_operand() = operand; - *request.mutable_channel_handle() = handle; OpRequest op_request; - *op_request.mutable_send_request() = request; + SendRequest* request = op_request.mutable_send_request(); + *request->mutable_operand() = operand; + *request->mutable_channel_handle() = handle; *op_request.mutable_computation() = computation_.handle(); - AddCommonFieldsToOpRequest(&op_request); - OpResponse response; - - VLOG(2) << "making send request"; - Status s = client_->stub()->Op(&op_request, &response); - VLOG(2) << "done with op request"; - - if (!s.ok()) { - NoteError(s); - return; - } + RunOpAndNoteError(&op_request); } ComputationDataHandle ComputationBuilder::Recv(const Shape& shape, const ChannelHandle& handle) { - if (!first_error_.ok() || !PrepareComputation().ok()) { - return ComputationDataHandle(); - } - - RecvRequest request; - *request.mutable_shape() = shape; - *request.mutable_channel_handle() = handle; OpRequest op_request; - *op_request.mutable_recv_request() = request; - *op_request.mutable_computation() = computation_.handle(); - AddCommonFieldsToOpRequest(&op_request); - OpResponse response; - - VLOG(2) << "making recv request"; - Status s = client_->stub()->Op(&op_request, &response); - return ParseOpResponse(s, &response); + RecvRequest* request = op_request.mutable_recv_request(); + *request->mutable_shape() = shape; + *request->mutable_channel_handle() = handle; + return RunOpAndParseResponse(&op_request); } Computation ComputationBuilder::BuildAndNoteError() { @@ -1830,13 +1459,6 @@ StatusOr ComputationBuilder::Build() { return {std::move(computation_)}; } -void ComputationBuilder::AddCommonFieldsToOpRequest(OpRequest* request) const { - *request->mutable_metadata() = metadata_; - if (sharding_) { - *request->mutable_sharding() = *sharding_; - } -} - /* static */ ConvolutionDimensionNumbers ComputationBuilder::CreateDefaultConvDimensionNumbers(int num_spatial_dims) { ConvolutionDimensionNumbers dimension_numbers; diff --git a/tensorflow/compiler/xla/client/computation_builder.h b/tensorflow/compiler/xla/client/computation_builder.h index d2dbbbbebbd5a9386f8841576de33a1fdb767000..d82ba63e8ad0b9ceac0eb5f0cd7720cac0cbe6d3 100644 --- a/tensorflow/compiler/xla/client/computation_builder.h +++ b/tensorflow/compiler/xla/client/computation_builder.h @@ -43,59 +43,6 @@ limitations under the License. namespace xla { -class ShardingBuilder { - public: - // A shaped array used to describe the assignment of tiles to devices. - using TileAssignment = Array; - - // Creates a replicated sharding - replicate a tensor on every device. - static OpSharding Replicate() { - OpSharding result; - result.set_type(OpSharding::Type::OpSharding_Type_REPLICATED); - return result; - } - // Creates a sharding that assigns a tensor to just one device. - static OpSharding AssignDevice(int device) { - OpSharding result; - result.set_type(OpSharding::Type::OpSharding_Type_MAXIMAL); - result.add_tile_assignment_dimensions(1); - result.add_tile_assignment_devices(device); - return result; - } - // Creates a tiled sharding with the given tile shape and assignment of tiles - // to devices. - static OpSharding Tile(Shape tile_shape, - const TileAssignment& tile_assignment) { - OpSharding result; - result.set_type(OpSharding::Type::OpSharding_Type_OTHER); - *result.mutable_tile_shape() = tile_shape; - for (int64 dim : tile_assignment.dimensions()) { - result.add_tile_assignment_dimensions(dim); - } - for (uint32 device : tile_assignment) { - result.add_tile_assignment_devices(device); - } - return result; - } - // Creates a sharding in one dimension, with the given tile shape which must - // be rank 1 and using devices 0..num_tiles. - static OpSharding Tile1D(Shape tile_shape, int64 num_tiles) { - OpSharding result; - result.set_type(OpSharding::Type::OpSharding_Type_OTHER); - - CHECK_EQ(ShapeUtil::Rank(tile_shape), 1); - std::vector dimensions(1, num_tiles); - auto& tile_dimension = (*tile_shape.mutable_dimensions())[0]; - tile_dimension = CeilOfRatio(static_cast(tile_dimension), num_tiles); - *result.mutable_tile_shape() = tile_shape; - result.add_tile_assignment_dimensions(num_tiles); - for (int64 i = 0; i < num_tiles; ++i) { - result.add_tile_assignment_devices(i); - } - return result; - } -}; - // Wraps an XLA client with a convenient interface for building up // computations. Any errors encountered in building up the computation are // deferred from being handled until Build() is called. @@ -393,6 +340,11 @@ class ComputationBuilder { ComputationDataHandle Dot(const ComputationDataHandle& lhs, const ComputationDataHandle& rhs); + // Enqueues a general dot instruction onto the computation. + ComputationDataHandle DotGeneral( + const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, + const DotDimensionNumbers& dimension_numbers); + // Default dimension numbers used for a 2D convolution. static constexpr int64 kConvBatchDimension = 0; static constexpr int64 kConvFeatureDimension = 1; @@ -458,14 +410,24 @@ class ComputationBuilder { tensorflow::gtl::ArraySlice rhs_dilation, const ConvolutionDimensionNumbers& dimension_numbers); + // Enqueues an FFT instruction onto the computation, of the given type and + // with the given FFT length. + ComputationDataHandle Fft(const ComputationDataHandle& operand, + FftType fft_type, + tensorflow::gtl::ArraySlice fft_length); + // Enqueues an infeed instruction onto the computation, which writes data of // the given shape to the infeed buffer of the device. ComputationDataHandle Infeed(const Shape& shape, const string& config = ""); // Enqueues an outfeed instruction onto the computation. This instruction // generates outgoing data transfers for the given data. - void Outfeed(const ComputationDataHandle& operand, const Shape& shape, - const string& outfeed_config); + // + // shape_with_layout communicates the laid out shape that we want to outfeed + // -- if !ShapeUtil::Compatible(GetShape(operand), shape_with_layout) an error + // will occur. + void Outfeed(const ComputationDataHandle& operand, + const Shape& shape_with_layout, const string& outfeed_config); // Enqueues a call instruction onto the computation. ComputationDataHandle Call( @@ -726,16 +688,18 @@ class ComputationBuilder { const ComputationDataHandle& b, const Shape& shape); - // Enqueues a B(1, p) random number generation instruction onto the - // computation. - ComputationDataHandle RngBernoulli(const ComputationDataHandle& mean, - const Shape& shape); - // Enqueues a while node onto the computation. ComputationDataHandle While(const Computation& condition, const Computation& body, const ComputationDataHandle& init); + // Enqueues a conditional node onto the computation. + ComputationDataHandle Conditional(const ComputationDataHandle& predicate, + const ComputationDataHandle& true_operand, + const Computation& true_computation, + const ComputationDataHandle& false_operand, + const Computation& false_computation); + // Enqueues a ReducePrecision node onto the computation. ComputationDataHandle ReducePrecision(const ComputationDataHandle& operand, const int exponent_bits, @@ -751,7 +715,7 @@ class ComputationBuilder { ComputationDataHandle Recv(const Shape& shape, const ChannelHandle& handle); // Returns true if 'operand' is a compile-time constant. A compile-time - // constant does not depend on parameters with higher index then + // constant does not depend on parameters with index greater than or equal to // `num_parameters`, or on stateful operators such as `RngNormal` or `Infeed`. // Unlike `ComputeConstant`, `IsConstant` tests whether a computation is a // compile-time constant without evaluating the computation. @@ -811,7 +775,7 @@ class ComputationBuilder { // The operand must represent a constant value, which in this case // means that it must not statically depend on any parameter of the // computation that is being built other then the ones specified on the - // paramtere list. The parameters in the list will be indexed by their + // parameter list. The parameters in the list will be indexed by their // parameter id property so the number of parameters specified should be at // least as many as the largest used parameter index. // @@ -870,8 +834,6 @@ class ComputationBuilder { Status first_error() const { return first_error_; } private: - using PopulateLiteral = std::function; - // Limited checking of convolution parameters. Returns false on // error. bool VerifyConvolution(const Shape& lhs_shape, const Shape& rhs_shape, @@ -890,11 +852,6 @@ class ComputationBuilder { tensorflow::gtl::ArraySlice rhs_dilation, Window* window); - // Internal helper method that makes a request for a constant operation -- the - // provided function is used to populate the literal before sending the - // request. - ComputationDataHandle ConstantOp(const PopulateLiteral& populate); - // Internal helper method that does the building for an arbitrary unary op. ComputationDataHandle UnaryOp(UnaryOperation binop, const ComputationDataHandle& operand); @@ -924,19 +881,28 @@ class ComputationBuilder { // This is used before any given operation is enqueued. Status PrepareComputation(); - // Helper function for parsing a method response and either returning the - // output computation data handle (on success) or a vacuous computation data - // handle (on failure). - ComputationDataHandle ParseOpResponse(const Status& status, - OpResponse* response); - // Notes that the error occurred by: // * storing it internally and capturing a backtrace if it's the first error // (this deferred value will be produced on the call to Build()) // * dying if die_immediately_on_error_ is true void NoteError(const Status& error); - void AddCommonFieldsToOpRequest(OpRequest* request) const; + // Helper function that runs the given op_request, filling in op_response. + // Before the op is run, PrepareComputation is called, and common fields in + // the op_request are filled in. + Status RunOp(OpRequest* op_request, OpResponse* op_response); + + // Helper function that calls RunOp and calls NoteError on failures. + void RunOpAndNoteError(OpRequest* op_request); + + // Helper function that calls RunOp and either returns the output computation + // data handle (on success) or a vacuous computation data handle (on failure). + ComputationDataHandle RunOpAndParseResponse(OpRequest* op_request); + + // Helper function that implements GetShape without noting errors. This makes + // it easier to ensure the real GetShape will note errors on every error path. + StatusOr> GetShapeWithoutNoteError( + const ComputationDataHandle& operand); string name_; // Name to use for the built computation. @@ -970,68 +936,66 @@ class ComputationBuilder { template ComputationDataHandle ComputationBuilder::ConstantR0(NativeT value) { - return ConstantOp([value](Literal* literal) { literal->PopulateR0(value); }); + return ConstantLiteral(*Literal::CreateR0(value)); } template ComputationDataHandle ComputationBuilder::ConstantR1( tensorflow::gtl::ArraySlice values) { - return ConstantOp( - [&values](Literal* literal) { literal->PopulateR1(values); }); + return ConstantLiteral(*Literal::CreateR1(values)); } template ComputationDataHandle ComputationBuilder::ConstantR1(int64 length, NativeT value) { - return ConstantOp([length, value](Literal* literal) { - literal->PopulateWithValue(value, {length}); - }); + Literal literal(ShapeUtil::MakeShape( + primitive_util::NativeToPrimitiveType(), {length})); + literal.PopulateWithValue(value); + return ConstantLiteral(literal); } inline ComputationDataHandle ComputationBuilder::ConstantR1( const tensorflow::core::Bitmap& values) { - return ConstantOp( - [&values](Literal* literal) { literal->PopulateR1(values); }); + return ConstantLiteral(*Literal::CreateR1(values)); } template ComputationDataHandle ComputationBuilder::ConstantR2( std::initializer_list> values) { - return ConstantOp( - [&values](Literal* literal) { literal->PopulateR2(values); }); + return ConstantLiteral(*Literal::CreateR2(values)); } template ComputationDataHandle ComputationBuilder::ConstantFromArrayWithLayout( const Array& values, const Layout& layout) { - return ConstantOp([&values, &layout](Literal* literal) { - literal->PopulateFromArrayWithLayout(values, layout); - }); + return ConstantLiteral( + *Literal::CreateFromArrayWithLayout(values, layout)); } template ComputationDataHandle ComputationBuilder::ConstantFromArray( const Array& values) { - return ConstantOp( - [&values](Literal* literal) { literal->PopulateFromArray(values); }); + return ConstantLiteral(*Literal::CreateFromArray(values)); } template ComputationDataHandle ComputationBuilder::ConstantR2FromArray2DWithLayout( const Array2D& values, const Layout& layout) { - return ConstantFromArrayWithLayout(values, layout); + return ConstantLiteral( + *Literal::CreateFromArrayWithLayout(values, layout)); } template ComputationDataHandle ComputationBuilder::ConstantR2FromArray2D( const Array2D& values) { - return ConstantFromArray(values); + return ConstantLiteral(*Literal::CreateR2FromArray2D(values)); } template ComputationDataHandle ComputationBuilder::ConstantR3FromArray3DWithLayout( const Array3D& values, const Layout& layout) { - return ConstantFromArrayWithLayout(values, layout); + return ConstantLiteral( + *Literal::CreateR3FromArray3DWithLayout(values, layout)); } template diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc index b051955f0fd85b7ca886bc0238068aeb94427209..523169fdd266d445c9d0d056ba20091f77610ad9 100644 --- a/tensorflow/compiler/xla/client/local_client.cc +++ b/tensorflow/compiler/xla/client/local_client.cc @@ -78,14 +78,14 @@ tensorflow::Status LocalExecutable::ValidateExecutionOptions( } for (int i = 0; i < arguments.size(); ++i) { if (!computation_layout.parameter_layout(i).MatchesLayoutInShape( - arguments[i]->shape())) { + arguments[i]->on_host_shape())) { return InvalidArgument( "argument does not match shape or layout of computation parameter " "%d: expected %s, got %s", i, ShapeUtil::HumanString(computation_layout.parameter_layout(i).shape()) .c_str(), - ShapeUtil::HumanString(arguments[i]->shape()).c_str()); + ShapeUtil::HumanString(arguments[i]->on_host_shape()).c_str()); } } @@ -184,7 +184,7 @@ StatusOr> LocalExecutable::Run( } TF_ASSIGN_OR_RETURN( std::unique_ptr result, - executable_->ExecuteOnStreamWrapper>( + executable_->ExecuteOnStreamWrapper( &service_options, options.execution_profile(), arguments)); return ScopedShapedBuffer::MakeScoped(result.get(), actual_options.allocator()); @@ -281,13 +281,9 @@ LocalClient::LiteralToShapedBuffer(const Literal& literal, int device_ordinal, if (allocator == nullptr) { allocator = backend().memory_allocator(); } - TF_ASSIGN_OR_RETURN( - auto scoped_buffer, - ScopedShapedBuffer::Allocate( - literal.shape(), allocator, device_ordinal, - [this](const Shape& shape) { - return backend().transfer_manager()->GetByteSizeRequirement(shape); - })); + TF_ASSIGN_OR_RETURN(auto scoped_buffer, + backend().transfer_manager()->AllocateScopedShapedBuffer( + literal.shape(), allocator, device_ordinal)); TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor, backend().stream_executor(device_ordinal)); TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralToDevice( @@ -322,4 +318,8 @@ StatusOr> LocalClient::TransferFromOutfeedLocal( return std::move(literal); } +StatusOr LocalClient::ReplicaNumberToDeviceOrdinal(int replica_number) { + return local_service_->ReplicaNumberToDeviceOrdinal(replica_number); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/client/local_client.h b/tensorflow/compiler/xla/client/local_client.h index 3ca0d2ef5513cfb6b0dbfbc63b311f81a318356e..19fd14f76bc69d528193f7981a51a305f03f987e 100644 --- a/tensorflow/compiler/xla/client/local_client.h +++ b/tensorflow/compiler/xla/client/local_client.h @@ -176,6 +176,13 @@ class LocalClient : public Client { StatusOr> TransferFromOutfeedLocal( const Shape& shape, int device_ordinal); + // Returns the device ordinal that corresponds to the given replica number. + // + // This returns an error if there is not a one-to-one correspondence of + // replicas to device ordinals, but is useful as a short term mechanism for + // the "easy" case where a single replica is a single device. + StatusOr ReplicaNumberToDeviceOrdinal(int replica_number); + // Returns the platform that the underlying service targets. perftools::gputools::Platform* platform() const; diff --git a/tensorflow/compiler/xla/client/sharding_builder.cc b/tensorflow/compiler/xla/client/sharding_builder.cc new file mode 100644 index 0000000000000000000000000000000000000000..176802b33ef824a1f898255a19e44def3c1fc982 --- /dev/null +++ b/tensorflow/compiler/xla/client/sharding_builder.cc @@ -0,0 +1,76 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/sharding_builder.h" + +namespace xla { +namespace sharding_builder { + +OpSharding Replicate() { + OpSharding result; + result.set_type(OpSharding::Type::OpSharding_Type_REPLICATED); + return result; +} + +OpSharding AssignDevice(int device) { + OpSharding result; + result.set_type(OpSharding::Type::OpSharding_Type_MAXIMAL); + result.add_tile_assignment_dimensions(1); + result.add_tile_assignment_devices(device); + return result; +} + +OpSharding Tile(const Shape& tile_shape, + const TileAssignment& tile_assignment) { + OpSharding result; + result.set_type(OpSharding::Type::OpSharding_Type_OTHER); + *result.mutable_tile_shape() = tile_shape; + for (int64 dim : tile_assignment.dimensions()) { + result.add_tile_assignment_dimensions(dim); + } + for (uint32 device : tile_assignment) { + result.add_tile_assignment_devices(device); + } + return result; +} + +OpSharding Tile1D(const Shape& tile_shape, int64 num_tiles) { + OpSharding result; + result.set_type(OpSharding::Type::OpSharding_Type_OTHER); + + CHECK_EQ(ShapeUtil::Rank(tile_shape), 1); + std::vector dimensions(1, num_tiles); + *result.mutable_tile_shape() = tile_shape; + auto& tile_dimension = + (*result.mutable_tile_shape()->mutable_dimensions())[0]; + tile_dimension = CeilOfRatio(static_cast(tile_dimension), num_tiles); + result.add_tile_assignment_dimensions(num_tiles); + for (int64 i = 0; i < num_tiles; ++i) { + result.add_tile_assignment_devices(i); + } + return result; +} + +OpSharding Tuple(const ShapeTree& shardings) { + OpSharding result; + result.set_type(OpSharding::Type::OpSharding_Type_TUPLE); + for (const auto& index_to_sharding : shardings.leaves()) { + *result.add_tuple_shardings() = index_to_sharding.second; + } + return result; +} + +} // namespace sharding_builder +} // namespace xla diff --git a/tensorflow/compiler/xla/client/sharding_builder.h b/tensorflow/compiler/xla/client/sharding_builder.h new file mode 100644 index 0000000000000000000000000000000000000000..34763e54d946690289ff42a7712b980168933eee --- /dev/null +++ b/tensorflow/compiler/xla/client/sharding_builder.h @@ -0,0 +1,59 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_SHARDING_BUILDER_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_SHARDING_BUILDER_H_ + +#include + +#include "tensorflow/compiler/xla/array.h" +#include "tensorflow/compiler/xla/shape_tree.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { +namespace sharding_builder { +// A shaped array used to describe the assignment of tiles to devices. +using TileAssignment = Array; + +// Creates a replicated sharding - replicate a tensor on every device. +OpSharding Replicate(); + +// Creates a sharding that assigns a tensor to just one device. +OpSharding AssignDevice(int device); + +// Creates a tiled sharding with the given tile shape and assignment of tiles +// to devices. +// +// If tile_shape is not evenly divisible by the number of devices in +// tile_assignment, operations behave as if implicit padding had been inserted. +// The value of this padding is undefined. +OpSharding Tile(const Shape& tile_shape, const TileAssignment& tile_assignment); + +// Creates a sharding in one dimension, with the given tile shape which must +// be rank 1 and using devices [0..num_tiles). +// +// This is simply a convenience wrapper for Tile(). +OpSharding Tile1D(const Shape& tile_shape, int64 num_tiles); + +// Creates a tuple sharding from the given ShapeTree of element shardings. +OpSharding Tuple(const ShapeTree& shardings); + +} // namespace sharding_builder +} // namespace xla + +#endif diff --git a/tensorflow/compiler/xla/executable_run_options.cc b/tensorflow/compiler/xla/executable_run_options.cc index 33d5b6f1d4d15d5143a3421c87eab9b7a7d11345..392ad9010ab81923a089c7b00a79ddc281af92bb 100644 --- a/tensorflow/compiler/xla/executable_run_options.cc +++ b/tensorflow/compiler/xla/executable_run_options.cc @@ -83,7 +83,7 @@ ExecutableRunOptions& ExecutableRunOptions::set_device_assignment( return *this; } -DeviceAssignment* ExecutableRunOptions::device_assignment() const { +const DeviceAssignment* ExecutableRunOptions::device_assignment() const { return device_assignment_; } diff --git a/tensorflow/compiler/xla/executable_run_options.h b/tensorflow/compiler/xla/executable_run_options.h index deb3ddb203d263d25bef0499a8a53a6098d0de0c..d4fcbf0493c936ebcd0639a432e56b62ee15672c 100644 --- a/tensorflow/compiler/xla/executable_run_options.h +++ b/tensorflow/compiler/xla/executable_run_options.h @@ -82,7 +82,7 @@ class ExecutableRunOptions { ExecutableRunOptions& set_device_assignment( DeviceAssignment* device_assignment); - DeviceAssignment* device_assignment() const; + const DeviceAssignment* device_assignment() const; private: DeviceMemoryAllocator* allocator_ = nullptr; diff --git a/tensorflow/compiler/xla/index_util.cc b/tensorflow/compiler/xla/index_util.cc index 76c0168f370ff1f0749759705b7ecff359a80341..ffd1fb79e986f82e1c2721f0eefbf3b4c0838e41 100644 --- a/tensorflow/compiler/xla/index_util.cc +++ b/tensorflow/compiler/xla/index_util.cc @@ -78,7 +78,7 @@ namespace xla { int64 scale = 1; int64 linear_index = 0; bool first = true; - for (auto dimension : shape.layout().minor_to_major()) { + for (auto dimension : LayoutUtil::MinorToMajor(shape)) { if (first) { // Avoid two multiplies on the first loop iteration linear_index = multi_index[dimension]; @@ -110,7 +110,7 @@ namespace xla { // Accumulated product D{L(0)} * D{L(1)} * ... int64 divisor = 1; - for (auto dimension : shape.layout().minor_to_major()) { + for (auto dimension : LayoutUtil::MinorToMajor(shape)) { multi_index[dimension] = (linear_index / divisor) % shape.dimensions(dimension); divisor *= shape.dimensions(dimension); @@ -133,21 +133,49 @@ namespace xla { /* static */ int64 IndexUtil::GetDimensionStride(const Shape& shape, int64 dimension) { - const Layout& layout = shape.layout(); - int64 pdim_size = layout.padded_dimensions_size(); + int64 pdim_size = LayoutUtil::PaddedDimensions(shape).size(); int64 stride = 1; DCHECK(pdim_size == 0 || pdim_size == shape.dimensions_size()); - for (auto dim : layout.minor_to_major()) { + for (auto dim : LayoutUtil::MinorToMajor(shape)) { if (dim == dimension) { break; } if (pdim_size == 0) { stride *= shape.dimensions(dim); } else { - stride *= layout.padded_dimensions(dim); + stride *= LayoutUtil::PaddedDimension(shape, dim); } } return stride; } +/* static */ bool IndexUtil::IndexInBounds( + const Shape& shape, tensorflow::gtl::ArraySlice index) { + int64 rank = ShapeUtil::Rank(shape); + if (rank != index.size()) { + return false; + } + for (int64 d = 0; d < rank; ++d) { + if (index[d] >= shape.dimensions(d)) { + return false; + } + } + return true; +} + +/* static */ int IndexUtil::CompareIndices( + tensorflow::gtl::ArraySlice lhs, + tensorflow::gtl::ArraySlice rhs) { + int64 rank = lhs.size(); + CHECK_EQ(rhs.size(), rank); + for (int64 dim = 0; dim < rank; ++dim) { + if (lhs[dim] < rhs[dim]) { + return -1; + } else if (lhs[dim] > rhs[dim]) { + return 1; + } + } + return 0; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/index_util.h b/tensorflow/compiler/xla/index_util.h index c9838966a5b67397eb5fc4afe3ab9d98e82eb2b1..0b9188e8524d6f1367541496dc5a86a250a0d530 100644 --- a/tensorflow/compiler/xla/index_util.h +++ b/tensorflow/compiler/xla/index_util.h @@ -69,6 +69,18 @@ class IndexUtil { // sizeof(dimension(3)) * sizeof(dimension(2)) == 4 * 10 static int64 GetDimensionStride(const Shape& shape, int64 dimension); + // Returns true iff the given multi-index is contained in the bounds for the + // shape. + static bool IndexInBounds(const Shape& shape, + tensorflow::gtl::ArraySlice index); + + // Compares the given indices in lexicographic order. lhs[0] and rhs[0] are + // compared first, and lhs[rank-1] and rhs[rank-1] last. If lhs is larger, + // then -1 is returned. If rhs is larger, then 1 is returned. Otherwise, 0 is + // returned. + static int CompareIndices(tensorflow::gtl::ArraySlice lhs, + tensorflow::gtl::ArraySlice rhs); + private: TF_DISALLOW_COPY_AND_ASSIGN(IndexUtil); }; diff --git a/tensorflow/compiler/xla/layout_util.cc b/tensorflow/compiler/xla/layout_util.cc index 5c2cc2a7a99cc51ded3d98c9dd5903e4b3078548..fdc4bbdd8b162b7115788e267c2a53e73c186123 100644 --- a/tensorflow/compiler/xla/layout_util.cc +++ b/tensorflow/compiler/xla/layout_util.cc @@ -57,17 +57,26 @@ void SetDefaultLayoutToContainer( /* static */ Layout LayoutUtil::MakeLayout( tensorflow::gtl::ArraySlice minor_to_major) { Layout layout; + layout.set_format(DENSE); for (int64 dimension_number : minor_to_major) { layout.add_minor_to_major(dimension_number); } return layout; } +/* static */ Layout LayoutUtil::MakeSparseLayout(int64 max_sparse_elements) { + Layout layout; + layout.set_format(SPARSE); + layout.set_max_sparse_elements(max_sparse_elements); + return layout; +} + namespace { // Internal helper that creates a default layout for an array of the given rank. Layout CreateDefaultLayoutForRank(int64 rank) { Layout layout; + layout.set_format(DENSE); tensorflow::protobuf::RepeatedField* minor_to_major = layout.mutable_minor_to_major(); minor_to_major->Resize(rank, 0); @@ -105,7 +114,11 @@ Layout CreateDefaultLayoutForRank(int64 rank) { for (auto& element_shape : *shape->mutable_tuple_shapes()) { SetToDefaultLayout(&element_shape); } + shape->clear_layout(); + } else if (ShapeUtil::IsOpaque(*shape)) { + shape->clear_layout(); } else { + shape->mutable_layout()->set_format(DENSE); tensorflow::protobuf::RepeatedField* minor_to_major = shape->mutable_layout()->mutable_minor_to_major(); minor_to_major->Resize(shape->dimensions_size(), 0); @@ -137,8 +150,10 @@ Layout CreateDefaultLayoutForRank(int64 rank) { TF_RETURN_IF_ERROR(ValidateLayoutInShape(element_shape)); } return tensorflow::Status::OK(); - } else if (ShapeUtil::Rank(shape) == 0 && !shape.has_layout()) { - // A scalar without a layout is ok. + } else if (ShapeUtil::IsOpaque(shape)) { + if (shape.has_layout()) { + return InvalidArgument("opaque should not have a layout field"); + } return tensorflow::Status::OK(); } else { // Array shape. @@ -156,46 +171,59 @@ Layout CreateDefaultLayoutForRank(int64 rank) { return InvalidArgument("a single Layout is not valid for tuple shapes"); } - if (layout.minor_to_major_size() != ShapeUtil::Rank(shape)) { + if (ShapeUtil::IsOpaque(shape)) { + return tensorflow::Status::OK(); + } + + if (layout.format() == INVALID_FORMAT) { return InvalidArgument( - "layout minor_to_major field contains %d elements, " - "but shape is rank %lld: {%s}; shape: %s", - layout.minor_to_major_size(), ShapeUtil::Rank(shape), - tensorflow::str_util::Join(layout.minor_to_major(), ", ").c_str(), - shape.ShortDebugString().c_str()); + "Layout does not have a valid format: layout {%s}, shape {%s}", + layout.ShortDebugString().c_str(), shape.ShortDebugString().c_str()); } - std::vector dimensions_in_layout(ShapeUtil::Rank(shape), false); - for (int64 i = 0; i < ShapeUtil::Rank(shape); ++i) { - int64 dim = layout.minor_to_major(i); - if (dim < 0 || dim >= ShapeUtil::Rank(shape)) { + if (layout.format() == DENSE) { + if (layout.minor_to_major_size() != ShapeUtil::Rank(shape)) { return InvalidArgument( - "layout minor_to_major field has out-of-bounds value: %s", - HumanString(layout).c_str()); + "layout minor_to_major field contains %d elements, " + "but shape is rank %lld: {%s}; shape: %s", + layout.minor_to_major_size(), ShapeUtil::Rank(shape), + tensorflow::str_util::Join(layout.minor_to_major(), ", ").c_str(), + shape.ShortDebugString().c_str()); } - if (dimensions_in_layout[dim]) { - return InvalidArgument( - "layout minor_to_major field has duplicate values: {%s}", - HumanString(layout).c_str()); - } - dimensions_in_layout[dim] = true; - } - if (layout.padded_dimensions_size() > 0) { - if (layout.padded_dimensions_size() != ShapeUtil::Rank(shape)) { - return InvalidArgument( - "layout has %d padded dimensions, but shape is rank %lld", - layout.padded_dimensions_size(), ShapeUtil::Rank(shape)); + std::vector dimensions_in_layout(ShapeUtil::Rank(shape), false); + for (int64 i = 0; i < ShapeUtil::Rank(shape); ++i) { + int64 dim = layout.minor_to_major(i); + if (dim < 0 || dim >= ShapeUtil::Rank(shape)) { + return InvalidArgument( + "layout minor_to_major field has out-of-bounds value: %s", + HumanString(layout).c_str()); + } + if (dimensions_in_layout[dim]) { + return InvalidArgument( + "layout minor_to_major field has duplicate values: {%s}", + HumanString(layout).c_str()); + } + dimensions_in_layout[dim] = true; } - for (int i = 0; i < layout.padded_dimensions_size(); ++i) { - if (layout.padded_dimensions(i) < shape.dimensions(i)) { + + if (layout.padded_dimensions_size() > 0) { + if (layout.padded_dimensions_size() != ShapeUtil::Rank(shape)) { return InvalidArgument( - "for dimension %d, dimension padding (%lld) is smaller than " - "the dimension size (%lld) of the shape", - i, layout.padded_dimensions(i), shape.dimensions(i)); + "layout has %d padded dimensions, but shape is rank %lld", + layout.padded_dimensions_size(), ShapeUtil::Rank(shape)); + } + for (int i = 0; i < layout.padded_dimensions_size(); ++i) { + if (layout.padded_dimensions(i) < shape.dimensions(i)) { + return InvalidArgument( + "for dimension %d, dimension padding (%lld) is smaller than " + "the dimension size (%lld) of the shape", + i, layout.padded_dimensions(i), shape.dimensions(i)); + } } } } + return tensorflow::Status::OK(); } @@ -213,12 +241,23 @@ Layout CreateDefaultLayoutForRank(int64 rank) { LayoutUtil::ClearLayout(program_shape->mutable_result()); } +/* static */ bool LayoutUtil::IsDenseArray(const Shape& shape) { + return ShapeUtil::IsArray(shape) && shape.has_layout() && + IsDense(shape.layout()); +} + +/* static */ bool LayoutUtil::IsDense(const Layout& layout) { + return layout.format() == DENSE; +} + /* static */ bool LayoutUtil::IsMonotonicWithDim0Minor(const Layout& layout) { + CHECK(layout.format() == DENSE); return std::is_sorted(layout.minor_to_major().begin(), layout.minor_to_major().end()); } /* static */ bool LayoutUtil::IsMonotonicWithDim0Major(const Layout& layout) { + CHECK(layout.format() == DENSE); return std::is_sorted(layout.minor_to_major().begin(), layout.minor_to_major().end(), std::greater()); } @@ -228,6 +267,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) { shape.layout().padded_dimensions_size() == 0) { return false; } + CHECK(IsDenseArray(shape)); CHECK_EQ(shape.dimensions_size(), shape.layout().padded_dimensions_size()); for (int64 i = 0; i < shape.dimensions_size(); ++i) { if (shape.layout().padded_dimensions(i) > shape.dimensions(i)) { @@ -237,15 +277,46 @@ Layout CreateDefaultLayoutForRank(int64 rank) { return false; } +/* static */ tensorflow::gtl::ArraySlice LayoutUtil::PaddedDimensions( + const Shape& shape) { + CHECK(IsDenseArray(shape)); + return AsInt64Slice(shape.layout().padded_dimensions()); +} + +/* static */ int64 LayoutUtil::PaddedDimension(const Shape& shape, + int64 index) { + CHECK(IsDenseArray(shape)); + return shape.layout().padded_dimensions(index); +} + +/* static */ PaddingValue LayoutUtil::GetPaddingValue(const Shape& shape) { + CHECK(IsDenseArray(shape)); + return shape.layout().padding_value(); +} + +/* static */ bool LayoutUtil::IsSparseArray(const Shape& shape) { + return ShapeUtil::IsArray(shape) && shape.has_layout() && + IsSparse(shape.layout()); +} + +/* static */ bool LayoutUtil::IsSparse(const Layout& layout) { + return layout.format() == SPARSE; +} + +/* static */ int64 LayoutUtil::MaxSparseElements(const Layout& layout) { + CHECK(IsSparse(layout)); + return layout.max_sparse_elements(); +} + /* static */ bool LayoutUtil::HasLayout(const Shape& shape) { if (ShapeUtil::IsTuple(shape)) { // Tuple shape: all subshapes must have a layout. return std::all_of(shape.tuple_shapes().begin(), shape.tuple_shapes().end(), [](const Shape& s) { return HasLayout(s); }); + } else if (ShapeUtil::IsOpaque(shape)) { + return true; } - // A scalar trivially always has a layout. - return (ShapeUtil::Rank(shape) == 0 || - (shape.has_layout() && (shape.layout().minor_to_major_size() > 0))); + return shape.has_layout() && shape.layout().format() != INVALID_FORMAT; } /* static */ bool LayoutUtil::HasLayout(const ProgramShape& program_shape) { @@ -261,6 +332,18 @@ Layout CreateDefaultLayoutForRank(int64 rank) { return protobuf_util::ProtobufEquals(lhs, rhs); } +/* static */ tensorflow::gtl::ArraySlice LayoutUtil::MinorToMajor( + const Shape& shape) { + CHECK(IsDenseArray(shape)); + return AsInt64Slice(shape.layout().minor_to_major()); +} + +/* static */ tensorflow::gtl::ArraySlice LayoutUtil::MinorToMajor( + const Layout& layout) { + CHECK(layout.format() == DENSE); + return AsInt64Slice(layout.minor_to_major()); +} + /* static */ int64 LayoutUtil::Major(const Layout& layout, int64 physical_dimension_number) { CHECK_LE(0, physical_dimension_number); @@ -271,6 +354,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) { /* static */ int64 LayoutUtil::Minor(const Layout& layout, int64 physical_dimension_number) { + CHECK_EQ(layout.format(), DENSE); CHECK_LE(0, physical_dimension_number); CHECK_LT(physical_dimension_number, layout.minor_to_major_size()); return layout.minor_to_major(physical_dimension_number); @@ -287,6 +371,11 @@ Layout CreateDefaultLayoutForRank(int64 rank) { } /* static */ string LayoutUtil::HumanString(const Layout& layout) { + if (IsSparse(layout)) { + return tensorflow::strings::StrCat("sparse{", layout.max_sparse_elements(), + "}"); + } + CHECK(IsDense(layout)); return tensorflow::strings::StrCat( "{", tensorflow::str_util::Join(layout.minor_to_major(), ","), "}"); } @@ -356,6 +445,7 @@ tensorflow::Status LayoutUtil::CopyLayoutBetweenShapes(const Shape& src, /* static */ bool LayoutUtil::AreDimensionsConsecutive( const Layout& layout, tensorflow::gtl::ArraySlice dims) { + CHECK(IsDense(layout)); std::vector positions_in_layout; for (int64 dim : dims) { positions_in_layout.push_back( @@ -370,4 +460,9 @@ tensorflow::Status LayoutUtil::CopyLayoutBetweenShapes(const Shape& src, return true; } +std::ostream& operator<<(std::ostream& out, const Layout& layout) { + out << LayoutUtil::HumanString(layout); + return out; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/layout_util.h b/tensorflow/compiler/xla/layout_util.h index bc42e222292933be35e82d1fe50802e8830d16b3..6c54eb2201b66a4a0c5695bceb14bb2367133935 100644 --- a/tensorflow/compiler/xla/layout_util.h +++ b/tensorflow/compiler/xla/layout_util.h @@ -36,6 +36,10 @@ class LayoutUtil { // convenience function for protobuf construction.) static Layout MakeLayout(tensorflow::gtl::ArraySlice minor_to_major); + // Creates a sparse layout with the given maximum number of elements. (This is + // a convenience function for protobuf construction.) + static Layout MakeSparseLayout(int64 max_sparse_elements); + // Returns default layout for the given shape. static Layout GetDefaultLayoutForShape(const Shape& shape); @@ -71,6 +75,12 @@ class LayoutUtil { // Clears the layout on all Shapes within the given ProgramShape. static void ClearLayout(ProgramShape* program_shape); + // Returns whether the given Shape is an array and has a dense format layout. + static bool IsDenseArray(const Shape& shape); + + // Returns whether the given Layout has a dense format. + static bool IsDense(const Layout& layout); + // Returns whether the layout is monotonic and dim 0 is minor in the layout. // * R0 and R1: this is always trivially true. // * R2+: equivalent to column-major. Dimension 0 is the minor, dimension 1 is @@ -88,6 +98,30 @@ class LayoutUtil { // dimension size). static bool IsPadded(const Shape& shape); + // Returns the padded_dimensions array for the given Shape. Requires that the + // shape is an array and has a dense layout. + static tensorflow::gtl::ArraySlice PaddedDimensions( + const Shape& shape); + + // Returns the given index of the padded_dimensions array for the given Shape. + // Requires that the shape is an array and has a dense layout. + static int64 PaddedDimension(const Shape& shape, int64 index); + + // Returns the padding_value for the given Shape. Requires that the shape is + // an array and has a dense layout. + static PaddingValue GetPaddingValue(const Shape& shape); + + // Returns whether the given Shape is an array (i.e. not a tuple) and has a + // sparse format layout. + static bool IsSparseArray(const Shape& shape); + + // Returns whether the given Layout has a sparse format. + static bool IsSparse(const Layout& layout); + + // Returns the maximum number of elements that can be stored in a sparse + // layout. + static int64 MaxSparseElements(const Layout& layout); + // Returns whether the given shape has a layout. For tuple shapes, true is // returned only if all elements have layouts. static bool HasLayout(const Shape& shape); @@ -98,7 +132,12 @@ class LayoutUtil { // Returns whether lhs and rhs are identical. static bool Equal(const Layout& lhs, const Layout& rhs); - // Major(0) is the most major logical dimension number, major(1) is the + // Returns the minor_to_major array for the given Shape. Requires that the + // shape is an array and has a dense layout. + static tensorflow::gtl::ArraySlice MinorToMajor(const Shape& shape); + static tensorflow::gtl::ArraySlice MinorToMajor(const Layout& layout); + + // Major(0) is the most major logical dimension number, Major(1) is the // second-most-major logical dimension number and so on. // // This can be used to translate physical dimension numbers to logical @@ -160,6 +199,8 @@ class LayoutUtil { TF_DISALLOW_COPY_AND_ASSIGN(LayoutUtil); }; +std::ostream& operator<<(std::ostream& out, const Layout& layout); + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_LAYOUT_UTIL_H_ diff --git a/tensorflow/compiler/xla/layout_util_test.cc b/tensorflow/compiler/xla/layout_util_test.cc index 331bb9afa94e9e7c97d9c880dbac31c60ac0da18..4fd1d818e3e3b417eee9f6b14bb598bfb9480c6e 100644 --- a/tensorflow/compiler/xla/layout_util_test.cc +++ b/tensorflow/compiler/xla/layout_util_test.cc @@ -14,6 +14,9 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/layout_util.h" + +#include + #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" @@ -30,6 +33,14 @@ class LayoutUtilTest : public ::testing::Test { *shape.mutable_layout() = LayoutUtil::MakeLayout(minor_to_major); return shape; } + + Shape MakeShapeWithSparseLayout(PrimitiveType element_type, + tensorflow::gtl::ArraySlice dimensions, + int64 max_sparse_elements) { + Shape shape = ShapeUtil::MakeShape(element_type, dimensions); + *shape.mutable_layout() = LayoutUtil::MakeSparseLayout(max_sparse_elements); + return shape; + } }; TEST_F(LayoutUtilTest, TupleLayoutComparison) { @@ -81,6 +92,29 @@ TEST_F(LayoutUtilTest, CopyLayoutArray) { EXPECT_FALSE(dst.has_layout()); } +TEST_F(LayoutUtilTest, CopyLayoutSparse) { + Shape src = MakeShapeWithSparseLayout(F32, {2, 3}, 2); + Shape dst = MakeShapeWithLayout(F32, {2, 3}, {1, 0}); + + EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(src, dst)); + EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst)); + EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst)); + + // Should work if destination has no layout. + dst.clear_layout(); + EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(src, dst)); + EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst)); + EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst)); + + // If source is cleared, then destination should be cleared. + src.clear_layout(); + EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(src, dst)); + EXPECT_TRUE(dst.has_layout()); + EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst)); + EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst)); + EXPECT_FALSE(dst.has_layout()); +} + TEST_F(LayoutUtilTest, CopyLayoutTuple) { Shape src = ShapeUtil::MakeTupleShape( {MakeShapeWithLayout(F32, {2, 3}, {0, 1}), @@ -100,6 +134,25 @@ TEST_F(LayoutUtilTest, CopyLayoutTuple) { EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst)); } +TEST_F(LayoutUtilTest, CopyLayoutTupleSparse) { + Shape src = ShapeUtil::MakeTupleShape( + {MakeShapeWithSparseLayout(F32, {2, 3}, 4), + MakeShapeWithSparseLayout(F32, {42, 123}, 4), + ShapeUtil::MakeTupleShape( + {MakeShapeWithLayout(F32, {}, {}), + MakeShapeWithSparseLayout(F32, {1, 2, 3}, 6)})}); + Shape dst = ShapeUtil::MakeTupleShape( + {MakeShapeWithLayout(F32, {2, 3}, {1, 0}), + MakeShapeWithLayout(F32, {42, 123}, {1, 0}), + ShapeUtil::MakeTupleShape( + {MakeShapeWithLayout(F32, {}, {}), + MakeShapeWithLayout(F32, {1, 2, 3}, {1, 2, 0})})}); + + EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(src, dst)); + EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst)); + EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst)); +} + TEST_F(LayoutUtilTest, CopyLayoutNotCompatibleSameRank) { Shape src = MakeShapeWithLayout(F32, {123, 42, 7}, {2, 0, 1}); Shape dst = MakeShapeWithLayout(F32, {2, 3, 5}, {1, 0}); @@ -107,6 +160,13 @@ TEST_F(LayoutUtilTest, CopyLayoutNotCompatibleSameRank) { EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst)); } +TEST_F(LayoutUtilTest, CopyLayoutSparseNotCompatibleSameRank) { + Shape src = MakeShapeWithSparseLayout(F32, {123, 42, 7}, 6); + Shape dst = MakeShapeWithLayout(F32, {2, 3, 5}, {1, 0}); + ASSERT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst)); + EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst)); +} + TEST_F(LayoutUtilTest, CopyLayoutNotCompatibleDifferentRank) { Shape src = MakeShapeWithLayout(F32, {123, 42, 7}, {2, 0, 1}); Shape dst = MakeShapeWithLayout(F32, {2, 3}, {1, 0}); @@ -116,6 +176,15 @@ TEST_F(LayoutUtilTest, CopyLayoutNotCompatibleDifferentRank) { ::testing::ContainsRegex("cannot copy layout from shape")); } +TEST_F(LayoutUtilTest, CopyLayoutSparseNotCompatibleDifferentRank) { + Shape src = MakeShapeWithLayout(F32, {123, 42, 7}, {2, 0, 1}); + Shape dst = MakeShapeWithSparseLayout(F32, {2, 3}, 4); + auto status = LayoutUtil::CopyLayoutBetweenShapes(src, &dst); + EXPECT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), + ::testing::ContainsRegex("cannot copy layout from shape")); +} + TEST_F(LayoutUtilTest, CopyLayoutNotCompatibleTuple) { Shape src = ShapeUtil::MakeTupleShape({MakeShapeWithLayout(F32, {2, 3}, {0, 1}), @@ -221,5 +290,16 @@ TEST_F(LayoutUtilTest, DefaultLayoutGettersMajorToMinor) { ShapeUtil::MakeShape(F32, {10, 20, 30, 15, 25})))); } +TEST_F(LayoutUtilTest, SparseLayoutMaxElements) { + EXPECT_EQ(LayoutUtil::MaxSparseElements(LayoutUtil::MakeSparseLayout(101)), + 101); +} + +TEST_F(LayoutUtilTest, StreamOut) { + std::ostringstream oss; + oss << LayoutUtil::MakeLayout({0, 1, 2}); + EXPECT_EQ(oss.str(), "{0,1,2}"); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc index bfafef0a40f55e13ac94b2d1750df25146081784..e88bffd0ba2dacb837c568023f5da1338fea40f3 100644 --- a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc +++ b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc @@ -40,6 +40,10 @@ void SetDebugOptionsDefaults(DebugOptions* flags) { flags->set_xla_cpu_multi_thread_eigen(true); flags->set_xla_gpu_cuda_data_dir("./cuda_sdk_lib"); flags->set_xla_eliminate_hlo_implicit_broadcast(true); + + // Set cudnn batchnorm off by default; it does not provide a performance win + // on average. + flags->set_xla_gpu_use_cudnn_batchnorm(false); } // Allocates flag_values and flag_objects; this function must not be called more @@ -96,179 +100,184 @@ void AllocateFlags() { option_proto, reduce_precision_option_value); }; - flag_objects = new std::vector( - {tensorflow::Flag( - "xla_generate_hlo_graph", - flag_values->mutable_xla_generate_hlo_graph(), - "HLO modules matching this regex will be dumped to a .dot file " - "throughout various stages in compilation."), - tensorflow::Flag( - "xla_hlo_graph_addresses", - bool_setter_for(&DebugOptions::set_xla_hlo_graph_addresses), - flag_values->xla_hlo_graph_addresses(), - "With xla_generate_hlo_graph, show addresses of HLO ops in " - "graph dump."), - tensorflow::Flag( - "xla_hlo_graph_path", flag_values->mutable_xla_hlo_graph_path(), - "With xla_generate_hlo_graph, dump the graphs into this path."), - tensorflow::Flag( - "xla_hlo_dump_as_graphdef", - bool_setter_for(&DebugOptions::set_xla_hlo_dump_as_graphdef), - flag_values->xla_hlo_dump_as_graphdef(), - "Dump HLO graphs as TensorFlow GraphDefs."), - tensorflow::Flag( - "xla_hlo_graph_sharding_color", - bool_setter_for(&DebugOptions::set_xla_hlo_graph_sharding_color), - flag_values->xla_hlo_graph_sharding_color(), - "Assign colors based on sharding assignments when generating the " - "HLO graphs."), - tensorflow::Flag( - "xla_hlo_tfgraph_device_scopes", - bool_setter_for(&DebugOptions::set_xla_hlo_tfgraph_device_scopes), - flag_values->xla_hlo_tfgraph_device_scopes(), - "When generating TensorFlow HLO graphs, if the HLO instructions " - "are assigned to a specific device, prefix the name scope with " - "\"devX\" with X being the device ordinal."), - tensorflow::Flag( - "xla_log_hlo_text", flag_values->mutable_xla_log_hlo_text(), - "HLO modules matching this regex will be dumped to LOG(INFO)."), - tensorflow::Flag( - "xla_generate_hlo_text_to", - flag_values->mutable_xla_generate_hlo_text_to(), - "Dump all HLO modules as text into the provided directory path."), - tensorflow::Flag( - "xla_enable_fast_math", - bool_setter_for(&DebugOptions::set_xla_enable_fast_math), - flag_values->xla_enable_fast_math(), - "Enable unsafe fast-math optimizations in the compiler; " - "this may produce faster code at the expense of some accuracy."), - tensorflow::Flag( - "xla_llvm_enable_alias_scope_metadata", - bool_setter_for( - &DebugOptions::set_xla_llvm_enable_alias_scope_metadata), - flag_values->xla_llvm_enable_alias_scope_metadata(), - "In LLVM-based backends, enable the emission of " - "!alias.scope metadata in the generated IR."), - tensorflow::Flag( - "xla_llvm_enable_noalias_metadata", - bool_setter_for(&DebugOptions::set_xla_llvm_enable_noalias_metadata), - flag_values->xla_llvm_enable_noalias_metadata(), - "In LLVM-based backends, enable the emission of " - "!noalias metadata in the generated IR."), - tensorflow::Flag( - "xla_llvm_enable_invariant_load_metadata", - bool_setter_for( - &DebugOptions::set_xla_llvm_enable_invariant_load_metadata), - flag_values->xla_llvm_enable_invariant_load_metadata(), - "In LLVM-based backends, enable the emission of " - "!invariant.load metadata in " - "the generated IR."), - tensorflow::Flag( - "xla_llvm_disable_expensive_passes", - bool_setter_for( - &DebugOptions::set_xla_llvm_disable_expensive_passes), - flag_values->xla_llvm_disable_expensive_passes(), - "In LLVM-based backends, disable a custom set of " - "expensive optimization passes."), - tensorflow::Flag( - "xla_backend_optimization_level", - int32_setter_for(&DebugOptions::set_xla_backend_optimization_level), - flag_values->xla_backend_optimization_level(), - "Numerical optimization level for the XLA compiler backend."), - tensorflow::Flag( - "xla_disable_hlo_passes", setter_for_xla_disable_hlo_passes, "", - "Comma-separated list of hlo passes to be disabled. These names " - "must exactly match the passes' names; no whitespace around " - "commas."), - tensorflow::Flag( - "xla_embed_ir_in_executable", - bool_setter_for(&DebugOptions::set_xla_embed_ir_in_executable), - flag_values->xla_embed_ir_in_executable(), - "Embed the compiler IR as a string in the executable."), - tensorflow::Flag( - "xla_dump_ir_to", flag_values->mutable_xla_dump_ir_to(), - "Dump the compiler IR into this directory as individual files."), - tensorflow::Flag( - "xla_eliminate_hlo_implicit_broadcast", - bool_setter_for( - &DebugOptions::set_xla_eliminate_hlo_implicit_broadcast), - flag_values->xla_eliminate_hlo_implicit_broadcast(), - "Eliminate implicit broadcasts when lowering user " - "computations to HLO instructions; use explicit " - "broadcast instead."), - tensorflow::Flag( - "xla_cpu_multi_thread_eigen", - bool_setter_for(&DebugOptions::set_xla_cpu_multi_thread_eigen), - flag_values->xla_cpu_multi_thread_eigen(), - "When generating calls to Eigen in the CPU backend, " - "use multi-threaded Eigen mode."), - tensorflow::Flag("xla_gpu_cuda_data_dir", - flag_values->mutable_xla_gpu_cuda_data_dir(), - "If non-empty, speficies a local directory containing " - "ptxas and nvvm libdevice files; otherwise we use " - "those from runfile directories."), - tensorflow::Flag("xla_gpu_ftz", - bool_setter_for(&DebugOptions::set_xla_gpu_ftz), - flag_values->xla_gpu_ftz(), - "If true, flush-to-zero semantics are enabled in the " - "code generated for GPUs."), - tensorflow::Flag( - "xla_gpu_disable_multi_streaming", - bool_setter_for(&DebugOptions::set_xla_gpu_disable_multi_streaming), - flag_values->xla_gpu_disable_multi_streaming(), - "If true, multi-streaming in the GPU backend is disabled."), - tensorflow::Flag( - "xla_dump_hlo_proto_to", - flag_values->mutable_xla_dump_hlo_proto_to(), - "Dump compilation artifacts as proto binary into this directory."), - tensorflow::Flag( - "xla_test_all_output_layouts", - bool_setter_for(&DebugOptions::set_xla_test_all_output_layouts), - flag_values->xla_test_all_output_layouts(), - "Let ClientLibraryTestBase::ComputeAndCompare* test " - "all permutations of output layouts. For example, with " - "a 3D shape, all permutations of the set {0, 1, 2} are " - "tried."), - tensorflow::Flag( - "xla_test_all_input_layouts", - bool_setter_for(&DebugOptions::set_xla_test_all_input_layouts), - flag_values->xla_test_all_input_layouts(), - "Let ClientLibraryTestBase::ComputeAndCompare* test " - "all permutations of *input* layouts. For example, for " - "2 input arguments with 2D shape and 4D shape, the " - "computation will run 2! * 4! times for every possible " - "layouts"), - tensorflow::Flag( - "xla_hlo_profile", - bool_setter_for(&DebugOptions::set_xla_hlo_profile), - flag_values->xla_hlo_profile(), - "Instrument the computation to collect per-HLO cycle counts"), - tensorflow::Flag("xla_dump_computations_to", - flag_values->mutable_xla_dump_computations_to(), - "Dump computations that XLA executes into the provided " - "directory path"), - tensorflow::Flag("xla_dump_executions_to", - flag_values->mutable_xla_dump_executions_to(), - "Dump parameters and results of computations that XLA " - "executes into the provided directory path"), - tensorflow::Flag("xla_backend_extra_options", - setter_for_xla_backend_extra_options, "", - "Extra options to pass to a backend; " - "comma-separated list of 'key=val' strings (=val " - "may be omitted); no whitespace around commas."), - tensorflow::Flag("xla_reduce_precision", setter_for_xla_reduce_precision, - "", - "Directions for adding reduce-precision operations. " - "Format is 'LOCATION=E,M:OPS;NAMES' where LOCATION is " - "the class of locations in which to insert the " - "operations (e.g., 'OP_OUTPUTS'), E and M are the " - "exponent and matissa bit counts respectively, and " - "OPS and NAMES are comma-separated (no spaces) lists " - "of the operation types and names to which to attach " - "the reduce-precision operations. The NAMES string " - "and its preceding ';' may be omitted. This option " - "may be repeated to define multiple sets of added " - "reduce-precision operations.")}); + flag_objects = new std::vector({ + tensorflow::Flag( + "xla_generate_hlo_graph", + flag_values->mutable_xla_generate_hlo_graph(), + "HLO modules matching this regex will be dumped to a .dot file " + "throughout various stages in compilation."), + tensorflow::Flag( + "xla_hlo_graph_addresses", + bool_setter_for(&DebugOptions::set_xla_hlo_graph_addresses), + flag_values->xla_hlo_graph_addresses(), + "With xla_generate_hlo_graph, show addresses of HLO ops in " + "graph dump."), + tensorflow::Flag( + "xla_hlo_graph_path", flag_values->mutable_xla_hlo_graph_path(), + "With xla_generate_hlo_graph, dump the graphs into this path."), + tensorflow::Flag( + "xla_hlo_dump_as_graphdef", + bool_setter_for(&DebugOptions::set_xla_hlo_dump_as_graphdef), + flag_values->xla_hlo_dump_as_graphdef(), + "Dump HLO graphs as TensorFlow GraphDefs."), + tensorflow::Flag( + "xla_hlo_graph_sharding_color", + bool_setter_for(&DebugOptions::set_xla_hlo_graph_sharding_color), + flag_values->xla_hlo_graph_sharding_color(), + "Assign colors based on sharding assignments when generating the " + "HLO graphs."), + tensorflow::Flag( + "xla_hlo_tfgraph_device_scopes", + bool_setter_for(&DebugOptions::set_xla_hlo_tfgraph_device_scopes), + flag_values->xla_hlo_tfgraph_device_scopes(), + "When generating TensorFlow HLO graphs, if the HLO instructions " + "are assigned to a specific device, prefix the name scope with " + "\"devX\" with X being the device ordinal."), + tensorflow::Flag( + "xla_log_hlo_text", flag_values->mutable_xla_log_hlo_text(), + "HLO modules matching this regex will be dumped to LOG(INFO)."), + tensorflow::Flag( + "xla_generate_hlo_text_to", + flag_values->mutable_xla_generate_hlo_text_to(), + "Dump all HLO modules as text into the provided directory path."), + tensorflow::Flag( + "xla_enable_fast_math", + bool_setter_for(&DebugOptions::set_xla_enable_fast_math), + flag_values->xla_enable_fast_math(), + "Enable unsafe fast-math optimizations in the compiler; " + "this may produce faster code at the expense of some accuracy."), + tensorflow::Flag( + "xla_llvm_enable_alias_scope_metadata", + bool_setter_for( + &DebugOptions::set_xla_llvm_enable_alias_scope_metadata), + flag_values->xla_llvm_enable_alias_scope_metadata(), + "In LLVM-based backends, enable the emission of " + "!alias.scope metadata in the generated IR."), + tensorflow::Flag( + "xla_llvm_enable_noalias_metadata", + bool_setter_for(&DebugOptions::set_xla_llvm_enable_noalias_metadata), + flag_values->xla_llvm_enable_noalias_metadata(), + "In LLVM-based backends, enable the emission of " + "!noalias metadata in the generated IR."), + tensorflow::Flag( + "xla_llvm_enable_invariant_load_metadata", + bool_setter_for( + &DebugOptions::set_xla_llvm_enable_invariant_load_metadata), + flag_values->xla_llvm_enable_invariant_load_metadata(), + "In LLVM-based backends, enable the emission of " + "!invariant.load metadata in " + "the generated IR."), + tensorflow::Flag( + "xla_llvm_disable_expensive_passes", + bool_setter_for(&DebugOptions::set_xla_llvm_disable_expensive_passes), + flag_values->xla_llvm_disable_expensive_passes(), + "In LLVM-based backends, disable a custom set of " + "expensive optimization passes."), + tensorflow::Flag( + "xla_backend_optimization_level", + int32_setter_for(&DebugOptions::set_xla_backend_optimization_level), + flag_values->xla_backend_optimization_level(), + "Numerical optimization level for the XLA compiler backend."), + tensorflow::Flag( + "xla_disable_hlo_passes", setter_for_xla_disable_hlo_passes, "", + "Comma-separated list of hlo passes to be disabled. These names " + "must exactly match the passes' names; no whitespace around " + "commas."), + tensorflow::Flag( + "xla_embed_ir_in_executable", + bool_setter_for(&DebugOptions::set_xla_embed_ir_in_executable), + flag_values->xla_embed_ir_in_executable(), + "Embed the compiler IR as a string in the executable."), + tensorflow::Flag( + "xla_dump_ir_to", flag_values->mutable_xla_dump_ir_to(), + "Dump the compiler IR into this directory as individual files."), + tensorflow::Flag( + "xla_eliminate_hlo_implicit_broadcast", + bool_setter_for( + &DebugOptions::set_xla_eliminate_hlo_implicit_broadcast), + flag_values->xla_eliminate_hlo_implicit_broadcast(), + "Eliminate implicit broadcasts when lowering user " + "computations to HLO instructions; use explicit " + "broadcast instead."), + tensorflow::Flag( + "xla_cpu_multi_thread_eigen", + bool_setter_for(&DebugOptions::set_xla_cpu_multi_thread_eigen), + flag_values->xla_cpu_multi_thread_eigen(), + "When generating calls to Eigen in the CPU backend, " + "use multi-threaded Eigen mode."), + tensorflow::Flag("xla_gpu_cuda_data_dir", + flag_values->mutable_xla_gpu_cuda_data_dir(), + "If non-empty, speficies a local directory containing " + "ptxas and nvvm libdevice files; otherwise we use " + "those from runfile directories."), + tensorflow::Flag("xla_gpu_ftz", + bool_setter_for(&DebugOptions::set_xla_gpu_ftz), + flag_values->xla_gpu_ftz(), + "If true, flush-to-zero semantics are enabled in the " + "code generated for GPUs."), + tensorflow::Flag( + "xla_gpu_disable_multi_streaming", + bool_setter_for(&DebugOptions::set_xla_gpu_disable_multi_streaming), + flag_values->xla_gpu_disable_multi_streaming(), + "If true, multi-streaming in the GPU backend is disabled."), + tensorflow::Flag( + "xla_dump_hlo_proto_to", flag_values->mutable_xla_dump_hlo_proto_to(), + "Dump compilation artifacts as proto binary into this directory."), + tensorflow::Flag( + "xla_test_all_output_layouts", + bool_setter_for(&DebugOptions::set_xla_test_all_output_layouts), + flag_values->xla_test_all_output_layouts(), + "Let ClientLibraryTestBase::ComputeAndCompare* test " + "all permutations of output layouts. For example, with " + "a 3D shape, all permutations of the set {0, 1, 2} are " + "tried."), + tensorflow::Flag( + "xla_test_all_input_layouts", + bool_setter_for(&DebugOptions::set_xla_test_all_input_layouts), + flag_values->xla_test_all_input_layouts(), + "Let ClientLibraryTestBase::ComputeAndCompare* test " + "all permutations of *input* layouts. For example, for " + "2 input arguments with 2D shape and 4D shape, the " + "computation will run 2! * 4! times for every possible " + "layouts"), + tensorflow::Flag( + "xla_hlo_profile", + bool_setter_for(&DebugOptions::set_xla_hlo_profile), + flag_values->xla_hlo_profile(), + "Instrument the computation to collect per-HLO cycle counts"), + tensorflow::Flag("xla_dump_computations_to", + flag_values->mutable_xla_dump_computations_to(), + "Dump computations that XLA executes into the provided " + "directory path"), + tensorflow::Flag("xla_dump_executions_to", + flag_values->mutable_xla_dump_executions_to(), + "Dump parameters and results of computations that XLA " + "executes into the provided directory path"), + tensorflow::Flag("xla_backend_extra_options", + setter_for_xla_backend_extra_options, "", + "Extra options to pass to a backend; " + "comma-separated list of 'key=val' strings (=val " + "may be omitted); no whitespace around commas."), + tensorflow::Flag("xla_reduce_precision", setter_for_xla_reduce_precision, + "", + "Directions for adding reduce-precision operations. " + "Format is 'LOCATION=E,M:OPS;NAMES' where LOCATION is " + "the class of locations in which to insert the " + "operations (e.g., 'OP_OUTPUTS'), E and M are the " + "exponent and matissa bit counts respectively, and " + "OPS and NAMES are comma-separated (no spaces) lists " + "of the operation types and names to which to attach " + "the reduce-precision operations. The NAMES string " + "and its preceding ';' may be omitted. This option " + "may be repeated to define multiple sets of added " + "reduce-precision operations."), + tensorflow::Flag( + "xla_gpu_use_cudnn_batchnorm", + bool_setter_for(&DebugOptions::set_xla_gpu_use_cudnn_batchnorm), + flag_values->xla_gpu_use_cudnn_batchnorm(), + "Allows the GPU backend to implement batchnorm HLOs using cudnn, " + "rather than expanding them to a soup of HLOs."), + }); ParseFlagsFromEnv(*flag_objects); } diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc index 93d3cd425f0a868b51677058796e9c40c2d3dff8..7f0201e74ab51f8f9906dd045ae7dfb96158f8e9 100644 --- a/tensorflow/compiler/xla/literal_util.cc +++ b/tensorflow/compiler/xla/literal_util.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -27,14 +27,20 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/casts.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" + +using tensorflow::strings::Printf; +using tensorflow::strings::StrCat; + +namespace xla { + namespace { -using tensorflow::int64; constexpr bool kLittleEndian = __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__; @@ -46,9 +52,8 @@ void ConvertEndianShort(char* bytes, int64 size) { std::swap(bytes[i], bytes[i + 1]); } } -} // namespace -namespace xla { +} // namespace std::ostream& operator<<(std::ostream& out, const Literal& literal) { out << literal.ToString(); @@ -64,12 +69,12 @@ Literal::StrideConfig::StrideConfig( if (!dimensions.empty()) { // Selects the shape with the largest minor dimension as the one upon // which to run the tight stride loop. - if (dimensions[source_shape.layout().minor_to_major()[0]] >= - dimensions[dest_shape.layout().minor_to_major()[0]]) { - minor_dimension = source_shape.layout().minor_to_major()[0]; + if (dimensions[LayoutUtil::Minor(source_shape.layout(), 0)] >= + dimensions[LayoutUtil::Minor(dest_shape.layout(), 0)]) { + minor_dimension = LayoutUtil::Minor(source_shape.layout(), 0); dest_stride = IndexUtil::GetDimensionStride(dest_shape, minor_dimension); } else { - minor_dimension = dest_shape.layout().minor_to_major()[0]; + minor_dimension = LayoutUtil::Minor(dest_shape.layout(), 0); source_stride = IndexUtil::GetDimensionStride(source_shape, minor_dimension); } @@ -78,52 +83,134 @@ Literal::StrideConfig::StrideConfig( } } +Literal::Literal(const Shape& shape) + : Literal(shape, /*allocate_arrays=*/true) {} + +Literal::Literal(const Shape& shape, bool allocate_arrays) + : shape_(shape), pieces_(shape), owns_buffers_(true) { + CHECK(LayoutUtil::HasLayout(shape)); + for (auto& pair : pieces_) { + const ShapeIndex& index = pair.first; + Piece& piece = pair.second; + + piece.set_subshape(&ShapeUtil::GetSubshape(shape_, index)); + const Shape& subshape = piece.subshape(); + if (ShapeUtil::IsArray(subshape)) { + if (allocate_arrays) { + piece.set_buffer(new char[piece.size_bytes()]); + if (LayoutUtil::IsSparseArray(subshape)) { + piece.set_sparse_indices(new SparseIndexArray( + LayoutUtil::MaxSparseElements(subshape.layout()), + ShapeUtil::Rank(subshape))); + } + } else { + piece.set_buffer(nullptr); + } + } + } +} + +Literal::~Literal() { DeallocateBuffers(); } + +void Literal::DeallocateBuffers() { + if (owns_buffers_) { + for (auto& pair : pieces_) { + Piece& piece = pair.second; + if (piece.buffer() != nullptr) { + delete[] piece.buffer(); + delete piece.sparse_indices(); + } + } + } +} + +Literal::Literal(Literal&& other) { + shape_ = std::move(other.shape_); + pieces_ = std::move(other.pieces_); + // We need to iterate through the pieces to set the subshape pointer + // properly. It must refer to subshapes within shape_. + for (auto& pair : pieces_) { + const ShapeIndex& index = pair.first; + Piece& piece = pair.second; + piece.set_subshape(&ShapeUtil::GetSubshape(shape_, index)); + } + owns_buffers_ = other.owns_buffers_; + + other.shape_ = ShapeUtil::MakeNil(); + other.pieces_ = ShapeTree(other.shape_); + other.piece({}).set_subshape(&other.shape_); +} + +Literal& Literal::operator=(Literal&& other) { + DeallocateBuffers(); + shape_ = std::move(other.shape_); + pieces_ = std::move(other.pieces_); + // We need to iterate through the pieces to set the subshape pointer + // properly. It must refer to subshapes within shape_. + for (auto& pair : pieces_) { + const ShapeIndex& index = pair.first; + Piece& piece = pair.second; + piece.set_subshape(&ShapeUtil::GetSubshape(shape_, index)); + } + owns_buffers_ = other.owns_buffers_; + + other.shape_ = ShapeUtil::MakeNil(); + other.pieces_ = ShapeTree(other.shape_); + other.piece({}).set_subshape(&other.shape_); + return *this; +} + std::unique_ptr Literal::CreateFromShape(const Shape& shape) { - auto literal = MakeUnique(); - *literal->mutable_shape() = shape; - if (ShapeUtil::IsTuple(shape)) { - int64 num_elements = ShapeUtil::TupleElementCount(shape); - literal->tuple_literals_.resize(num_elements); - for (int i = 0; i < num_elements; ++i) { - std::unique_ptr elem = - CreateFromShape(ShapeUtil::GetTupleElementShape(shape, i)); - literal->tuple_literals_[i] = std::move(*elem); + auto literal = MakeUnique(shape); + for (auto& pair : literal->pieces_) { + Piece& piece = pair.second; + if (ShapeUtil::IsArray(piece.subshape())) { + memset(piece.untyped_data(), 0, piece.size_bytes()); } - } else { - literal->Reserve(ShapeUtil::ElementsIn(literal->shape())); } return literal; } +const SparseIndexArray* Literal::sparse_indices( + const ShapeIndex& shape_index) const { + return piece(shape_index).sparse_indices(); +} + +SparseIndexArray* Literal::sparse_indices(const ShapeIndex& shape_index) { + return piece(shape_index).sparse_indices(); +} + /* static */ std::unique_ptr Literal::CreateFromDimensions( PrimitiveType primitive_type, tensorflow::gtl::ArraySlice dimensions) { return CreateFromShape(ShapeUtil::MakeShape(primitive_type, dimensions)); } -template -Status Literal::CopyRange(const Literal& src_literal, - tensorflow::gtl::ArraySlice src_base, - tensorflow::gtl::ArraySlice dest_base, - tensorflow::gtl::ArraySlice copy_size) { - const Shape& src_shape = src_literal.shape(); - const Shape& dest_shape = shape(); - tensorflow::gtl::ArraySlice src_data = src_literal.GetArraySlice(); - tensorflow::gtl::MutableArraySlice dest_data = GetMutableArraySlice(); - - TF_RET_CHECK(ShapeUtil::Rank(src_shape) == src_base.size()); - TF_RET_CHECK(ShapeUtil::Rank(dest_shape) == dest_base.size()); +template +Status Literal::CopySliceFromInternal( + const Literal& src_literal, tensorflow::gtl::ArraySlice src_base, + tensorflow::gtl::ArraySlice dest_base, + tensorflow::gtl::ArraySlice copy_size) { + TF_RET_CHECK(ShapeUtil::Rank(src_literal.shape()) == src_base.size()); + TF_RET_CHECK(ShapeUtil::Rank(shape()) == dest_base.size()); + + auto linear_index = [](const Shape& shape, + tensorflow::gtl::ArraySlice multi_index) { + return IndexUtil::MultidimensionalIndexToLinearIndex(shape, multi_index); + }; - if (ShapeUtil::Rank(src_shape) == 0 || ShapeUtil::Rank(dest_shape) == 0) { + if (ShapeUtil::Rank(src_literal.shape()) == 0 || + ShapeUtil::Rank(shape()) == 0) { // If any of the two shapes are scalars, we can just call the StridedCopy() // directly, and we know we will be copying only one value. TF_RET_CHECK(copy_size.empty()); - StridedCopy(dest_data, LinearIndex(dest_base), 0, src_data, - src_literal.LinearIndex(src_base), 0, 1); - } else if (!ShapeUtil::HasZeroElements(dest_shape) && - !ShapeUtil::HasZeroElements(src_shape)) { - // Perform copy if neither src literal nor dest literal has dimensions with - // zero element, otherwise it's a no-op. + StridedCopy(data(), linear_index(shape(), dest_base), 0, + src_literal.data(), + linear_index(src_literal.shape(), src_base), 0, 1); + } else if (!ShapeUtil::HasZeroElements(shape()) && + !ShapeUtil::HasZeroElements(src_literal.shape())) { + // Perform copy if neither src nor dest has dimensions with zero element, + // otherwise it's a no-op. TF_RET_CHECK(src_base.size() == dest_base.size()); TF_RET_CHECK(src_base.size() == copy_size.size()); @@ -133,7 +220,8 @@ Status Literal::CopyRange(const Literal& src_literal, // proper stride size at the matching dimension. DimensionVector src_indexes(src_base.size(), 0); DimensionVector dest_indexes(dest_base.size(), 0); - StrideConfig stride_config(src_shape, dest_shape, copy_size); + Literal::StrideConfig stride_config(src_literal.shape(), shape(), + copy_size); auto copy_proc = [&](const std::vector& indexes) { // Map from multi-dimensional index, to source index. @@ -143,89 +231,295 @@ Status Literal::CopyRange(const Literal& src_literal, std::transform(indexes.begin(), indexes.end(), dest_base.begin(), dest_indexes.begin(), std::plus()); - int64 src_index = src_literal.LinearIndex(src_indexes); - int64 dest_index = LinearIndex(dest_indexes); + int64 src_index = linear_index(src_literal.shape(), src_indexes); + int64 dest_index = linear_index(shape(), dest_indexes); - StridedCopy(dest_data, dest_index, stride_config.dest_stride, src_data, - src_index, stride_config.source_stride, - stride_config.minor_loop_size); + StridedCopy(data(), dest_index, stride_config.dest_stride, + src_literal.data(), src_index, + stride_config.source_stride, stride_config.minor_loop_size); return true; }; - ShapeUtil::ForEachIndex(src_shape, stride_config.base, + ShapeUtil::ForEachIndex(src_literal.shape(), stride_config.base, stride_config.dimensions, stride_config.step, copy_proc); } return Status::OK(); } -Status Literal::Copy(const Literal& src_literal, - tensorflow::gtl::ArraySlice src_base, - tensorflow::gtl::ArraySlice dest_base, - tensorflow::gtl::ArraySlice copy_size) { +std::vector Literal::DecomposeTuple() { + CHECK(ShapeUtil::IsTuple(shape())); + std::vector elements; + for (int i = 0; i < ShapeUtil::TupleElementCount(shape()); ++i) { + elements.push_back(Literal(ShapeUtil::GetSubshape(shape(), {i}), + /*allocate_arrays=*/false)); + Literal& element = elements.back(); + for (auto& pair : element.pieces_) { + const ShapeIndex& index = pair.first; + Piece& dest_piece = pair.second; + ShapeIndex src_index = {i}; + for (int64 j : index) { + src_index.push_back(j); + } + Piece& src_piece = piece(src_index); + + // Move the respective buffer and sparse indices over to the element + // Literal. + dest_piece.set_buffer(src_piece.buffer()); + src_piece.set_buffer(nullptr); + dest_piece.set_sparse_indices(src_piece.sparse_indices()); + src_piece.set_sparse_indices(nullptr); + } + } + // Set this literal to be nil-shaped. + *this = Literal(); + return elements; +} + +/* static */ Literal Literal::MoveIntoTuple( + tensorflow::gtl::MutableArraySlice elements) { + std::vector element_shapes; + for (const Literal& element : elements) { + element_shapes.push_back(element.shape()); + } + Literal literal(ShapeUtil::MakeTupleShape(element_shapes), + /*allocate_arrays=*/false); + for (int i = 0; i < elements.size(); ++i) { + TF_CHECK_OK( + literal.MoveFrom(std::move(elements[i]), /*dest_shape_index=*/{i})); + } + return literal; +} + +namespace { + +// Copies the elements in 'src' to 'dest'. The shape and layout of the data in +// the array slices are indicated by dest_shape and src_shape respectively. +template +void CopyElementsBetween(tensorflow::gtl::MutableArraySlice dest, + tensorflow::gtl::ArraySlice src, + const Shape& dest_shape, const Shape& src_shape) { + CHECK(ShapeUtil::Compatible(dest_shape, src_shape)); + if (ShapeUtil::HasZeroElements(dest_shape)) { + return; + } + std::vector index(ShapeUtil::Rank(dest_shape)); + do { + dest[IndexUtil::MultidimensionalIndexToLinearIndex(dest_shape, index)] = + src[IndexUtil::MultidimensionalIndexToLinearIndex(src_shape, index)]; + } while (IndexUtil::BumpIndices(dest_shape, &index)); +} + +} // namespace + +Status Literal::Piece::CopyFrom(const Literal::Piece& src) { + if (ShapeUtil::Equal(subshape(), src.subshape())) { + // If the layouts are equal it's faster just to memcpy. + memcpy(buffer(), src.buffer(), src.size_bytes()); + } else { + TF_RET_CHECK(ShapeUtil::Compatible(src.subshape(), subshape())); + std::vector origin(ShapeUtil::Rank(subshape()), 0); + switch (subshape().element_type()) { +#define COPY_ELEMENTS(XLA_T, NATIVE_T) \ + case (XLA_T): \ + CopyElementsBetween(data(), src.data(), \ + subshape(), src.subshape()); \ + break; + COPY_ELEMENTS(U8, uint8); + COPY_ELEMENTS(U16, uint16); + COPY_ELEMENTS(U32, uint32); + COPY_ELEMENTS(U64, uint64); + COPY_ELEMENTS(S8, int8); + COPY_ELEMENTS(S16, int16); + COPY_ELEMENTS(S32, int32); + COPY_ELEMENTS(S64, int64); + COPY_ELEMENTS(F16, half); + COPY_ELEMENTS(BF16, bfloat16); + COPY_ELEMENTS(F32, float); + COPY_ELEMENTS(F64, double); + COPY_ELEMENTS(C64, complex64); + COPY_ELEMENTS(PRED, bool); +#undef COPY_ELEMENTS + default: + return Unimplemented( + "Unhandled primitive type %s", + PrimitiveType_Name(subshape().element_type()).c_str()); + } + } + return Status::OK(); +} + +Status Literal::CopyFrom(const Literal& src_literal, + const ShapeIndex& dest_shape_index, + const ShapeIndex& src_shape_index) { + const Shape& dest_subshape = + ShapeUtil::GetSubshape(shape(), dest_shape_index); + const Shape& src_subshape = + ShapeUtil::GetSubshape(src_literal.shape(), src_shape_index); + if (!ShapeUtil::Compatible(dest_subshape, src_subshape)) { + return InvalidArgument( + "Destination subshape incompatible with source subshape: %s vs %s", + ShapeUtil::HumanString(dest_subshape).c_str(), + ShapeUtil::HumanString(src_subshape).c_str()); + } + + for (auto& pair : pieces_) { + const ShapeIndex& index = pair.first; + Piece& piece = pair.second; + if (!ShapeUtil::IsArray(piece.subshape())) { + continue; + } + + // Determine if this index is in the part of this literal that we want to + // copy over from src_literal. + bool in_subtree_to_copy = true; + for (int i = 0; i < dest_shape_index.size(); ++i) { + if (index[i] != dest_shape_index[i]) { + in_subtree_to_copy = false; + break; + } + } + if (!in_subtree_to_copy) { + continue; + } + + // Construct the index of the corresponding piece in the source literal. + ShapeIndex src_piece_index = src_shape_index; + for (int64 i = dest_shape_index.size(); i < index.size(); ++i) { + src_piece_index.push_back(index[i]); + } + + TF_RETURN_IF_ERROR(piece.CopyFrom(src_literal.piece(src_piece_index))); + } + return Status::OK(); +} + +Status Literal::MoveFrom(Literal&& src_literal, + const ShapeIndex& dest_shape_index) { + const Shape& dest_subshape = + ShapeUtil::GetSubshape(shape(), dest_shape_index); + if (!ShapeUtil::Equal(dest_subshape, src_literal.shape())) { + return InvalidArgument( + "Destination subshape not equal to source shape: %s vs %s", + ShapeUtil::HumanString(dest_subshape).c_str(), + ShapeUtil::HumanString(src_literal.shape()).c_str()); + } + + if (!(owns_buffers_ && src_literal.owns_buffers_)) { + return InvalidArgument( + "Source and destination literals must both own their buffers (ie, not " + "be views)"); + } + + for (auto& pair : src_literal.pieces_) { + const ShapeIndex& src_index = pair.first; + Piece& src_piece = pair.second; + if (!ShapeUtil::IsArray(src_piece.subshape())) { + continue; + } + + ShapeIndex dest_index = dest_shape_index; + for (int64 i : src_index) { + dest_index.push_back(i); + } + Piece& dest_piece = piece(dest_index); + delete[] dest_piece.buffer(); + dest_piece.set_buffer(src_piece.buffer()); + delete dest_piece.sparse_indices(); + dest_piece.set_sparse_indices(src_piece.sparse_indices()); + } + + src_literal.shape_ = ShapeUtil::MakeNil(); + src_literal.pieces_ = ShapeTree(src_literal.shape_); + src_literal.piece({}).set_subshape(&src_literal.shape_); + return Status::OK(); +} + +Status Literal::CopySliceFrom(const Literal& src_literal, + tensorflow::gtl::ArraySlice src_base, + tensorflow::gtl::ArraySlice dest_base, + tensorflow::gtl::ArraySlice copy_size) { + TF_RET_CHECK(ShapeUtil::IsArray(shape())) << ShapeUtil::HumanString(shape()); + TF_RET_CHECK(ShapeUtil::IsArray(src_literal.shape())) + << ShapeUtil::HumanString(src_literal.shape()); TF_RET_CHECK(ShapeUtil::SameElementType(src_literal.shape(), shape())); - switch (src_literal.shape().element_type()) { + + switch (shape().element_type()) { case U8: - return CopyRange(src_literal, src_base, dest_base, copy_size); + return CopySliceFromInternal(src_literal, src_base, dest_base, + copy_size); case U16: - return CopyRange(src_literal, src_base, dest_base, copy_size); + return CopySliceFromInternal(src_literal, src_base, dest_base, + copy_size); case U32: - return CopyRange(src_literal, src_base, dest_base, copy_size); + return CopySliceFromInternal(src_literal, src_base, dest_base, + copy_size); case U64: - return CopyRange(src_literal, src_base, dest_base, copy_size); + return CopySliceFromInternal(src_literal, src_base, dest_base, + copy_size); case S8: - return CopyRange(src_literal, src_base, dest_base, copy_size); + return CopySliceFromInternal(src_literal, src_base, dest_base, + copy_size); case S16: - return CopyRange(src_literal, src_base, dest_base, copy_size); + return CopySliceFromInternal(src_literal, src_base, dest_base, + copy_size); case S32: - return CopyRange(src_literal, src_base, dest_base, copy_size); + return CopySliceFromInternal(src_literal, src_base, dest_base, + copy_size); case S64: - return CopyRange(src_literal, src_base, dest_base, copy_size); + return CopySliceFromInternal(src_literal, src_base, dest_base, + copy_size); case F16: - return CopyRange(src_literal, src_base, dest_base, copy_size); + return CopySliceFromInternal(src_literal, src_base, dest_base, + copy_size); case BF16: - return CopyRange(src_literal, src_base, dest_base, copy_size); + return CopySliceFromInternal(src_literal, src_base, dest_base, + copy_size); case F32: - return CopyRange(src_literal, src_base, dest_base, copy_size); + return CopySliceFromInternal(src_literal, src_base, dest_base, + copy_size); case F64: - return CopyRange(src_literal, src_base, dest_base, copy_size); + return CopySliceFromInternal(src_literal, src_base, dest_base, + copy_size); case C64: - return CopyRange(src_literal, src_base, dest_base, copy_size); + return CopySliceFromInternal(src_literal, src_base, dest_base, + copy_size); case PRED: - return CopyRange(src_literal, src_base, dest_base, copy_size); + return CopySliceFromInternal(src_literal, src_base, dest_base, + copy_size); default: break; } - return Unimplemented("Unhandled primitive type %d", - src_literal.shape().element_type()); + return Unimplemented("Unhandled primitive type %d", shape().element_type()); } /* static */ Literal Literal::Zero(PrimitiveType primitive_type) { switch (primitive_type) { case U8: - return *Literal::CreateR0(0); + return std::move(*Literal::CreateR0(0)); case U32: - return *Literal::CreateR0(0); + return std::move(*Literal::CreateR0(0)); case U64: - return *Literal::CreateR0(0); + return std::move(*Literal::CreateR0(0)); case S8: - return *Literal::CreateR0(0); + return std::move(*Literal::CreateR0(0)); case S32: - return *Literal::CreateR0(0); + return std::move(*Literal::CreateR0(0)); case S64: - return *Literal::CreateR0(0); + return std::move(*Literal::CreateR0(0)); case F16: - return *Literal::CreateR0(static_cast(0.0f)); + return std::move(*Literal::CreateR0(static_cast(0.0f))); case BF16: - return *Literal::CreateR0(static_cast(0.0f)); + return std::move( + *Literal::CreateR0(static_cast(0.0f))); case F32: - return *Literal::CreateR0(0); + return std::move(*Literal::CreateR0(0)); case F64: - return *Literal::CreateR0(0); + return std::move(*Literal::CreateR0(0)); case C64: - return *Literal::CreateR0(0); + return std::move(*Literal::CreateR0(0)); case PRED: - return *Literal::CreateR0(false); + return std::move(*Literal::CreateR0(false)); case S16: case U16: LOG(FATAL) << "u16/s16 literals not yet implemented"; @@ -241,30 +535,33 @@ Status Literal::Copy(const Literal& src_literal, /* static */ Literal Literal::One(PrimitiveType primitive_type) { switch (primitive_type) { case U8: - return *Literal::CreateR0(1); + return std::move(*Literal::CreateR0(1)); case U32: - return *Literal::CreateR0(1); + return std::move(*Literal::CreateR0(1)); case U64: - return *Literal::CreateR0(1); + return std::move(*Literal::CreateR0(1)); case S8: - return *Literal::CreateR0(1); + return std::move(*Literal::CreateR0(1)); case S32: - return *Literal::CreateR0(1); + return std::move(*Literal::CreateR0(1)); case S64: - return *Literal::CreateR0(1); + return std::move(*Literal::CreateR0(1)); + case F16: + return std::move(*Literal::CreateR0(static_cast(1.0f))); + case BF16: + return std::move( + *Literal::CreateR0(static_cast(1.0f))); case F32: - return *Literal::CreateR0(1); + return std::move(*Literal::CreateR0(1)); case F64: - return *Literal::CreateR0(1); + return std::move(*Literal::CreateR0(1)); case C64: - return *Literal::CreateR0(1); + return std::move(*Literal::CreateR0(1)); case PRED: - return *Literal::CreateR0(true); + return std::move(*Literal::CreateR0(true)); case S16: case U16: LOG(FATAL) << "u16/s16 literals not yet implemented"; - case F16: - return *Literal::CreateR0(static_cast(1.0f)); case TUPLE: LOG(FATAL) << "tuple element type cannot take on value of 1"; case OPAQUE: @@ -277,35 +574,42 @@ Status Literal::Copy(const Literal& src_literal, /* static */ Literal Literal::MinValue(PrimitiveType primitive_type) { switch (primitive_type) { case U8: - return *Literal::CreateR0(std::numeric_limits::min()); + return std::move( + *Literal::CreateR0(std::numeric_limits::min())); case U32: - return *Literal::CreateR0(std::numeric_limits::min()); + return std::move( + *Literal::CreateR0(std::numeric_limits::min())); case U64: - return *Literal::CreateR0(std::numeric_limits::min()); + return std::move( + *Literal::CreateR0(std::numeric_limits::min())); case S8: - return *Literal::CreateR0(std::numeric_limits::min()); + return std::move( + *Literal::CreateR0(std::numeric_limits::min())); case S32: - return *Literal::CreateR0(std::numeric_limits::min()); + return std::move( + *Literal::CreateR0(std::numeric_limits::min())); case S64: - return *Literal::CreateR0(std::numeric_limits::min()); + return std::move( + *Literal::CreateR0(std::numeric_limits::min())); case F32: - return *Literal::CreateR0(-std::numeric_limits::infinity()); + return std::move( + *Literal::CreateR0(-std::numeric_limits::infinity())); case F64: - return *Literal::CreateR0( - -std::numeric_limits::infinity()); + return std::move( + *Literal::CreateR0(-std::numeric_limits::infinity())); case C64: LOG(FATAL) << "C64 element type has no minimum value"; case PRED: - return *Literal::CreateR0(false); + return std::move(*Literal::CreateR0(false)); case S16: case U16: LOG(FATAL) << "u16/s16 literals not yet implemented"; case F16: - return *Literal::CreateR0( - static_cast(-std::numeric_limits::infinity())); + return std::move(*Literal::CreateR0( + static_cast(-std::numeric_limits::infinity()))); case BF16: - return *Literal::CreateR0( - static_cast(-std::numeric_limits::infinity())); + return std::move(*Literal::CreateR0( + static_cast(-std::numeric_limits::infinity()))); case TUPLE: LOG(FATAL) << "tuple element type has no minimum value"; case OPAQUE: @@ -318,33 +622,40 @@ Status Literal::Copy(const Literal& src_literal, /* static */ Literal Literal::MaxValue(PrimitiveType primitive_type) { switch (primitive_type) { case U8: - return *Literal::CreateR0(std::numeric_limits::max()); + return std::move( + *Literal::CreateR0(std::numeric_limits::max())); case U32: - return *Literal::CreateR0(std::numeric_limits::max()); + return std::move( + *Literal::CreateR0(std::numeric_limits::max())); case U64: - return *Literal::CreateR0(std::numeric_limits::max()); + return std::move( + *Literal::CreateR0(std::numeric_limits::max())); case S8: - return *Literal::CreateR0(std::numeric_limits::max()); + return std::move( + *Literal::CreateR0(std::numeric_limits::max())); case S32: - return *Literal::CreateR0(std::numeric_limits::max()); + return std::move( + *Literal::CreateR0(std::numeric_limits::max())); case S64: - return *Literal::CreateR0(std::numeric_limits::max()); + return std::move( + *Literal::CreateR0(std::numeric_limits::max())); case F32: - return *Literal::CreateR0(std::numeric_limits::infinity()); + return std::move( + *Literal::CreateR0(std::numeric_limits::infinity())); case F64: - return *Literal::CreateR0( - std::numeric_limits::infinity()); + return std::move( + *Literal::CreateR0(std::numeric_limits::infinity())); case PRED: - return *Literal::CreateR0(true); + return std::move(*Literal::CreateR0(true)); case S16: case U16: LOG(FATAL) << "u16/s16 literals not yet implemented"; case F16: - return *Literal::CreateR0( - static_cast(std::numeric_limits::infinity())); + return std::move(*Literal::CreateR0( + static_cast(std::numeric_limits::infinity()))); case BF16: - return *Literal::CreateR0( - static_cast(std::numeric_limits::infinity())); + return std::move(*Literal::CreateR0( + static_cast(std::numeric_limits::infinity()))); case TUPLE: LOG(FATAL) << "tuple element type has no maximum value"; case OPAQUE: @@ -356,17 +667,29 @@ Status Literal::Copy(const Literal& src_literal, /* static */ std::unique_ptr Literal::CreateR1( const tensorflow::core::Bitmap& values) { - auto literal = MakeUnique(); + auto literal = MakeUnique( + ShapeUtil::MakeShape(PRED, {static_cast(values.bits())})); literal->PopulateR1(values); return literal; } +void Literal::PopulateR1(const tensorflow::core::Bitmap& values) { + CHECK(ShapeUtil::IsArray(shape())); + CHECK_EQ(ShapeUtil::Rank(shape()), 1); + CHECK_EQ(element_count(), values.bits()); + CHECK_EQ(shape().element_type(), PRED); + for (int64 i = 0; i < static_cast(values.bits()); ++i) { + Set({i}, values.get(i)); + } +} + /* static */ std::unique_ptr Literal::CreateR1U8( tensorflow::StringPiece value) { - auto literal = MakeUnique(); - *literal->mutable_shape() = - ShapeUtil::MakeShape(U8, {static_cast(value.size())}); - literal->set_u8s(tensorflow::StringPiece(value.ToString())); + auto literal = MakeUnique( + ShapeUtil::MakeShape(U8, {static_cast(value.size())})); + for (int i = 0; i < value.size(); ++i) { + literal->Set({i}, value[i]); + } return literal; } @@ -380,46 +703,50 @@ Status Literal::Copy(const Literal& src_literal, std::unique_ptr Literal::Relayout( const Layout& new_layout, const ShapeIndex& shape_index) const { - std::unique_ptr outer_result = CloneToUnique(); - - const Literal* copy_from = this; - Literal* copy_to = outer_result.get(); - for (int64 i = 0; i < shape_index.size(); i++) { - *ShapeUtil::GetMutableSubshape(copy_to->mutable_shape(), {shape_index, i}) - ->mutable_layout() = new_layout; - copy_from = ©_from->tuple_literals_[shape_index[i]]; - copy_to = ©_to->tuple_literals_[shape_index[i]]; - } - - DimensionVector base(ShapeUtil::Rank(copy_from->shape()), 0); - DimensionVector copy_size(copy_from->shape().dimensions().begin(), - copy_from->shape().dimensions().end()); + // Create new shape with 'new_layout' set at the given shape index. + Shape new_shape = shape(); + Shape* subshape = ShapeUtil::GetMutableSubshape(&new_shape, shape_index); + TF_CHECK_OK(LayoutUtil::ValidateLayoutForShape(new_layout, *subshape)); + *subshape->mutable_layout() = new_layout; + auto result = MakeUnique(new_shape); + TF_CHECK_OK(result->CopyFrom(*this)); + return result; +} - CHECK(ShapeUtil::IsArray(copy_from->shape())); - CHECK(ShapeUtil::IsArray(copy_to->shape())); - *copy_to->mutable_shape()->mutable_layout() = new_layout; - TF_CHECK_OK(copy_to->Copy(*copy_from, base, base, copy_size)); - return outer_result; +std::unique_ptr Literal::Relayout( + const Shape& shape_with_layout) const { + CHECK(ShapeUtil::Compatible(shape_with_layout, shape())) + << "Given shape_with_layout " << ShapeUtil::HumanString(shape_with_layout) + << " not compatible with literal shape " + << ShapeUtil::HumanString(shape()); + std::unique_ptr result = CreateFromShape(shape_with_layout); + ShapeUtil::ForEachSubshape( + result->shape(), + [this, &result](const Shape& subshape, const ShapeIndex& index) { + if (ShapeUtil::IsArray(subshape)) { + TF_CHECK_OK(result->CopyFrom(*this, + /*dest_shape_index=*/index, + /*src_shape_index=*/index)); + } + }); + return result; } StatusOr> Literal::Reshape( tensorflow::gtl::ArraySlice dimensions) const { - if (ShapeUtil::IsTuple(shape())) { + if (!ShapeUtil::IsArray(shape())) { return InvalidArgument("Reshape does not support tuples."); } std::unique_ptr output; if (!LayoutUtil::IsMonotonicWithDim0Major(shape().layout())) { - std::vector minor_to_major(ShapeUtil::Rank(shape())); - std::iota(minor_to_major.rbegin(), minor_to_major.rend(), - static_cast(0)); - output = Relayout(LayoutUtil::MakeLayout(minor_to_major)); + output = + Relayout(LayoutUtil::GetDefaultLayoutForRank(ShapeUtil::Rank(shape()))); } else { output = CloneToUnique(); } // Because the layout is monotonic, we can simply reuse the same sequence of // values without changing their order. - *output->mutable_shape() = - ShapeUtil::MakeShape(shape().element_type(), dimensions); + output->shape_ = ShapeUtil::MakeShape(shape().element_type(), dimensions); int64 elements_before = ShapeUtil::ElementsIn(shape()); int64 elements_after = ShapeUtil::ElementsIn(output->shape()); @@ -435,7 +762,7 @@ StatusOr> Literal::Reshape( std::unique_ptr Literal::Transpose( tensorflow::gtl::ArraySlice permutation) const { - CHECK(!ShapeUtil::IsTuple(shape())) << "Tuple is not supported for transpose"; + CHECK(ShapeUtil::IsArray(shape())) << "Tuple is not supported for transpose"; CHECK(IsPermutation(permutation, ShapeUtil::Rank(shape()))) << "Given permutation is not a permutation of dimension numbers"; // To transpose the array, we just permute the dimensions and layout, and @@ -458,23 +785,24 @@ std::unique_ptr Literal::Transpose( // dimension has within the transposed array, a layout is affine if // MinMaj(Di) == TMinMaj(T(Di)), with TMinMaj() being the minor to major // vector of the affine layout. + CHECK(LayoutUtil::IsDenseArray(permuted_shape)); Layout* layout = permuted_shape.mutable_layout(); layout->clear_minor_to_major(); - for (auto index : shape().layout().minor_to_major()) { + for (auto index : LayoutUtil::MinorToMajor(shape())) { layout->add_minor_to_major(inverse_permutation[index]); } std::unique_ptr new_literal = CreateFromShape(permuted_shape); DCHECK_GE(ShapeUtil::ByteSizeOf(new_literal->shape()), ShapeUtil::ByteSizeOf(shape())); - std::memcpy(new_literal->MutableInternalData(), InternalData(), - ShapeUtil::ByteSizeOf(shape())); + std::memcpy(new_literal->root_piece().buffer(), root_piece().buffer(), + root_piece().size_bytes()); return new_literal; } std::unique_ptr Literal::Slice( tensorflow::gtl::ArraySlice start_indices, tensorflow::gtl::ArraySlice limit_indices) const { - CHECK(!ShapeUtil::IsTuple(shape())) << "tuple is not supported for reshape"; + CHECK(ShapeUtil::IsArray(shape())) << "tuple is not supported for slice"; DimensionVector result_dimensions; for (int64 dnum = 0; dnum < ShapeUtil::Rank(shape()); ++dnum) { @@ -484,13 +812,11 @@ std::unique_ptr Literal::Slice( CHECK_GT(dimension, 0); result_dimensions.push_back(dimension); } - const auto result_shape = ShapeUtil::MakeShapeWithLayout( - shape().element_type(), result_dimensions, - AsInt64Slice(shape().layout().minor_to_major())); + const auto result_shape = + ShapeUtil::MakeShapeWithLayout(shape().element_type(), result_dimensions, + LayoutUtil::MinorToMajor(shape())); - auto result_literal = MakeUnique(); - *result_literal->mutable_shape() = result_shape; - result_literal->Reserve(ShapeUtil::ElementsIn(result_shape)); + auto result_literal = MakeUnique(result_shape); DimensionVector new_indices(ShapeUtil::Rank(result_shape)); switch (result_shape.element_type()) { @@ -530,48 +856,116 @@ std::unique_ptr Literal::Slice( } } +Literal Literal::Clone() const { + Literal result(shape()); + TF_CHECK_OK(result.CopyFrom(*this)); + return result; +} + std::unique_ptr Literal::CloneToUnique() const { - auto unique = MakeUnique(); - *unique = *this; - return unique; + auto result = MakeUnique(shape()); + TF_CHECK_OK(result->CopyFrom(*this)); + return result; } -string Literal::GetAsString( - tensorflow::gtl::ArraySlice multi_index) const { - switch (shape().element_type()) { +string Literal::GetAsString(tensorflow::gtl::ArraySlice multi_index, + const ShapeIndex& shape_index) const { + const Shape& subshape = ShapeUtil::GetSubshape(shape(), shape_index); + CHECK(LayoutUtil::IsDenseArray(subshape)); + switch (subshape.element_type()) { case PRED: - return Get(multi_index) ? "true" : "false"; - case U8: - return tensorflow::strings::StrCat(Get(multi_index)); + return Get(multi_index, shape_index) ? "true" : "false"; + case S8: + return StrCat(Get(multi_index, shape_index)); + case S16: + return StrCat(Get(multi_index, shape_index)); case S32: - return tensorflow::strings::StrCat(Get(multi_index)); + return StrCat(Get(multi_index, shape_index)); case S64: - return tensorflow::strings::StrCat(Get(multi_index)); + return StrCat(Get(multi_index, shape_index)); + case U8: + return StrCat(Get(multi_index, shape_index)); + case U16: + return StrCat(Get(multi_index, shape_index)); case U32: - return tensorflow::strings::StrCat(Get(multi_index)); + return StrCat(Get(multi_index, shape_index)); case U64: - return tensorflow::strings::StrCat(Get(multi_index)); + return StrCat(Get(multi_index, shape_index)); + case F16: + return StrCat(Get(multi_index, shape_index)); case F32: - return tensorflow::strings::StrCat(Get(multi_index)); + return StrCat(Get(multi_index, shape_index)); + case BF16: + return StrCat( + static_cast(Get(multi_index, shape_index))); case F64: - return tensorflow::strings::StrCat(Get(multi_index)); + return StrCat(Get(multi_index, shape_index)); case C64: { - complex64 c = Get(multi_index); - return tensorflow::strings::StrCat("(", c.real(), ", ", c.imag(), ")"); + complex64 c = Get(multi_index, shape_index); + return StrCat("(", c.real(), ", ", c.imag(), ")"); } + default: + LOG(FATAL) << PrimitiveType_Name(subshape.element_type()); + } +} + +string Literal::GetSparseElementAsString(int64 sparse_element_number, + const ShapeIndex& shape_index) const { + const Shape& subshape = ShapeUtil::GetSubshape(shape(), shape_index); + CHECK(LayoutUtil::IsSparseArray(subshape)); + switch (subshape.element_type()) { + case PRED: + return GetSparseElement(sparse_element_number, shape_index) + ? "true" + : "false"; + case S8: + return StrCat(GetSparseElement(sparse_element_number, shape_index)); + case S16: + return StrCat( + GetSparseElement(sparse_element_number, shape_index)); + case S32: + return StrCat( + GetSparseElement(sparse_element_number, shape_index)); + case S64: + return StrCat( + GetSparseElement(sparse_element_number, shape_index)); + case U8: + return StrCat( + GetSparseElement(sparse_element_number, shape_index)); + case U16: + return StrCat( + GetSparseElement(sparse_element_number, shape_index)); + case U32: + return StrCat( + GetSparseElement(sparse_element_number, shape_index)); + case U64: + return StrCat( + GetSparseElement(sparse_element_number, shape_index)); case F16: - return tensorflow::strings::StrCat(Get(multi_index)); + return StrCat(GetSparseElement(sparse_element_number, shape_index)); + case F32: + return StrCat( + GetSparseElement(sparse_element_number, shape_index)); case BF16: - return tensorflow::strings::StrCat( - static_cast(Get(multi_index))); + return StrCat(static_cast( + GetSparseElement(sparse_element_number, shape_index))); + case F64: + return StrCat( + GetSparseElement(sparse_element_number, shape_index)); + case C64: { + complex64 c = + GetSparseElement(sparse_element_number, shape_index); + return StrCat("(", c.real(), ", ", c.imag(), ")"); + } default: - return tensorflow::strings::StrCat( - "[", PrimitiveType_Name(shape().element_type()), "]"); + LOG(FATAL) << "Invalid element type for sparse arrays: " + << PrimitiveType_Name(subshape.element_type()); } } StatusOr Literal::GetIntegralAsS64( tensorflow::gtl::ArraySlice multi_index) const { + CHECK(LayoutUtil::IsDenseArray(shape())); switch (shape().element_type()) { case PRED: return Get(multi_index); @@ -592,13 +986,83 @@ StatusOr Literal::GetIntegralAsS64( } } -int64 Literal::LinearIndex( - tensorflow::gtl::ArraySlice multi_index) const { - return IndexUtil::MultidimensionalIndexToLinearIndex(shape(), multi_index); +tensorflow::gtl::ArraySlice Literal::GetSparseIndex( + int64 sparse_element_number, const ShapeIndex& shape_index) const { + const Piece& p = piece(shape_index); + CHECK_GE(sparse_element_number, 0); + CHECK_LT(sparse_element_number, p.sparse_indices()->index_count()); + return p.sparse_indices()->At(sparse_element_number); } -string Literal::ToString(bool print_layout) const { - std::vector pieces; +void Literal::SortSparseElements(const ShapeIndex& shape_index) { + piece(shape_index).SortSparseElements(); +} + +void Literal::Piece::SortSparseElements() { + switch (subshape().element_type()) { + case PRED: + SortSparseElementsInternal(); + break; + case S8: + SortSparseElementsInternal(); + break; + case U8: + SortSparseElementsInternal(); + break; + case S16: + SortSparseElementsInternal(); + break; + case U16: + SortSparseElementsInternal(); + break; + case S32: + SortSparseElementsInternal(); + break; + case U32: + SortSparseElementsInternal(); + break; + case S64: + SortSparseElementsInternal(); + break; + case U64: + SortSparseElementsInternal(); + break; + case F32: + SortSparseElementsInternal(); + break; + case F64: + SortSparseElementsInternal(); + break; + case C64: + SortSparseElementsInternal(); + break; + case F16: + SortSparseElementsInternal(); + break; + case BF16: + SortSparseElementsInternal(); + break; + default: + LOG(FATAL) << "Element type not valid for sparse array: " + << PrimitiveType_Name(subshape().element_type()); + } +} + +template +void Literal::Piece::SortSparseElementsInternal() { + CHECK(LayoutUtil::IsSparseArray(subshape())); + int64 num_elements = sparse_indices()->index_count(); + auto values = data(); + CHECK_LE(num_elements, values.size()); + sparse_indices()->SortWithValues( + tensorflow::gtl::MutableArraySlice(values.data(), num_elements)); +} + +namespace { + +void ToStringHelper(const Literal& literal, const ShapeIndex& shape_index, + bool print_layout, std::vector* pieces) { + const Shape& subshape = ShapeUtil::GetSubshape(literal.shape(), shape_index); auto shape_to_string = [print_layout](const Shape& shape) { if (print_layout) { @@ -608,272 +1072,186 @@ string Literal::ToString(bool print_layout) const { } }; + // TODO(b/32894291): refactor this code to reduce code duplication. + if (ShapeUtil::IsTuple(subshape)) { + pieces->push_back(shape_to_string(subshape)); + pieces->push_back(" (\n"); + std::vector tuple_pieces; + for (int i = 0; i < ShapeUtil::TupleElementCount(subshape); ++i) { + ShapeIndex element_index = shape_index; + element_index.push_back(i); + std::vector element_pieces; + ToStringHelper(literal, element_index, print_layout, &element_pieces); + tuple_pieces.push_back(tensorflow::str_util::Join(element_pieces, "")); + } + pieces->push_back(tensorflow::str_util::Join(tuple_pieces, ",\n")); + pieces->push_back("\n)"); + return; + } + + if (LayoutUtil::IsSparseArray(subshape)) { + pieces->push_back(shape_to_string(subshape)); + pieces->push_back("{"); + int64 rank = ShapeUtil::Rank(subshape); + int64 num_elements = literal.sparse_element_count(); + for (int64 i = 0; i < num_elements; ++i) { + if (i > 0) { + pieces->push_back(", "); + } + if (rank == 1) { + pieces->push_back(StrCat(literal.GetSparseIndex(i)[0])); + pieces->push_back(": "); + } else { + pieces->push_back("["); + pieces->push_back( + tensorflow::str_util::Join(literal.GetSparseIndex(i), ", ")); + pieces->push_back("]: "); + } + pieces->push_back(literal.GetSparseElementAsString(i)); + } + pieces->push_back("}"); + return; + } + + CHECK(LayoutUtil::IsDenseArray(subshape)); + auto element_to_string = - [this](tensorflow::gtl::ArraySlice indices) -> string { - PrimitiveType element_type = shape().element_type(); + [&](tensorflow::gtl::ArraySlice indices) -> string { + PrimitiveType element_type = subshape.element_type(); if (element_type == PRED) { // We display predicates in a densely packed form. - return Get(indices) ? "1" : "0"; + return literal.Get(indices, shape_index) ? "1" : "0"; } return ((!indices.empty() && indices.back() > 0) ? ", " : "") + - GetAsString(indices); + literal.GetAsString(indices, shape_index); }; - // TODO(b/32894291): refactor this code to reduce code duplication. - if (ShapeUtil::IsTuple(shape())) { - pieces.push_back(shape_to_string(shape())); - pieces.push_back(" (\n"); - pieces.push_back(tensorflow::str_util::Join( - tuple_literals(), ",\n", [](string* out, const Literal& element) { - tensorflow::strings::StrAppend(out, element.ToString()); - })); - pieces.push_back("\n)"); - } else if (ShapeUtil::Rank(shape()) == 0) { - pieces.push_back(GetAsString({})); - } else if (ShapeUtil::Rank(shape()) == 1) { - pieces.push_back("{"); - for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) { - pieces.push_back(element_to_string({i0})); + if (ShapeUtil::Rank(subshape) == 0) { + pieces->push_back(literal.GetAsString({}, shape_index)); + } else if (ShapeUtil::Rank(subshape) == 1) { + pieces->push_back("{"); + for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) { + pieces->push_back(element_to_string({i0})); } - pieces.push_back("}"); - } else if (ShapeUtil::Rank(shape()) == 2) { - pieces.push_back(shape_to_string(shape())); - pieces.push_back(" {\n"); - for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) { - pieces.push_back(" { "); - for (int64 i1 = 0; i1 < shape().dimensions(1); ++i1) { - pieces.push_back(element_to_string({i0, i1})); + pieces->push_back("}"); + } else if (ShapeUtil::Rank(subshape) == 2) { + pieces->push_back(shape_to_string(subshape)); + pieces->push_back(" {\n"); + for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) { + pieces->push_back(" { "); + for (int64 i1 = 0; i1 < subshape.dimensions(1); ++i1) { + pieces->push_back(element_to_string({i0, i1})); } - pieces.push_back(" "); - pieces.push_back(i0 == shape().dimensions(0) - 1 ? "}\n" : "},\n"); + pieces->push_back(" "); + pieces->push_back(i0 == subshape.dimensions(0) - 1 ? "}\n" : "},\n"); } - pieces.push_back("}"); - } else if (ShapeUtil::Rank(shape()) == 3) { - pieces.push_back(shape_to_string(shape())); - pieces.push_back(" {\n"); - for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) { - pieces.push_back(i0 > 0 ? ",\n{" : "{"); - for (int64 i1 = 0; i1 < shape().dimensions(1); ++i1) { - pieces.push_back(i1 > 0 ? ",\n { " : " { "); - for (int64 i2 = 0; i2 < shape().dimensions(2); ++i2) { - pieces.push_back(element_to_string({i0, i1, i2})); + pieces->push_back("}"); + } else if (ShapeUtil::Rank(subshape) == 3) { + pieces->push_back(shape_to_string(subshape)); + pieces->push_back(" {\n"); + for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) { + pieces->push_back(i0 > 0 ? ",\n{" : "{"); + for (int64 i1 = 0; i1 < subshape.dimensions(1); ++i1) { + pieces->push_back(i1 > 0 ? ",\n { " : " { "); + for (int64 i2 = 0; i2 < subshape.dimensions(2); ++i2) { + pieces->push_back(element_to_string({i0, i1, i2})); } - pieces.push_back(" }"); + pieces->push_back(" }"); } - pieces.push_back(" }"); + pieces->push_back(" }"); } - pieces.push_back("\n}"); - } else if (ShapeUtil::Rank(shape()) == 4) { - pieces.push_back(shape_to_string(shape())); - pieces.push_back(" {\n"); - for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) { - pieces.push_back(tensorflow::strings::Printf(" { /*i0=%lld*/\n", i0)); - for (int64 i1 = 0; i1 < shape().dimensions(1); ++i1) { - pieces.push_back( - tensorflow::strings::Printf(" { /*i1=%lld*/\n", i1)); - for (int64 i2 = 0; i2 < shape().dimensions(2); ++i2) { - pieces.push_back(" {"); - for (int64 i3 = 0; i3 < shape().dimensions(3); ++i3) { - pieces.push_back(element_to_string({i0, i1, i2, i3})); + pieces->push_back("\n}"); + } else if (ShapeUtil::Rank(subshape) == 4) { + pieces->push_back(shape_to_string(subshape)); + pieces->push_back(" {\n"); + for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) { + pieces->push_back(Printf(" { /*i0=%lld*/\n", i0)); + for (int64 i1 = 0; i1 < subshape.dimensions(1); ++i1) { + pieces->push_back(Printf(" { /*i1=%lld*/\n", i1)); + for (int64 i2 = 0; i2 < subshape.dimensions(2); ++i2) { + pieces->push_back(" {"); + for (int64 i3 = 0; i3 < subshape.dimensions(3); ++i3) { + pieces->push_back(element_to_string({i0, i1, i2, i3})); } - pieces.push_back(i2 == shape().dimensions(2) - 1 ? "}\n" : "},\n"); + pieces->push_back(i2 == subshape.dimensions(2) - 1 ? "}\n" : "},\n"); } - pieces.push_back(i1 == shape().dimensions(1) - 1 ? " }\n" - : " },\n"); + pieces->push_back(i1 == subshape.dimensions(1) - 1 ? " }\n" + : " },\n"); } - pieces.push_back(i0 == shape().dimensions(0) - 1 ? " }\n" : " },\n"); + pieces->push_back(i0 == subshape.dimensions(0) - 1 ? " }\n" : " },\n"); } - pieces.push_back("}"); - } else if (ShapeUtil::Rank(shape()) == 5) { - pieces.push_back(shape_to_string(shape())); - pieces.push_back(" {\n"); - for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) { - pieces.push_back(tensorflow::strings::Printf(" { /*i0=%lld*/\n", i0)); - for (int64 i1 = 0; i1 < shape().dimensions(1); ++i1) { - pieces.push_back( - tensorflow::strings::Printf(" { /*i1=%lld*/\n", i1)); - for (int64 i2 = 0; i2 < shape().dimensions(2); ++i2) { - pieces.push_back( - tensorflow::strings::Printf(" { /*i2=%lld*/\n", i2)); - for (int64 i3 = 0; i3 < shape().dimensions(3); ++i3) { - pieces.push_back(" {"); - for (int64 i4 = 0; i4 < shape().dimensions(4); ++i4) { - pieces.push_back(element_to_string({i0, i1, i2, i3, i4})); + pieces->push_back("}"); + } else if (ShapeUtil::Rank(subshape) == 5) { + pieces->push_back(shape_to_string(subshape)); + pieces->push_back(" {\n"); + for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) { + pieces->push_back(Printf(" { /*i0=%lld*/\n", i0)); + for (int64 i1 = 0; i1 < subshape.dimensions(1); ++i1) { + pieces->push_back(Printf(" { /*i1=%lld*/\n", i1)); + for (int64 i2 = 0; i2 < subshape.dimensions(2); ++i2) { + pieces->push_back(Printf(" { /*i2=%lld*/\n", i2)); + for (int64 i3 = 0; i3 < subshape.dimensions(3); ++i3) { + pieces->push_back(" {"); + for (int64 i4 = 0; i4 < subshape.dimensions(4); ++i4) { + pieces->push_back(element_to_string({i0, i1, i2, i3, i4})); } - pieces.push_back(i3 == shape().dimensions(3) - 1 ? "}\n" : "},\n"); + pieces->push_back(i3 == subshape.dimensions(3) - 1 ? "}\n" + : "},\n"); } - pieces.push_back(i2 == shape().dimensions(2) - 1 ? " }\n" - : " },\n"); + pieces->push_back(i2 == subshape.dimensions(2) - 1 ? " }\n" + : " },\n"); } - pieces.push_back(i1 == shape().dimensions(1) - 1 ? " }\n" - : " },\n"); + pieces->push_back(i1 == subshape.dimensions(1) - 1 ? " }\n" + : " },\n"); } - pieces.push_back(i0 == shape().dimensions(0) - 1 ? " }\n" : " },\n"); + pieces->push_back(i0 == subshape.dimensions(0) - 1 ? " }\n" : " },\n"); } - pieces.push_back("}"); + pieces->push_back("}"); } else { - pieces.push_back(shape_to_string(shape())); - pieces.push_back(" {...}"); + pieces->push_back(shape_to_string(subshape)); + pieces->push_back(" {"); + literal.EachCellAsString( + [&](tensorflow::gtl::ArraySlice indices, const string& value) { + pieces->push_back(" "); + pieces->push_back(value); + }); + pieces->push_back("}"); } +} + +} // namespace + +int64 Literal::sparse_element_count() const { + CHECK(LayoutUtil::IsSparseArray(shape())); + return sparse_indices()->index_count(); +} +string Literal::ToString(bool print_layout) const { + std::vector pieces; + ToStringHelper(*this, {}, print_layout, &pieces); return tensorflow::str_util::Join(pieces, ""); } /* static */ std::unique_ptr Literal::MakeTuple( tensorflow::gtl::ArraySlice elements) { - auto literal = MakeUnique(); - std::vector shape; - for (const Literal* tuple_element : elements) { - *literal->add_tuple_literals() = *tuple_element; - shape.push_back(tuple_element->shape()); + std::vector element_shapes; + for (const Literal* element : elements) { + element_shapes.push_back(element->shape()); + } + auto literal = MakeUnique(ShapeUtil::MakeTupleShape(element_shapes)); + for (int i = 0; i < elements.size(); ++i) { + TF_CHECK_OK(literal->CopyFrom(*elements[i], /*dest_shape_index=*/{i})); } - *literal->mutable_shape() = ShapeUtil::MakeTupleShape(shape); return literal; } /* static */ std::unique_ptr Literal::MakeTupleOwned( std::vector> elements) { - auto literal = MakeUnique(); - std::vector shape; - for (auto& tuple_element : elements) { - shape.push_back(tuple_element->shape()); - *literal->add_tuple_literals() = std::move(*tuple_element); - } - *literal->mutable_shape() = ShapeUtil::MakeTupleShape(shape); - return literal; -} - -const void* Literal::InternalData() const { - return const_cast( - const_cast(this)->MutableInternalData()); -} - -void* Literal::MutableInternalData() { - // NOTE: We access the vectors directly to avoid the const reference - // created by the accessor functions. - switch (shape().element_type()) { - case PRED: - case U8: - return reinterpret_cast(u8s_.data()); - case S32: - return reinterpret_cast(s32s_.data()); - case S64: - return reinterpret_cast(s64s_.data()); - case U32: - return reinterpret_cast(u32s_.data()); - case U64: - return reinterpret_cast(u64s_.data()); - case F32: - return reinterpret_cast(f32s_.data()); - case F64: - return reinterpret_cast(f64s_.data()); - case C64: - return reinterpret_cast(c64s_.data()); - case F16: - return reinterpret_cast(f16s_.data()); - case BF16: - return reinterpret_cast(bf16s_.data()); - default: - LOG(FATAL) << "primitive type not supported in literals: " - << PrimitiveType_Name(shape().element_type()); - } -} - -void Literal::Reserve(int64 num_elements) { - CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements); - switch (shape().element_type()) { - case PRED: - Resize(num_elements, false); - break; - case S8: - Resize(num_elements, 0); - break; - case U8: - Resize(num_elements, 0); - break; - case S32: - Resize(num_elements, 0); - break; - case S64: - Resize(num_elements, 0); - break; - case U32: - Resize(num_elements, 0); - break; - case U64: - Resize(num_elements, 0); - break; - case F32: - Resize(num_elements, 0); - break; - case F64: - Resize(num_elements, 0); - break; - case C64: - Resize(num_elements, 0); - break; - case F16: - Resize(num_elements, static_cast(0.0f)); - break; - case BF16: - Resize(num_elements, static_cast(0.0f)); - break; - default: - LOG(FATAL) << "primitive type not supported in literals: " - << PrimitiveType_Name(shape().element_type()); - } -} - -tensorflow::Status Literal::ValidateLiteral() const { - TF_CHECK_OK(ShapeUtil::ValidateShape(shape())); - int64 expected = ShapeUtil::ElementsIn(shape()); - int64 actual = -1; - switch (shape().element_type()) { - case PRED: - case U8: - actual = u8s_size(); - break; - case S32: - actual = s32s_size(); - break; - case U32: - actual = u32s_size(); - break; - case S64: - actual = s64s_size(); - break; - case U64: - actual = u64s_size(); - break; - case F32: - actual = f32s_size(); - break; - case F64: - actual = f64s_size(); - break; - case C64: - actual = c64s_size(); - break; - case F16: - actual = f16s().size() / sizeof(half); - break; - case BF16: - actual = bf16s().size(); - break; - default: - return tensorflow::errors::Unimplemented( - "unhandled element type for literal validation: " + - PrimitiveType_Name(shape().element_type())); - } - - if (expected != actual) { - return tensorflow::errors::InvalidArgument(tensorflow::strings::Printf( - "literal has bad number of elements for its shape %s: want %lld " - "got %lld", - ShapeUtil::HumanString(shape()).c_str(), expected, actual)); + std::vector element_ptrs; + for (const auto& element : elements) { + element_ptrs.push_back(element.get()); } - - return tensorflow::Status::OK(); + return MakeTuple(element_ptrs); } void Literal::EachCellAsString( @@ -892,17 +1270,13 @@ void Literal::EachCellAsString( namespace { template std::unique_ptr ConvertBetweenNativeTypes(const Literal& src_literal) { - auto result_literal = MakeUnique(); - Shape* result_shape = result_literal->mutable_shape(); - *result_shape = src_literal.shape(); - result_shape->set_element_type( - primitive_util::NativeToPrimitiveType()); - result_literal->Reserve(ShapeUtil::ElementsIn(*result_shape)); - tensorflow::gtl::ArraySlice src_data = - src_literal.GetArraySlice(); - tensorflow::gtl::MutableArraySlice dest_data = - result_literal->GetMutableArraySlice(); - int64 num_elements = ShapeUtil::ElementsIn(src_literal.shape()); + CHECK(ShapeUtil::IsArray(src_literal.shape())); + auto result_literal = MakeUnique(ShapeUtil::ChangeElementType( + src_literal.shape(), + primitive_util::NativeToPrimitiveType())); + auto src_data = src_literal.data(); + auto dest_data = result_literal->template data(); + int64 num_elements = src_literal.element_count(); for (int64 i = 0; i < num_elements; ++i) { dest_data[i] = static_cast(src_data[i]); @@ -912,18 +1286,16 @@ std::unique_ptr ConvertBetweenNativeTypes(const Literal& src_literal) { template std::unique_ptr ConvertToC64(const Literal& src_literal) { - auto result_literal = MakeUnique(); - Shape* result_shape = result_literal->mutable_shape(); - *result_shape = src_literal.shape(); - result_shape->set_element_type(C64); - result_literal->Reserve(ShapeUtil::ElementsIn(*result_shape)); + CHECK(ShapeUtil::IsArray(src_literal.shape())); + auto result_literal = MakeUnique( + ShapeUtil::ChangeElementType(src_literal.shape(), C64)); using NativeSrcT = typename primitive_util::PrimitiveTypeToNative::type; tensorflow::gtl::ArraySlice src_data = - src_literal.GetArraySlice(); + src_literal.data(); tensorflow::gtl::MutableArraySlice dest_data = - result_literal->GetMutableArraySlice(); - int64 num_elements = ShapeUtil::ElementsIn(src_literal.shape()); + result_literal->data(); + int64 num_elements = src_literal.element_count(); for (int64 i = 0; i < num_elements; ++i) { dest_data[i] = complex64(static_cast(src_data[i]), 0); } @@ -968,10 +1340,12 @@ StatusOr> ConvertIfDestTypeMatches( PrimitiveType_Name(primitive_dest_type).c_str()); } } + } // namespace StatusOr> Literal::Convert( PrimitiveType primitive_dest_type) const { + TF_RET_CHECK(ShapeUtil::IsArray(shape())); switch (shape().element_type()) { #define CONVERT_IF_DEST_TYPE_MATCHES(type) \ case (type): \ @@ -996,356 +1370,192 @@ StatusOr> Literal::Convert( } } -namespace { - -// Helper function which compares whether the elements of literal1 are equal to -// the elements of literal2. Recursively iterates through the entire -// multidimensional index space and compares the literal elements -// one-by-one. literal1 and literal2 must be compatible (same dimensions and -// type). template -bool EqualElements(const Literal& literal1, const Literal& literal2, - int dimension, std::vector* multi_index) { - if (dimension == ShapeUtil::Rank(literal1.shape())) { - return (literal1.Get(*multi_index) == - literal2.Get(*multi_index)); - } - for (int64 i = 0; i < literal1.shape().dimensions(dimension); ++i) { - (*multi_index)[dimension] = i; - if (!EqualElements(literal1, literal2, dimension + 1, - multi_index)) { +bool Literal::Piece::EqualElementsInternal( + const Literal::Piece& other, std::vector* multi_index) const { + if (multi_index->size() == ShapeUtil::Rank(subshape())) { + return (Get(*multi_index) == other.Get(*multi_index)); + } + for (int64 i = 0; i < subshape().dimensions(multi_index->size()); ++i) { + multi_index->push_back(i); + if (!EqualElementsInternal(other, multi_index)) { return false; } + multi_index->pop_back(); } return true; } -} // namespace +bool Literal::Piece::EqualElements(const Literal::Piece& other) const { + DCHECK(ShapeUtil::Compatible(subshape(), other.subshape())); + + std::vector multi_index; + switch (subshape().element_type()) { + case PRED: + return EqualElementsInternal(other, &multi_index); + case U8: + return EqualElementsInternal(other, &multi_index); + case S32: + return EqualElementsInternal(other, &multi_index); + case S64: + return EqualElementsInternal(other, &multi_index); + case U32: + return EqualElementsInternal(other, &multi_index); + case U64: + return EqualElementsInternal(other, &multi_index); + case F32: + return EqualElementsInternal(other, &multi_index); + case F64: + return EqualElementsInternal(other, &multi_index); + case F16: + return EqualElementsInternal(other, &multi_index); + case BF16: + return EqualElementsInternal(other, &multi_index); + case C64: + return EqualElementsInternal(other, &multi_index); + default: + LOG(FATAL) << "Unimplemented: Literal::Piece::EqualElements for type " + << PrimitiveType_Name(subshape().element_type()); + } +} bool Literal::operator==(const Literal& other) const { if (!ShapeUtil::Compatible(shape(), other.shape())) { return false; } - if (ShapeUtil::IsTuple(shape())) { - // Because the shapes are compatible, they must have the same number of - // tuple elements. - CHECK_EQ(tuple_literals_size(), other.tuple_literals_size()); - for (int i = 0; i < tuple_literals_size(); ++i) { - if (tuple_literals(i) != other.tuple_literals(i)) { - return false; - } + for (const auto& pair : pieces_) { + const ShapeIndex& index = pair.first; + const Piece& piece = pair.second; + if (!ShapeUtil::IsArray(piece.subshape())) { + continue; } - return true; - } else { - std::vector multi_index(ShapeUtil::Rank(shape()), 0); - switch (shape().element_type()) { - case PRED: - return EqualElements(*this, other, 0, &multi_index); - case U8: - return EqualElements(*this, other, 0, &multi_index); - case S32: - return EqualElements(*this, other, 0, &multi_index); - case S64: - return EqualElements(*this, other, 0, &multi_index); - case U32: - return EqualElements(*this, other, 0, &multi_index); - case U64: - return EqualElements(*this, other, 0, &multi_index); - case F32: - return EqualElements(*this, other, 0, &multi_index); - case F64: - return EqualElements(*this, other, 0, &multi_index); - case F16: - return EqualElements(*this, other, 0, &multi_index); - case BF16: - return EqualElements(*this, other, 0, &multi_index); - case C64: - return EqualElements(*this, other, 0, &multi_index); - default: - LOG(FATAL) << "Unimplemented: Literal::Equal for type " - << PrimitiveType_Name(shape().element_type()); + + const Piece& other_piece = other.piece(index); + if (!piece.EqualElements(other_piece)) { + return false; } } + return true; } -template <> -tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice() { - auto values = mutable_preds(); - return tensorflow::gtl::MutableArraySlice( - reinterpret_cast(values->data()), values->size()); -} - -template <> -tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice() { - auto values = mutable_u8s(); - return tensorflow::gtl::MutableArraySlice( - reinterpret_cast(values->data()), values->size()); -} - -template <> -tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice() { - auto values = mutable_u8s(); - return tensorflow::gtl::MutableArraySlice(values->data(), - values->size()); -} - -template <> -tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice() { - auto values = mutable_s16s(); - return tensorflow::gtl::MutableArraySlice(values->data(), - values->size()); -} - -template <> -tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice() { - auto values = mutable_u16s(); - return tensorflow::gtl::MutableArraySlice(values->data(), - values->size()); -} - -template <> -tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice() { - auto values = mutable_s32s(); - return tensorflow::gtl::MutableArraySlice(values->data(), - values->size()); -} - -template <> -tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice() { - auto values = mutable_u32s(); - return tensorflow::gtl::MutableArraySlice(values->data(), - values->size()); -} - -template <> -tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice() { - static_assert(sizeof(int64) == sizeof(tensorflow::protobuf_int64) && - alignof(int64) == alignof(tensorflow::protobuf_int64), - "The int64 and tensorflow::protobuf_int64 types are not " - "compatible"); - auto values = mutable_s64s(); - // Because of the fact that tensorflow::protobuf_int64 is defined as int64_t - // while tensorflow::int64 is defined as long long, a reinterpret_cast<> is - // necessary from the raw data pointer returned by the mutable_data() API. - return tensorflow::gtl::MutableArraySlice( - reinterpret_cast(values->data()), values->size()); -} - -template <> -tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice() { - static_assert(sizeof(uint64) == sizeof(tensorflow::protobuf_uint64) && - alignof(uint64) == alignof(tensorflow::protobuf_uint64), - "The uint64 and tensorflow::protobuf_uint64 types are not " - "compatible"); - auto values = mutable_u64s(); - // Because of the fact that tensorflow::protobuf_uint64 is defined as uint64_t - // while tensorflow::uint64 is defined as unsigned long long, a - // reinterpret_cast<> is necessary from the raw data pointer returned by the - // mutable_data() API. - return tensorflow::gtl::MutableArraySlice( - reinterpret_cast(values->data()), values->size()); -} - -template <> -tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice() { - auto values = mutable_f32s(); - return tensorflow::gtl::MutableArraySlice(values->data(), - values->size()); -} - -template <> -tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice() { - auto values = mutable_f64s(); - return tensorflow::gtl::MutableArraySlice(values->data(), - values->size()); -} - -template <> -tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice() { - auto values = mutable_c64s(); - return {values->data(), values->size()}; -} - -template <> -tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice() { - auto values = mutable_f16s(); - return tensorflow::gtl::MutableArraySlice(values->data(), - values->size()); -} - -template <> -tensorflow::gtl::MutableArraySlice -Literal::GetMutableArraySlice() { - auto values = mutable_bf16s(); - return {values->data(), values->size()}; -} - -template <> -tensorflow::gtl::ArraySlice Literal::GetArraySlice() const { - CHECK_EQ(shape().element_type(), PRED); - return tensorflow::gtl::ArraySlice( - reinterpret_cast(preds().data()), preds().size()); -} - -template <> -tensorflow::gtl::ArraySlice Literal::GetArraySlice() const { - CHECK_EQ(shape().element_type(), U8); - return tensorflow::gtl::ArraySlice( - reinterpret_cast(u8s().data()), u8s().size()); -} - -template <> -tensorflow::gtl::ArraySlice Literal::GetArraySlice() const { - CHECK_EQ(shape().element_type(), S8); - return tensorflow::gtl::ArraySlice( - reinterpret_cast(u8s().data()), u8s().size()); -} - -template <> -tensorflow::gtl::ArraySlice Literal::GetArraySlice() const { - CHECK_EQ(shape().element_type(), U16); - return tensorflow::gtl::ArraySlice(u16s().data(), u16s().size()); -} - -template <> -tensorflow::gtl::ArraySlice Literal::GetArraySlice() const { - CHECK_EQ(shape().element_type(), S16); - return tensorflow::gtl::ArraySlice(s16s().data(), s16s().size()); -} - -template <> -tensorflow::gtl::ArraySlice Literal::GetArraySlice() const { - CHECK_EQ(shape().element_type(), U32); - return u32s(); -} - -template <> -tensorflow::gtl::ArraySlice Literal::GetArraySlice() const { - CHECK_EQ(shape().element_type(), U64); - return u64s(); -} - -template <> -tensorflow::gtl::ArraySlice Literal::GetArraySlice() const { - CHECK_EQ(shape().element_type(), S32); - return s32s(); -} - -template <> -tensorflow::gtl::ArraySlice Literal::GetArraySlice() const { - CHECK_EQ(shape().element_type(), S64); - return s64s(); -} - -template <> -tensorflow::gtl::ArraySlice Literal::GetArraySlice() const { - CHECK_EQ(shape().element_type(), F64); - return f64s(); -} - -template <> -tensorflow::gtl::ArraySlice Literal::GetArraySlice() const { - CHECK_EQ(shape().element_type(), F16); - return tensorflow::gtl::ArraySlice(f16s().data(), - f16s().size() / sizeof(half)); -} - -template <> -tensorflow::gtl::ArraySlice Literal::GetArraySlice() const { - CHECK_EQ(shape().element_type(), BF16); - return {bf16s().data(), bf16s().size()}; -} - -template <> -tensorflow::gtl::ArraySlice Literal::GetArraySlice() - const { - CHECK_EQ(shape().element_type(), C64); - return c64s(); -} +namespace { template -static bool AllElementsEqualValue(const Literal& literal, NativeT value) { - for (int64 i = 0; i < ShapeUtil::ElementsIn(literal.shape()); ++i) { - auto multi_index = - IndexUtil::LinearIndexToMultidimensionalIndex(literal.shape(), i); - if (literal.Get(multi_index) != value) { +static bool AllElementsEqualValue(tensorflow::gtl::ArraySlice data, + NativeT value) { + for (int64 i = 0; i < data.size(); ++i) { + if (data[i] != value) { return false; } } return true; } +} // namespace + bool Literal::IsAll(int8 value) const { - switch (shape().element_type()) { - case U8: - if (value >= 0) { - return AllElementsEqualValue(*this, value); - } - return false; - case U32: - if (value >= 0) { - return AllElementsEqualValue(*this, value); - } - return false; - case U64: - if (value >= 0) { - return AllElementsEqualValue(*this, value); - } - return false; - case S8: - return AllElementsEqualValue(*this, value); - case S32: - return AllElementsEqualValue(*this, value); - case S64: - return AllElementsEqualValue(*this, value); - case F32: - return AllElementsEqualValue(*this, value); - case F64: - return AllElementsEqualValue(*this, value); - case F16: - return AllElementsEqualValue(*this, static_cast(value)); - case BF16: - return AllElementsEqualValue(*this, - static_cast(value)); - case PRED: - if (value == 0) { - return AllElementsEqualValue(*this, false); - } - if (value == 1) { - return AllElementsEqualValue(*this, true); + for (const auto& pair : pieces_) { + const Piece& piece = pair.second; + if (!ShapeUtil::IsArray(piece.subshape())) { + continue; + } + + auto piece_is_all = [&]() { + switch (shape().element_type()) { + case U8: + if (value >= 0) { + return AllElementsEqualValue(piece.data(), value); + } + return false; + case U32: + if (value >= 0) { + return AllElementsEqualValue(piece.data(), value); + } + return false; + case U64: + if (value >= 0) { + return AllElementsEqualValue(piece.data(), value); + } + return false; + case S8: + return AllElementsEqualValue(piece.data(), value); + case S32: + return AllElementsEqualValue(piece.data(), value); + case S64: + return AllElementsEqualValue(piece.data(), value); + case F32: + return AllElementsEqualValue(piece.data(), value); + case F64: + return AllElementsEqualValue(piece.data(), value); + case F16: + return AllElementsEqualValue(piece.data(), + static_cast(value)); + case BF16: + return AllElementsEqualValue(piece.data(), + static_cast(value)); + case PRED: + if (value == 0) { + return AllElementsEqualValue(piece.data(), false); + } + if (value == 1) { + return AllElementsEqualValue(piece.data(), true); + } + return false; + default: + return false; } return false; - default: + }; + + if (!piece_is_all()) { return false; + } } + return true; } bool Literal::IsAllFloat(float value) const { - switch (shape().element_type()) { - case F32: - return AllElementsEqualValue(*this, value); - case F64: - return AllElementsEqualValue(*this, value); - case F16: - return AllElementsEqualValue(*this, static_cast(value)); - case BF16: - return AllElementsEqualValue(*this, - static_cast(value)); - default: + for (const auto& pair : pieces_) { + const Piece& piece = pair.second; + if (!ShapeUtil::IsArray(piece.subshape())) { + continue; + } + + auto piece_is_all = [&]() { + switch (shape().element_type()) { + case F32: + return AllElementsEqualValue(piece.data(), value); + case F64: + return AllElementsEqualValue(piece.data(), value); + case F16: + return AllElementsEqualValue(piece.data(), + static_cast(value)); + case BF16: + return AllElementsEqualValue(piece.data(), + static_cast(value)); + default: + return false; + } + }; + if (!piece_is_all()) { return false; + } } + return true; } bool Literal::IsAllComplex(complex64 value) const { switch (shape().element_type()) { case C64: - return AllElementsEqualValue(*this, value); + return AllElementsEqualValue(root_piece().data(), + value); default: return false; } } bool Literal::IsZero(tensorflow::gtl::ArraySlice indices) const { + CHECK(ShapeUtil::IsArray(shape())); switch (shape().element_type()) { case U8: return Get(indices) == 0; @@ -1376,247 +1586,294 @@ bool Literal::IsZero(tensorflow::gtl::ArraySlice indices) const { } } -template <> -/* static */ void Literal::Resize(int64 num_elements, bool value) { - CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements); - mutable_preds()->resize(num_elements, value); -} - -template <> -void Literal::Resize(int64 num_elements, int8 value) { - CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements); - mutable_u8s()->resize(num_elements, value); -} - -template <> -void Literal::Resize(int64 num_elements, uint8 value) { - CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements); - mutable_u8s()->resize(num_elements, value); -} - -template <> -void Literal::Resize(int64 num_elements, int32 value) { - CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements); - mutable_s32s()->resize(num_elements, value); -} - -template <> -void Literal::Resize(int64 num_elements, uint32 value) { - CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements); - mutable_u32s()->resize(num_elements, value); -} - -template <> -void Literal::Resize(int64 num_elements, int64 value) { - CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements); - mutable_s64s()->resize(num_elements, value); -} - -template <> -void Literal::Resize(int64 num_elements, uint64 value) { - CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements); - mutable_u64s()->resize(num_elements, value); -} - -template <> -void Literal::Resize(int64 num_elements, float value) { - CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements); - mutable_f32s()->resize(num_elements, value); -} - -template <> -void Literal::Resize(int64 num_elements, double value) { - CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements); - mutable_f64s()->resize(num_elements, value); -} - -template <> -void Literal::Resize(int64 num_elements, half value) { - CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements); - mutable_f16s()->resize(num_elements, value); -} - -template <> -void Literal::Resize(int64 num_elements, bfloat16 value) { - CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements); - mutable_bf16s()->resize(num_elements, value); -} - -template <> -void Literal::Resize(int64 num_elements, complex64 value) { - CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements); - mutable_c64s()->resize(num_elements, value); -} +namespace { template void CopyToRepeatedField(RepeatedFieldT* dest, - const std::vector& src) { + const tensorflow::gtl::ArraySlice src) { *dest = RepeatedFieldT(src.begin(), src.end()); } -template <> -void CopyToRepeatedField, complex64>( - tensorflow::protobuf::RepeatedField* dest, - const std::vector& src) { - *dest = tensorflow::protobuf::RepeatedField( - reinterpret_cast(src.data()), - reinterpret_cast(src.data()) + src.size() * 2); -} +} // namespace -LiteralProto Literal::ToProto() const { - LiteralProto proto; - proto.Clear(); - *proto.mutable_shape() = shape(); - switch (shape().element_type()) { +void Literal::Piece::WriteToProto(LiteralProto* proto) const { + *proto->mutable_shape() = subshape(); + switch (subshape().element_type()) { case PRED: - CopyToRepeatedField(proto.mutable_preds(), preds()); + CopyToRepeatedField(proto->mutable_preds(), data()); break; case U8: - *proto.mutable_u8s() = u8s_string(); - break; - case S32: - CopyToRepeatedField(proto.mutable_s32s(), s32s()); - break; - case S64: - CopyToRepeatedField(proto.mutable_s64s(), s64s()); + proto->set_u8s(static_cast(data().data()), + element_count()); break; case U32: - CopyToRepeatedField(proto.mutable_u32s(), u32s()); + CopyToRepeatedField(proto->mutable_u32s(), data()); break; case U64: - CopyToRepeatedField(proto.mutable_u64s(), u64s()); + CopyToRepeatedField(proto->mutable_u64s(), data()); + break; + case S32: + CopyToRepeatedField(proto->mutable_s32s(), data()); + break; + case S64: + CopyToRepeatedField(proto->mutable_s64s(), data()); break; case F16: - *proto.mutable_f16s() = - string(reinterpret_cast(f16s_.data()), - f16s_.size() * sizeof(half)); + *proto->mutable_f16s() = string( + reinterpret_cast(data().data()), size_bytes()); if (!kLittleEndian) { - ConvertEndianShort(const_cast(proto.mutable_f16s()->data()), - proto.f16s().size()); + ConvertEndianShort(const_cast(proto->mutable_f16s()->data()), + proto->f16s().size()); } break; case BF16: - *proto.mutable_bf16s() = - string(reinterpret_cast(bf16s_.data()), - bf16s_.size() * sizeof(bfloat16)); + *proto->mutable_bf16s() = string( + reinterpret_cast(data().data()), size_bytes()); if (!kLittleEndian) { - ConvertEndianShort(const_cast(proto.mutable_bf16s()->data()), - proto.bf16s().size()); + ConvertEndianShort(const_cast(proto->mutable_bf16s()->data()), + proto->bf16s().size()); } break; case F32: - CopyToRepeatedField(proto.mutable_f32s(), f32s()); + CopyToRepeatedField(proto->mutable_f32s(), data()); break; case F64: - CopyToRepeatedField(proto.mutable_f64s(), f64s()); + CopyToRepeatedField(proto->mutable_f64s(), data()); break; case C64: - CopyToRepeatedField(proto.mutable_c64s(), c64s()); - break; - case TUPLE: - for (const auto& tuple : tuple_literals()) { - *proto.add_tuple_literals() = tuple.ToProto(); + for (complex64 value : data()) { + proto->add_c64s(value.real()); + proto->add_c64s(value.imag()); } break; + case TUPLE: + // Nothing to do but assign the shape which is done above. + return; default: - LOG(FATAL) << "Unhandled primitive type " << shape().element_type(); + LOG(FATAL) << "Unhandled primitive type " << subshape().element_type(); } - - return proto; } -template -void CopyFromRepeatedField(std::vector* dest, - const RepeatedFieldT& src) { - *dest = std::vector(src.begin(), src.end()); +const void* Literal::Piece::untyped_data() const { + CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape()); + return buffer(); } -template <> -void CopyFromRepeatedField, - complex64>( - std::vector* dest, - const tensorflow::protobuf::RepeatedField& src) { - *dest = std::vector( - reinterpret_cast(src.data()), - reinterpret_cast(src.data()) + src.size() / 2); +void* Literal::Piece::untyped_data() { + CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape()); + return buffer(); } -void Literal::CopyFromProto(const LiteralProto& literal_proto) { - if (!literal_proto.has_shape()) { - return; +namespace { + +template +Status CopyFromRepeatedField(tensorflow::gtl::MutableArraySlice dest, + const RepeatedFieldT& src) { + if (dest.size() != src.size()) { + return InvalidArgument( + "Expected %lu elements in LiteralProto repeated field, has %d", + dest.size(), src.size()); } + std::copy(src.begin(), src.end(), dest.begin()); + return Status::OK(); +} - *mutable_shape() = literal_proto.shape(); - switch (shape().element_type()) { +} // namespace + +Status Literal::Piece::CopyFromProto(const LiteralProto& proto) { + // These conditions should have been checked in Literal::CreateFromProto. + TF_RET_CHECK(proto.has_shape()); + TF_RET_CHECK(LayoutUtil::HasLayout(proto.shape())); + TF_RET_CHECK(ShapeUtil::Equal(proto.shape(), subshape())); + + switch (subshape().element_type()) { case PRED: - CopyFromRepeatedField(mutable_preds(), literal_proto.preds()); - break; - case U8: - set_u8s(literal_proto.u8s()); + TF_RETURN_IF_ERROR(CopyFromRepeatedField(data(), proto.preds())); break; + case U8: { + auto u8_data = data(); + TF_RET_CHECK(proto.u8s().size() == u8_data.size()); + std::copy(proto.u8s().begin(), proto.u8s().end(), u8_data.begin()); + } break; case S32: - CopyFromRepeatedField(mutable_s32s(), literal_proto.s32s()); + TF_RETURN_IF_ERROR(CopyFromRepeatedField(data(), proto.s32s())); break; case S64: - CopyFromRepeatedField(mutable_s64s(), literal_proto.s64s()); + TF_RETURN_IF_ERROR(CopyFromRepeatedField(data(), proto.s64s())); break; case U32: - CopyFromRepeatedField(mutable_u32s(), literal_proto.u32s()); + TF_RETURN_IF_ERROR(CopyFromRepeatedField(data(), proto.u32s())); break; case U64: - CopyFromRepeatedField(mutable_u64s(), literal_proto.u64s()); + TF_RETURN_IF_ERROR(CopyFromRepeatedField(data(), proto.u64s())); break; case F16: { - const string& s(literal_proto.f16s()); - CHECK_EQ(0, s.size() % sizeof(half)); - f16s_ = std::vector(s.size() / sizeof(half)); - memcpy(f16s_.data(), s.data(), s.size()); - + const string& s(proto.f16s()); + TF_RET_CHECK(data().size() * sizeof(half) == s.size()); + memcpy(untyped_data(), s.data(), s.size()); if (!kLittleEndian) { - ConvertEndianShort(reinterpret_cast(f16s_.data()), s.size()); + ConvertEndianShort(reinterpret_cast(untyped_data()), s.size()); } - break; - } - case BF16: { - const string& s(literal_proto.bf16s()); - CHECK_EQ(0, s.size() % sizeof(bfloat16)); - bf16s_ = std::vector(s.size() / sizeof(bfloat16)); - memcpy(bf16s_.data(), s.data(), s.size()); + } break; + case BF16: { + const string& s(proto.bf16s()); + TF_RET_CHECK(data().size() * sizeof(bfloat16) == s.size()); + memcpy(untyped_data(), s.data(), s.size()); if (!kLittleEndian) { - ConvertEndianShort(reinterpret_cast(bf16s_.data()), s.size()); + ConvertEndianShort(reinterpret_cast(untyped_data()), s.size()); } - break; - } + } break; case F32: - CopyFromRepeatedField(mutable_f32s(), literal_proto.f32s()); + TF_RETURN_IF_ERROR(CopyFromRepeatedField(data(), proto.f32s())); break; case F64: - CopyFromRepeatedField(mutable_f64s(), literal_proto.f64s()); + TF_RETURN_IF_ERROR(CopyFromRepeatedField(data(), proto.f64s())); break; - case C64: - CopyFromRepeatedField(mutable_c64s(), literal_proto.c64s()); - break; - case TUPLE: - for (const auto& proto : literal_proto.tuple_literals()) { - mutable_tuple_literals()->push_back(Literal(proto)); + case C64: { + auto complex_data = data(); + TF_RET_CHECK(proto.c64s_size() == complex_data.size() * 2); + for (int64 i = 0; i < complex_data.size(); ++i) { + complex_data[i] = complex64{proto.c64s(i * 2), proto.c64s(i * 2 + 1)}; } + } break; + case TUPLE: + LOG(FATAL) << "Should not be called on tuple shapes: " + << ShapeUtil::HumanString(subshape()); break; default: - LOG(FATAL) << "Unhandled primitive type " << shape().element_type(); + LOG(FATAL) << "Unhandled primitive type " << subshape().element_type(); } + return Status::OK(); } -const Literal& Literal::GetSubliteral(const ShapeIndex& index) const { - return const_cast(this)->GetSubliteral(index); +LiteralProto Literal::ToProto() const { + LiteralProto proto; + for (const auto& pair : pieces_) { + const ShapeIndex& index = pair.first; + const Piece& piece = pair.second; + + LiteralProto* proto_piece = &proto; + for (int64 i : index) { + while (proto_piece->tuple_literals_size() <= i) { + proto_piece->add_tuple_literals(); + } + proto_piece = proto_piece->mutable_tuple_literals(i); + } + piece.WriteToProto(proto_piece); + } + + if (LayoutUtil::IsSparseArray(shape())) { + CopyToRepeatedField(proto.mutable_sparse_indices(), + sparse_indices()->data()); + } + + return proto; +} + +/* static */ +StatusOr> Literal::CreateFromProto( + const LiteralProto& proto) { + if (!proto.has_shape()) { + return InvalidArgument("LiteralProto has no shape"); + } + if (!LayoutUtil::HasLayout(proto.shape())) { + return InvalidArgument("LiteralProto has no layout"); + } + + auto literal = MakeUnique(proto.shape()); + + for (auto& pair : literal->pieces_) { + const ShapeIndex& index = pair.first; + Piece& piece = pair.second; + const LiteralProto* proto_element = &proto; + for (int64 i : index) { + TF_RET_CHECK(i < proto_element->tuple_literals_size()); + proto_element = &proto_element->tuple_literals(i); + } + + if (ShapeUtil::IsTuple(piece.subshape())) { + if (proto_element->tuple_literals_size() != + ShapeUtil::TupleElementCount(piece.subshape())) { + return InvalidArgument( + "Expected %lld tuple elements in LiteralProto, has %d", + ShapeUtil::TupleElementCount(piece.subshape()), + proto_element->tuple_literals_size()); + } + continue; + } + + TF_RET_CHECK(ShapeUtil::IsArray(piece.subshape())); + TF_RETURN_IF_ERROR(piece.CopyFromProto(*proto_element)); + } + return std::move(literal); +} + +const void* Literal::untyped_data(const ShapeIndex& shape_index) const { + return piece(shape_index).untyped_data(); +} + +void* Literal::untyped_data(const ShapeIndex& shape_index) { + return piece(shape_index).untyped_data(); +} + +int64 Literal::size_bytes(const ShapeIndex& shape_index) const { + return piece(shape_index).size_bytes(); +} + +string Literal::GetR1U8AsString() const { + CHECK(ShapeUtil::IsArray(shape())); + CHECK_EQ(ShapeUtil::Rank(shape()), 1); + CHECK_EQ(shape().element_type(), U8); + return string(tensorflow::bit_cast(data().data()), + ShapeUtil::ElementsIn(shape())); +} + +/* static */ const LiteralView LiteralView::Create( + const Literal& literal, const ShapeIndex& view_root) { + return LiteralView(literal, view_root); +} + +LiteralView::LiteralView(const Literal& literal, const ShapeIndex& view_root) { + shape_ = ShapeUtil::GetSubshape(literal.shape(), view_root); + pieces_ = ShapeTree(shape_); + owns_buffers_ = false; + for (auto& pair : pieces_) { + const ShapeIndex& index = pair.first; + Piece& piece = pair.second; + + ShapeIndex src_index = view_root; + for (int64 i : index) { + src_index.push_back(i); + } + const Piece& src_piece = literal.piece(src_index); + piece.set_buffer(src_piece.buffer()); + piece.set_sparse_indices(src_piece.sparse_indices()); + piece.set_subshape(&ShapeUtil::GetSubshape(shape_, index)); + } +} + +LiteralView::~LiteralView() {} + +LiteralView::LiteralView(const LiteralView& other) { CopyFrom(other); } + +LiteralView& LiteralView::operator=(const LiteralView& other) { + CopyFrom(other); + return *this; } -Literal& Literal::GetSubliteral(const ShapeIndex& index) { - Literal* subliteral = this; - for (int64 i : index) { - subliteral = &subliteral->tuple_literals_.at(i); +void LiteralView::CopyFrom(const LiteralView& other) { + // We can't use the default copy-constructor/copy-assignment because + // Piece::subshape_ points to subshapes within the Shape of the owning + // Literal/LiteralView. + shape_ = other.shape(); + pieces_ = other.pieces_; + for (auto& pair : pieces_) { + const ShapeIndex& index = pair.first; + Piece& piece = pair.second; + piece.set_subshape(&ShapeUtil::GetSubshape(shape_, index)); } - return *subliteral; + owns_buffers_ = false; } } // namespace xla diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h index f37e529caf54e3aded1a418d1f01c1440cd0f284..e0196509a7483abac3d9c0e59a54b591a327b980 100644 --- a/tensorflow/compiler/xla/literal_util.h +++ b/tensorflow/compiler/xla/literal_util.h @@ -34,7 +34,9 @@ limitations under the License. #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/shape_tree.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/sparse_index_array.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" @@ -50,152 +52,70 @@ limitations under the License. namespace xla { -// Utility class for dealing with XLA literal values. Most methods are -// templated by native (host) type which corresponds to a unique XLA -// PrimitiveType. See ComputationBuilder for details. Not all primitive types -// defined in xla_data.proto have a corresponding native type or even have a -// storage location in the Literal proto yet (for example, primitive type F16). +// Class representing literal values in XLA. +// +// TODO(b/67651157): The methods in this class should be reduced to a minimal +// set of methods which construct Literals and accessors methods. Other methods +// which perform computation on Literals (Reshape, Slice, etc) should be moved +// elsewhere, and perhaps combined with evaluator code which operates on +// Literals. class Literal { public: - Literal() {} + Literal() : Literal(ShapeUtil::MakeNil()) {} - Literal(const Literal& other) = default; - Literal(Literal&&) = default; + // Create a literal of the given shape. The literal is allocated sufficient + // memory to hold the shape. Memory is uninitialized. + explicit Literal(const Shape& shape); + virtual ~Literal(); - explicit Literal(const LiteralProto& other) { CopyFromProto(other); } - - Literal& operator=(const Literal& other) = default; - Literal& operator=(Literal&&) = default; + // Literals are moveable, but not copyable. To copy a literal use + // Literal::Clone or Literal::CloneToUnique. This prevents inadvertent copies + // of literals which can be expensive. + Literal(const Literal& other) = delete; + Literal& operator=(const Literal& other) = delete; + Literal(Literal&& other); + Literal& operator=(Literal&& other); // Literals are equal if they have compatible shapes and the same data - // values. Layout is not checked. + // values. Layout is not compared. bool operator==(const Literal& other) const; bool operator!=(const Literal& other) const { return !(*this == other); } + // Serialize to and from a proto. + static StatusOr> CreateFromProto( + const LiteralProto& proto); LiteralProto ToProto() const; - bool has_shape() const { - return shape_.element_type() != PRIMITIVE_TYPE_INVALID; - } - - // Basic accessor functions. Names mirror the original protobuf - // functions for convenience. - string DebugString() const { return ToProto().DebugString(); } - string ShortDebugString() const { return ToProto().ShortDebugString(); } - - // Return the nested literal at the given shape index. - const Literal& GetSubliteral(const ShapeIndex& index) const; - Literal& GetSubliteral(const ShapeIndex& index); - - void Clear() { - shape_.Clear(); - u8s_.clear(); - s16s_.clear(); - s32s_.clear(); - s64s_.clear(); - u16s_.clear(); - u32s_.clear(); - u64s_.clear(); - f16s_.clear(); - f32s_.clear(); - f64s_.clear(); - tuple_literals_.clear(); - } - - int preds_size() const { return u8s().size(); } - const std::vector& preds() const { - static_assert(sizeof(uint8) == sizeof(bool), - "The uint8 and bool types should be the same size"); - return u8s_; - } - std::vector* mutable_preds() { - static_assert(sizeof(uint8) == sizeof(bool), - "The uint8 and bool types should be the same size"); - return &u8s_; - } - - int s16s_size() const { return s16s().size(); } - int32 s16s(int i) const { return s16s_[i]; } - const std::vector& s16s() const { return s16s_; } - std::vector* mutable_s16s() { return &s16s_; } - - int s32s_size() const { return s32s().size(); } - int32 s32s(int i) const { return s32s_[i]; } - const std::vector& s32s() const { return s32s_; } - std::vector* mutable_s32s() { return &s32s_; } - - int s64s_size() const { return s64s().size(); } - void add_s64s(int64 value) { s64s_.push_back(value); } - const std::vector& s64s() const { return s64s_; } - std::vector* mutable_s64s() { return &s64s_; } - - int u16s_size() const { return u16s().size(); } - uint32 u16s(int i) const { return u16s_[i]; } - const std::vector& u16s() const { return u16s_; } - std::vector* mutable_u16s() { return &u16s_; } - - int u32s_size() const { return u32s().size(); } - uint32 u32s(int i) const { return u32s_[i]; } - const std::vector& u32s() const { return u32s_; } - std::vector* mutable_u32s() { return &u32s_; } - - int u64s_size() const { return u64s().size(); } - const std::vector& u64s() const { return u64s_; } - std::vector* mutable_u64s() { return &u64s_; } - - int f16s_size() const { return f16s().size(); } - half f16s(int i) const { return f16s_[i]; } - const std::vector& f16s() const { return f16s_; } - std::vector* mutable_f16s() { return &f16s_; } - - int f32s_size() const { return f32s().size(); } - float f32s(int i) const { return f32s_[i]; } - void add_f32s(float value) { f32s_.push_back(value); } - const std::vector& f32s() const { return f32s_; } - std::vector& f32s() { return f32s_; } - std::vector* mutable_f32s() { return &f32s_; } - - int f64s_size() const { return f64s().size(); } - const std::vector& f64s() const { return f64s_; } - std::vector* mutable_f64s() { return &f64s_; } - - int c64s_size() const { return c64s().size(); } - const std::vector& c64s() const { return c64s_; } - std::vector* mutable_c64s() { return &c64s_; } - - int bf16s_size() const { return bf16s().size(); } - bfloat16 bf16s(int i) const { return bf16s_[i]; } - const std::vector& bf16s() const { return bf16s_; } - std::vector* mutable_bf16s() { return &bf16s_; } - - int tuple_literals_size() const { return tuple_literals().size(); } - const Literal& tuple_literals(int i) const { return tuple_literals_[i]; } - Literal* add_tuple_literals() { - tuple_literals_.push_back(Literal()); - return &tuple_literals_.back(); - } - std::vector* mutable_tuple_literals() { return &tuple_literals_; } - const std::vector& tuple_literals() const { return tuple_literals_; } - - int u8s_size() const { return u8s().size(); } - const std::vector& u8s() const { return u8s_; } - void set_u8s(const std::vector& value) { u8s_ = value; } - void set_u8s(tensorflow::StringPiece value) { - u8s_ = std::vector(value.size()); - u8s_.clear(); - append_u8s(value); - } - - void append_u8s(tensorflow::StringPiece value) { - u8s_.insert(u8s_.end(), value.begin(), value.end()); - } - - string u8s_string() const { return string(u8s().begin(), u8s().end()); } + // Return the shape of the literal. + const Shape& shape() const { return shape_; } - std::vector* mutable_u8s() { return &u8s_; } + // TODO(b/67651157): Remove this accessor. Literal users should not be able to + // mutate the shape as this can produce malformed Literals. + Shape* mutable_shape_do_not_use() { return &shape_; } - const Shape& shape() const { return shape_; } - Shape* mutable_shape() { return &shape_; } + // Returns a (Mutable)ArraySlice view of the array for this literal for the + // given NativeT (e.g., float). CHECKs if the subshape of the literal at the + // given ShapeIndex is not array. See primitive_util.h for the mapping from + // XLA type to native type. + template + tensorflow::gtl::ArraySlice data( + const ShapeIndex& shape_index = {}) const; + template + tensorflow::gtl::MutableArraySlice data( + const ShapeIndex& shape_index = {}); + + // Returns a pointer to the sparse index array. Returns nullptr if the literal + // is not a sparse array. + const SparseIndexArray* sparse_indices( + const ShapeIndex& shape_index = {}) const; + SparseIndexArray* sparse_indices(const ShapeIndex& shape_index = {}); + + // Returns a pointer to (or size of) the underlying buffer holding the array + // at the given shape index. CHECKs if the subshape of the literal at the + // given ShapeIndex is not array. + const void* untyped_data(const ShapeIndex& shape_index = {}) const; + void* untyped_data(const ShapeIndex& shape_index = {}); + int64 size_bytes(const ShapeIndex& shape_index = {}) const; // Creates a new literal of a given rank. To minimize ambiguity (for users // and the compiler) these CreateR[0-2] methods should explicitly specify the @@ -243,6 +163,60 @@ class Literal { values, const Layout& layout); + // Returns this literal's data as a string. This literal must be a rank-1 U8 + // array. + string GetR1U8AsString() const; + + // Creates a literal with a sparse layout and the given indices and values. + // The shape is initialized from the given dimensions. The minor dimension of + // the indices array must equal the rank of the shape (i.e. size of the + // dimensions array). The major dimension of the indices array must equal the + // number of elements in the values array. The maximum number of elements in + // the array is taken from the max_indices() value of the index array. + // + // XLA assumes that sparse literals are in sorted order for all operations. If + // the `sort` argument is true, then the indices and values will be sorted + // while copying them into the literal. If you have ensured that the indices + // and values are already sorted, then you may set the `sort` argument to + // false to skip the sorting step. + // + // For example: + // + // CreateSparse( + // {12, 12, 12}, + // SparseIndexArray(10, 3, + // Array2D{ + // {0, 1, 2}, + // {3, 4, 5}, + // {6, 7, 8}, + // {9, 10, 11}, + // }), + // {1.0, 2.0 3.0, 4.0}) + // + // This creates an array with shape F64[12,12,12]sparse{10}, that has the + // following non-zero values: + // + // [0, 1, 2]: 1.0 + // [3, 4, 5]: 2.0 + // [6, 7, 8]: 3.0 + // [9, 10, 11]: 4.0 + // + template + static std::unique_ptr CreateSparse( + tensorflow::gtl::ArraySlice dimensions, SparseIndexArray indices, + tensorflow::gtl::ArraySlice values, bool sort = true); + + // Populates a literal with a sparse layout with the given indices and values. + // Each index in the indices array is CHECKed against the dimensions in the + // literal's shape. If sort is true, then the indices and values will be + // sorted. If sort is false, then the indices and values are assumed to + // already be in sorted order. See CreateSparse for an example of how data + // are populated. + template + void PopulateSparse(SparseIndexArray indices, + tensorflow::gtl::ArraySlice values, + bool sort = true); + // Creates a new Literal object with the shape specified as parameter. // The content of the literal values is the default value of the primitive // type of literal itself (0 for numeric types, and false for predicates). @@ -256,6 +230,23 @@ class Literal { PrimitiveType primitive_type, tensorflow::gtl::ArraySlice dimensions); + // Copy values from 'src_literal' rooted at 'src_shape_index' into this + // literal rooted at 'dest_shape_index'. The subshape of this literal rooted + // at 'dest_shape_index' must be compatible with the subshape of 'src_literal' + // rooted at 'src_shape_index', but need not be arrays. + Status CopyFrom(const Literal& src_literal, + const ShapeIndex& dest_shape_index = {}, + const ShapeIndex& src_shape_index = {}); + + // Similar to CopyFrom, but with move semantincs. The subshape of this literal + // rooted at 'dest_shape_index' must be *equal* to the shape 'src_literal' + // (layouts and shapes must match), but need not be arrays. The memory + // allocated in this literal for the subshape at dest_shape_index is + // deallocated, and the respective buffers are replaced with those in + // src_literal. Upon return, src_literal is set to a nil shape (empty tuple). + Status MoveFrom(Literal&& src_literal, + const ShapeIndex& dest_shape_index = {}); + // Copies the values from src_literal, starting at src_base shape indexes, // to this literal, starting at dest_base, where the copy size in each // dimension is specified by copy_size. @@ -265,10 +256,24 @@ class Literal { // Note: if either src_literal or this literal contains dimensions with zero // element, then copy_size must be 0 in these dimensions while the // corresponding base indices being 0. - Status Copy(const Literal& src_literal, - tensorflow::gtl::ArraySlice src_base, - tensorflow::gtl::ArraySlice dest_base, - tensorflow::gtl::ArraySlice copy_size); + // This literal and 'src_literal' must be arrays. + Status CopySliceFrom(const Literal& src_literal, + tensorflow::gtl::ArraySlice src_base, + tensorflow::gtl::ArraySlice dest_base, + tensorflow::gtl::ArraySlice copy_size); + + // Returns a vector containing the tuple elements of this Literal as separate + // Literals. This Literal must be tuple-shaped and can be a nested tuple. The + // elements are moved into the new Literals; no data is copied. Upon return + // this Literal is set to a nil shape (empty tuple) + std::vector DecomposeTuple(); + + // This operation is the inverse of DecomposeTuple. The given elements are + // moved into the tuple elements of a new tuple-shaped Literal which is + // returned. Upon return, each of the Literals in 'elements' is set to a nil + // shape (empty tuple). + static Literal MoveIntoTuple( + tensorflow::gtl::MutableArraySlice elements); // Creates a new value that has the equivalent value as this literal, but // conforms to new_layout; e.g. a literal matrix that was in {0, 1} @@ -285,11 +290,16 @@ class Literal { std::unique_ptr Relayout(const Layout& new_layout, const ShapeIndex& shape_index = {}) const; - // Creates a new literal by reshaping this literal to have 'shape'. Both the - // original shape and 'shape' must contain the same number of elements. The + // An overload of Relayout which changes the layout of the entire shape rather + // than being limited to a single array within the shape. + std::unique_ptr Relayout(const Shape& shape_with_layout) const; + + // Creates a new literal by reshaping this literal to have the given + // dimensions. The total number of elements must not change; The // implementation currently only supports monotonic dim0-major layouts. + // This literal must be an array. StatusOr> Reshape( - tensorflow::gtl::ArraySlice shape) const; + tensorflow::gtl::ArraySlice dimensions) const; // Creates a new literal by reordering the dimensions of this literal. // The given `permutation` must be a permutation of the dimension numbers @@ -297,6 +307,7 @@ class Literal { // in the result literal (i.e., new_order[i] = old_order[permutation[i]]). // For example, a transpose call on a literal of shape [3 x 8 x 4] and // `permutation` = {2, 0, 1} returns a new literal of shape [4 x 3 x 8]. + // This literal must be an array. std::unique_ptr Transpose( tensorflow::gtl::ArraySlice permutation) const; @@ -305,6 +316,7 @@ class Literal { // same rank and layout as for the given literal. The number of indices in // start_indices and limit_indices must be the rank of the literal, and the // indices follow the order of the dimensions. + // This literal must be an array. std::unique_ptr Slice( tensorflow::gtl::ArraySlice start_indices, tensorflow::gtl::ArraySlice limit_indices) const; @@ -312,34 +324,35 @@ class Literal { // Creates a literal with a prepended dimension with bound "times"; e.g. a // f32[3x2] with times=4 will produce a f32[4x3x2] with the 3x2 from this // literal replicated four times. + // This literal must be an array. template std::unique_ptr Replicate(int64 times) const; // Converts this literal to another primitive type. Returns an error if the - // conversion is not possible. + // conversion is not possible. This literal must be array-shaped. StatusOr> Convert( PrimitiveType primitive_dest_type) const; - // Creates a literal value zero of the given primitive type. + // Creates a scalar literal value zero of the given primitive type. static Literal Zero(PrimitiveType primitive_type); - // Creates a literal value one of the given primitive type. + // Creates a scalar literal value one of the given primitive type. static Literal One(PrimitiveType primitive_type); - // Creates a literal value containing the minimum value of the given + // Creates a scalar literal value containing the minimum value of the given // primitive type. For floating-point types, returns -inf. static Literal MinValue(PrimitiveType primitive_type); - // Creates a literal value containing the maximum value of the given + // Creates a scalar literal value containing the maximum value of the given // primitive type. For floating-point types, returns inf. static Literal MaxValue(PrimitiveType primitive_type); // Creates a literal of the given shape where each element is `value`. template - static std::unique_ptr CreateFullWithMonotonicDim0MajorLayout( + static std::unique_ptr CreateFullWithDescendingLayout( tensorflow::gtl::ArraySlice dimensions, NativeT value); - // Creates a new literal from an array. The variants not ending with + // Creates a new literal from an Array type. The variants not ending with // WithLayout use the default XLA layout for the literal's linear // representation in memory. template @@ -388,35 +401,50 @@ class Literal { std::initializer_list> values, int64 projection_p, int64 projection_z); - // Clones this literal into an owned unique_ptr version. + // Clones this literal into a new Literal, or new std::unique_ptr. + Literal Clone() const; std::unique_ptr CloneToUnique() const; - // Returns the linear index of the given index within this literal's - // element_type repeated field. - int64 LinearIndex(tensorflow::gtl::ArraySlice multi_index) const; + // Gets or sets an element in the literal at the given index. The multi_index + // is CHECKed against the dimension sizes. + template + NativeT Get(tensorflow::gtl::ArraySlice multi_index, + const ShapeIndex& shape_index) const; + template + void Set(tensorflow::gtl::ArraySlice multi_index, + const ShapeIndex& shape_index, NativeT value); - // Gets or sets an element in the literal at the given index. The index is - // CHECKed against the dimension sizes. + // Overloads of Get and Set for array literals. CHECKs if the literal is not + // array-shaped and dense. template NativeT Get(tensorflow::gtl::ArraySlice multi_index) const; template void Set(tensorflow::gtl::ArraySlice multi_index, NativeT value); - // Returns a (Mutable)ArraySlice view of the array for this literal for the - // given NativeT (e.g., float). These functions map native type to XLA - // PrimitiveType via template specialization. The unspecialized forms below - // aborts to handle the error case where the given native type does not map to - // an XLA primitive type. + // Returns the multi-index of the element in a sparse literal at the given + // sparse element number. The sparse element number is the position with in + // the sparse array's list of (index, value) pairs, and is checked against the + // total number of (index, value) pairs in the sparse array. + tensorflow::gtl::ArraySlice GetSparseIndex( + int64 sparse_element_number, const ShapeIndex& shape_index = {}) const; + + // Returns the value of the element in a sparse literal at the given sparse + // element number. The sparse element number is the position with in the + // sparse array's list of (index, value) pairs, and is checked against the + // total number of (index, value) pairs in the sparse array. template - tensorflow::gtl::ArraySlice GetArraySlice() const { - static_assert(!std::is_same::value, - "Cannot map native type to primitive type."); - } + NativeT GetSparseElement(int64 sparse_element_number, + const ShapeIndex& shape_index = {}) const; + + // Appends the given element to the literal. If the elements are not appended + // in sorted order, then SortSparseElements should be called before calling + // other methods. This literal must have a sparse layout. template - tensorflow::gtl::MutableArraySlice GetMutableArraySlice() { - static_assert(!std::is_same::value, - "Cannot map native type to primitive type."); - } + void AppendSparseElement(tensorflow::gtl::ArraySlice multi_index, + NativeT value, const ShapeIndex& shape_index = {}); + + // Sorts the elements in a sparse array. + void SortSparseElements(const ShapeIndex& shape_index = {}); // Returns the element value at index (0, ..., 0), however many zeroes are // required for that index. @@ -425,10 +453,16 @@ class Literal { // As Get(), but determines the correct type and converts the value // into text. - string GetAsString(tensorflow::gtl::ArraySlice multi_index) const; + string GetAsString(tensorflow::gtl::ArraySlice multi_index, + const ShapeIndex& shape_index = {}) const; + + // As GetSparseElement(), but determines the correct type and converts the + // value into text. + string GetSparseElementAsString(int64 sparse_element_number, + const ShapeIndex& shape_index = {}) const; // As Get(), but determines the correct type and converts the value into - // int64. + // int64. This literal must be an array. StatusOr GetIntegralAsS64( tensorflow::gtl::ArraySlice multi_index) const; @@ -436,7 +470,8 @@ class Literal { template static std::unique_ptr MakeIdentityR2(int64 size); - // Returns a tuple literal composed of given literals. + // Returns a tuple literal composed of given literals. Data is copied from the + // given elements into the returned literal. static std::unique_ptr MakeTuple( tensorflow::gtl::ArraySlice elements); @@ -450,10 +485,6 @@ class Literal { static std::unique_ptr MakeTupleOwned( std::vector> elements); - // Validates that the data payload of the literal matches the literal shape; - // if it does not, an appropriate status is returned. - tensorflow::Status ValidateLiteral() const; - // Returns a string representation of the literal value. string ToString(bool print_layout = false) const; @@ -464,6 +495,8 @@ class Literal { // This function is useful if you want a polymorphic representation // of the tensor's elements (turning it to a string for something // like representation in a protobuf). + // + // This literal must have a dense layout. void EachCellAsString( const std::function indices, const string& value)>& per_cell) const; @@ -472,80 +505,45 @@ class Literal { NativeT value)> per_cell) const; - // Templated methods which populate the given repeated field in this literal - // with the given value(s). The Shape field of this literal is set - // to match the array dimensions and type. Examples: + // Populate this literal with the given values. Examples: // // // Populate with floats. // Array2D float_values = ... // literal.PopulateR2FromArray2D(values); // // // Populate with int32s. - // literal.PopulateR2({{1, 2}, {3, 4}}); + // literal.PopulateR2({{1, 2}, {3, 4}}); // - template - void PopulateR0(NativeT values); + // The shape and element type of this literal must match given values. For + // example, in the call above to literal.PopulateR2(), 'literal' must be a 2x2 + // array of S32. template void PopulateR1(tensorflow::gtl::ArraySlice values); void PopulateR1(const tensorflow::core::Bitmap& values); template void PopulateR2(std::initializer_list> values); template - void PopulateR2WithLayout( - std::initializer_list> values, - const Layout& layout); - template void PopulateFromArray(const Array& values); template - void PopulateFromArrayWithLayout(const Array& values, - const Layout& layout); - template void PopulateR2FromArray2D(const Array2D& values); template - void PopulateR2FromArray2DWithLayout(const Array2D& values, - const Layout& layout); - template void PopulateR3FromArray3D(const Array3D& values); template - void PopulateR3FromArray3DWithLayout(const Array3D& values, - const Layout& layout); - template void PopulateR4FromArray4D(const Array4D& values); - template - void PopulateR4FromArray4DWithLayout(const Array4D& values, - const Layout& layout); // Populates literal values by calling the generator function for every cell // in this literal object. // // generator must be a callable of the type // NativeT(tensorflow::gtl::ArraySlice indexes) or compatible. + // + // This literal must have a dense layout. template Status Populate(const FnType& generator); - // Creates a Literal of the given dimensions with all elements set to the - // given value. - template - void PopulateWithValue(NativeT value, - tensorflow::gtl::ArraySlice dimensions); - - // Returns a pointer to the underlying vector corresponding to the Literal's - // shape. - const void* InternalData() const; - void* MutableInternalData(); - - // Allocates space in the underlying vector of this literal sufficient to hold - // num_elements of this literal's primitive type. Values in the vector are set - // to zero. num_elements must equal the number of elements in the literal's - // shape. - void Reserve(int64 num_elements); - - // Allocates space in the underlying vector of this literal sufficient to hold - // num_elements of this literal's primitive type and sets each element in this - // literal to the given value. num_elements must equal the number of elements - // in this literal's shape. + // Fills this literal with the given value. template - void Resize(int64 num_elements, NativeT value); + void PopulateWithValue(NativeT value); // Returns whether every element in this literal is equal to value. // @@ -555,7 +553,7 @@ class Literal { // // If value doesn't fit in this literal's type, returns false. Values of 1/0 // are considered equal to true/false; other values are not considered equal - // to true. + // to true. Also if this literal is not array-shaped false is returned. bool IsAll(int8 value) const; // Like IsAll(const Literal&, int8), except we check whether the literal is @@ -566,7 +564,7 @@ class Literal { // This casts value to the type of literal, then compares using ==. The usual // admonishments about floating-point equality checks apply. We expect you to // use this to check for values that can be expressed precisely as a float, - // e.g. -0.5. + // e.g. -0.5. Also if this literal is not array-shaped false is returned. bool IsAllFloat(float value) const; // Like IsAll(const Literal&, int8), except we check whether the literal is @@ -578,23 +576,38 @@ class Literal { // admonishments about floating-point equality checks apply. We expect you to // use this to check for complex values that can be expressed precisely as // float pairs e.g. (-0.5, 1.0). + // + // This literal must have a dense layout. bool IsAllComplex(complex64 value) const; // Returns whether this literal is zero at the specified index. This literal - // must be an array. + // must be an array with a dense layout. bool IsZero(tensorflow::gtl::ArraySlice indices) const; - private: - // Copy from a LiteralProto instance. - void CopyFromProto(const LiteralProto& literal_proto); + // Return the count of the elements in the array at the given shape index in + // this literal. + int64 element_count(const ShapeIndex& index = {}) const { + return ShapeUtil::ElementsIn(ShapeUtil::GetSubshape(shape(), index)); + } + + // Return the count of the elements in the sparse array at the given shape + // index in this literal, which will be no larger than + // LayoutUtil::MaxSparseElements(SetSubshape(shape(), index).layout()). + int64 sparse_element_count() const; + + protected: + // 'allocate_arrays' indicates whether to allocate memory for the arrays in + // the shape. If false, buffer pointers inside of the Literal::Pieces are set + // to nullptr. + Literal(const Shape& shape, bool allocate_arrays); - // Internal template helper for the Copy() API, matching its arguments one by - // one. - template - Status CopyRange(const Literal& src_literal, - tensorflow::gtl::ArraySlice src_base, - tensorflow::gtl::ArraySlice dest_base, - tensorflow::gtl::ArraySlice copy_size); + // Internal template helper for the Literal::CopySliceFrom(), matching its + // arguments one by one. + template + Status CopySliceFromInternal(const Literal& src_literal, + tensorflow::gtl::ArraySlice src_base, + tensorflow::gtl::ArraySlice dest_base, + tensorflow::gtl::ArraySlice copy_size); // Utility structure which is used to create the optimal configuration for // a ShapeUtil::ForEachIndex() scan across two literals. @@ -619,163 +632,243 @@ class Literal { int64 minor_loop_size = 1; }; - Shape shape_; - std::vector u8s_; - std::vector s16s_; - std::vector s32s_; - std::vector s64s_; - std::vector u16s_; - std::vector u32s_; - std::vector u64s_; - std::vector bf16s_; - std::vector f16s_; - std::vector f32s_; - std::vector f64s_; - std::vector c64s_; - std::vector tuple_literals_; -}; - -std::ostream& operator<<(std::ostream& out, const Literal& literal); - -// Declarations of template specializations for GetArraySlice and -// GetMutableArraySlice. The specializations map native type to XLA primitive -// type. -template <> -tensorflow::gtl::ArraySlice Literal::GetArraySlice() const; - -template <> -tensorflow::gtl::ArraySlice Literal::GetArraySlice() const; - -template <> -tensorflow::gtl::ArraySlice Literal::GetArraySlice() const; - -template <> -tensorflow::gtl::ArraySlice Literal::GetArraySlice() const; + // A data structure representing a subshape at a particular ShapeIndex within + // the literal. For array-shaped ShapeIndexes, this data structure holds the + // pointer to the memory allocated for the array data. + class Piece { + public: + // Return the buffer holding the array data for this piece as an array + // slice. This piece must be array-shaped. + template + tensorflow::gtl::ArraySlice data() const; + template + tensorflow::gtl::MutableArraySlice data(); + + // Return the buffer holding the array data for this piece as a void*. This + // piece must be array-shaped. + void* untyped_data(); + const void* untyped_data() const; + + // Gets or sets an element in the array at the given index. The multi_index + // is CHECKed against the dimension sizes of the array. This piece must be + // array-shaped. + template + NativeT Get(tensorflow::gtl::ArraySlice index) const; + template + void Set(tensorflow::gtl::ArraySlice index, NativeT value); + + // Gets/sets the buffer holding the array data. + char* buffer() const { return buffer_; } + void set_buffer(char* buffer) { buffer_ = buffer; } + + // The array of multi-indices that provide the locations of non-zero + // elements in a sparse array. Only used if + // LayoutUtil::IsSparseArray(shape()) is true. + SparseIndexArray* sparse_indices() const { return sparse_indices_; } + void set_sparse_indices(SparseIndexArray* sparse_indices) { + sparse_indices_ = sparse_indices; + } -template <> -tensorflow::gtl::ArraySlice Literal::GetArraySlice() const; + // Gets or sets the subshape of this piece. This reference points to a + // subshape within the shape in the containing Literal (Literal::shape_). + const Shape& subshape() const { return *subshape_; } + void set_subshape(const Shape* subshape) { subshape_ = subshape; } -template <> -tensorflow::gtl::ArraySlice Literal::GetArraySlice() const; + // Returns the size in bytes of the buffer holding the array data. + int64 size_bytes() const { return ShapeUtil::ByteSizeOf(subshape()); } -template <> -tensorflow::gtl::ArraySlice Literal::GetArraySlice() const; + // Returns the number of elements in this piece's array. + int64 element_count() const { return ShapeUtil::ElementsIn(subshape()); } -template <> -tensorflow::gtl::ArraySlice Literal::GetArraySlice() const; + // Copy the data from 'src' into this piece's buffer. Shapes of this piece + // and src must be compatible. + Status CopyFrom(const Piece& src); -template <> -tensorflow::gtl::ArraySlice Literal::GetArraySlice() const; + // Returns true if this piece and 'other' contain the same data. This piece + // and 'other' must be array-shaped and compatible. + bool EqualElements(const Piece& other) const; -template <> -inline tensorflow::gtl::ArraySlice Literal::GetArraySlice() - const { - DCHECK(shape().element_type() == F32); - return f32s(); -} + // Writes the shape and data (if array-shaped) into the given proto. + void WriteToProto(LiteralProto* proto) const; -template <> -tensorflow::gtl::ArraySlice Literal::GetArraySlice() const; + // Copies the data from the given proto into this piece. The shape of this + // piece must be equal (not just compatible) to the shape of the proto. + Status CopyFromProto(const LiteralProto& proto); -template <> -tensorflow::gtl::ArraySlice Literal::GetArraySlice() const; + // Sorts the elements in a sparse array. + void SortSparseElements(); -template <> -tensorflow::gtl::ArraySlice Literal::GetArraySlice() const; + private: + // Recursive helper for EqualElements. + template + bool EqualElementsInternal(const Piece& other, + std::vector* multi_index) const; -template <> -tensorflow::gtl::ArraySlice Literal::GetArraySlice() - const; + // Helper for SortSparseElements that has the element type as a template + // parameter. + template + void SortSparseElementsInternal(); -template <> -tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice(); + // For array-shaped pieces, this is the buffer holding the literal data. + char* buffer_ = nullptr; -template <> -tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice(); + // For sparse arrays, this is the array of indices. + SparseIndexArray* sparse_indices_ = nullptr; -template <> -tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice(); - -template <> -tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice(); + // The shape of piece. This points into the shape of the containing Literal + // (Literal::shape_). + const Shape* subshape_ = nullptr; + }; -template <> -tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice(); + // Returns the piece at the given ShapeIndex. + Piece& piece(const ShapeIndex& shape_index) { + return *pieces_.mutable_element(shape_index); + } + const Piece& piece(const ShapeIndex& shape_index) const { + return pieces_.element(shape_index); + } -template <> -tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice(); + // Returns the piece at the root of the shape (empty ShapeIndex). + Piece& root_piece() { return piece({}); } + const Piece& root_piece() const { return piece({}); } -template <> -tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice(); + // Deallocate the buffers held by this literal (if the literal owns the + // buffer). + void DeallocateBuffers(); -template <> -tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice(); + Shape shape_; + ShapeTree pieces_; -template <> -tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice(); + // Whether the buffers held in pieces_ are owned by this Literal. + bool owns_buffers_; -template <> -tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice(); + // LiteralView must access and manipulate Pieces of other Literals. + friend class LiteralView; +}; // namespace xla -template <> -tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice(); +std::ostream& operator<<(std::ostream& out, const Literal& literal); -template <> -tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice(); +// A read-only view of a Literal. A LiteralView contains pointers to buffers +// owned by the viewed Literal. +// +// TODO(b/71550060): Replace LiteralView with Literal slice classes (immutable +// and mutable) similar to (Mutable)ArraySlice. +class LiteralView : public Literal { + public: + // Create and return a view of the given literal rooted at the given shape + // index within the given literal. A factory is used rather than a public + // constructor because only const LiteralViews are supported. It's still + // possible to create non-const LiteralViews via the copy constructors, but + // the factory method makes it a bit less likely. Implementing literal slices + // will fix this undesirable situation (b/71550060). + static const LiteralView Create(const Literal& literal, + const ShapeIndex& view_root = {}); -template <> -tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice(); + LiteralView(const LiteralView& other); + LiteralView& operator=(const LiteralView& other); -template <> -tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice(); + virtual ~LiteralView(); -template <> -void Literal::Resize(int64 num_elements, bool value); + private: + LiteralView(const Literal& literal, const ShapeIndex& view_root); -template <> -void Literal::Resize(int64 num_elements, int8 value); + // Helper for the copy constructor and copy assignment operator. + void CopyFrom(const LiteralView& other); +}; -template <> -void Literal::Resize(int64 num_elements, uint8 value); +template +tensorflow::gtl::ArraySlice Literal::Piece::data() const { + CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape()); + CHECK_EQ(subshape().element_type(), + primitive_util::NativeToPrimitiveType()) + << "Attempting to access " + << PrimitiveType_Name(primitive_util::NativeToPrimitiveType()) + << " type, but literal element type is " + << PrimitiveType_Name(subshape().element_type()); + return tensorflow::gtl::ArraySlice( + reinterpret_cast(buffer()), + ShapeUtil::ElementsIn(subshape())); +} -template <> -void Literal::Resize(int64 num_elements, int32 value); +template +tensorflow::gtl::MutableArraySlice Literal::Piece::data() { + CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape()); + CHECK_EQ(subshape().element_type(), + primitive_util::NativeToPrimitiveType()) + << "Attempting to access " + << PrimitiveType_Name(primitive_util::NativeToPrimitiveType()) + << " type, but literal element type is " + << PrimitiveType_Name(subshape().element_type()); + return tensorflow::gtl::MutableArraySlice( + reinterpret_cast(buffer()), ShapeUtil::ElementsIn(subshape())); +} -template <> -void Literal::Resize(int64 num_elements, uint32 value); +template +NativeT Literal::Piece::Get( + tensorflow::gtl::ArraySlice multi_index) const { + CHECK(LayoutUtil::IsDenseArray(subshape())); + return data()[IndexUtil::MultidimensionalIndexToLinearIndex( + subshape(), multi_index)]; +} -template <> -void Literal::Resize(int64 num_elements, int64 value); +template +void Literal::Piece::Set(tensorflow::gtl::ArraySlice multi_index, + NativeT value) { + CHECK(LayoutUtil::IsDenseArray(subshape())); + data()[IndexUtil::MultidimensionalIndexToLinearIndex( + subshape(), multi_index)] = value; +} -template <> -void Literal::Resize(int64 num_elements, uint64 value); +template +tensorflow::gtl::ArraySlice Literal::data( + const ShapeIndex& shape_index) const { + return piece(shape_index).data(); +} -template <> -void Literal::Resize(int64 num_elements, float value); +template +tensorflow::gtl::MutableArraySlice Literal::data( + const ShapeIndex& shape_index) { + return piece(shape_index).data(); +} -template <> -void Literal::Resize(int64 num_elements, double value); +template +inline NativeT Literal::Get(tensorflow::gtl::ArraySlice multi_index, + const ShapeIndex& shape_index) const { + return piece(shape_index).Get(multi_index); +} -template <> -void Literal::Resize(int64 num_elements, half value); +template +inline NativeT Literal::Get( + tensorflow::gtl::ArraySlice multi_index) const { + return root_piece().Get(multi_index); +} -template <> -void Literal::Resize(int64 num_elements, bfloat16 value); +template +inline void Literal::Set(tensorflow::gtl::ArraySlice multi_index, + const ShapeIndex& shape_index, NativeT value) { + return piece(shape_index).Set(multi_index, value); +} -template <> -void Literal::Resize(int64 num_elements, complex64 value); +template +inline void Literal::Set(tensorflow::gtl::ArraySlice multi_index, + NativeT value) { + return root_piece().Set(multi_index, value); +} template /* static */ std::unique_ptr Literal::CreateR0(NativeT value) { - auto literal = MakeUnique(); - literal->PopulateR0(value); + auto literal = MakeUnique(ShapeUtil::MakeShape( + primitive_util::NativeToPrimitiveType(), {})); + literal->Set({}, value); return literal; } template /* static */ std::unique_ptr Literal::CreateR1( tensorflow::gtl::ArraySlice values) { - auto literal = MakeUnique(); + auto literal = MakeUnique( + ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType(), + {static_cast(values.size())})); literal->PopulateR1(values); return literal; } @@ -784,8 +877,12 @@ template /* static */ std::unique_ptr Literal::CreateR2WithLayout( std::initializer_list> values, const Layout& layout) { - auto literal = MakeUnique(); - literal->PopulateR2WithLayout(values, layout); + auto literal = MakeUnique(ShapeUtil::MakeShapeWithLayout( + primitive_util::NativeToPrimitiveType(), + {static_cast(values.size()), + static_cast(values.begin()->size())}, + AsInt64Slice(layout.minor_to_major()))); + literal->PopulateR2(values); return literal; } @@ -858,6 +955,21 @@ template return CreateR4FromArray4DWithLayout(tmp, layout); } +template +/* static */ std::unique_ptr Literal::CreateSparse( + tensorflow::gtl::ArraySlice dimensions, SparseIndexArray indices, + tensorflow::gtl::ArraySlice values, bool sort) { + int64 num_elements = values.size(); + int64 rank = dimensions.size(); + CHECK_EQ(num_elements, indices.index_count()); + CHECK_EQ(rank, indices.rank()); + auto literal = MakeUnique(ShapeUtil::MakeShapeWithSparseLayout( + primitive_util::NativeToPrimitiveType(), dimensions, + indices.max_indices())); + literal->PopulateSparse(indices, values, sort); + return literal; +} + template /* static */ std::unique_ptr Literal::CreateR4( std::initializer_list template /* static */ std::unique_ptr Literal::CreateFromArrayWithLayout( const Array& values, const Layout& layout) { - auto literal = MakeUnique(); - literal->PopulateFromArrayWithLayout(values, layout); + auto literal = MakeUnique(ShapeUtil::MakeShapeWithLayout( + primitive_util::NativeToPrimitiveType(), values.dimensions(), + AsInt64Slice(layout.minor_to_major()))); + literal->PopulateFromArray(values); return literal; } @@ -970,81 +1084,33 @@ template return CreateFromArrayWithLayout(values, layout); } -template -NativeT Literal::Get(tensorflow::gtl::ArraySlice multi_index) const { - int64 linear_index = LinearIndex(multi_index); - return GetArraySlice().at(linear_index); -} - template NativeT Literal::GetFirstElement() const { - return GetArraySlice().at(0); -} - -template <> -inline uint8 Literal::Get( - tensorflow::gtl::ArraySlice multi_index) const { - CHECK(shape().element_type() == U8); - int64 linear_index = LinearIndex(multi_index); - return u8s()[linear_index]; -} - -template <> -inline int8 Literal::Get( - tensorflow::gtl::ArraySlice multi_index) const { - CHECK(shape().element_type() == S8); - int64 linear_index = LinearIndex(multi_index); - return u8s()[linear_index]; -} - -template <> -inline half Literal::Get( - tensorflow::gtl::ArraySlice multi_index) const { - CHECK(shape().element_type() == F16); - int64 linear_index = LinearIndex(multi_index); - return GetArraySlice()[linear_index]; -} - -template <> -inline bfloat16 Literal::Get( - tensorflow::gtl::ArraySlice multi_index) const { - CHECK(shape().element_type() == BF16); - int64 linear_index = LinearIndex(multi_index); - return GetArraySlice()[linear_index]; + return data().at(0); } template -void Literal::Set(tensorflow::gtl::ArraySlice multi_index, - NativeT value) { - int64 linear_index = LinearIndex(multi_index); - GetMutableArraySlice().at(linear_index) = value; -} - -template <> -inline void Literal::Set(tensorflow::gtl::ArraySlice multi_index, - uint8 value) { - int64 linear_index = LinearIndex(multi_index); - (*mutable_u8s())[linear_index] = value; +NativeT Literal::GetSparseElement(int64 sparse_element_number, + const ShapeIndex& shape_index) const { + CHECK( + LayoutUtil::IsSparseArray(ShapeUtil::GetSubshape(shape(), shape_index))); + return data(shape_index)[sparse_element_number]; } -template <> -inline void Literal::Set(tensorflow::gtl::ArraySlice multi_index, - int8 value) { - return Set(multi_index, value); -} - -template <> -inline void Literal::Set(tensorflow::gtl::ArraySlice multi_index, - int64 value) { - int64 linear_index = LinearIndex(multi_index); - (*mutable_s64s())[linear_index] = value; -} - -template <> -/* static */ inline void Literal::Set( - tensorflow::gtl::ArraySlice multi_index, uint64 value) { - int64 linear_index = LinearIndex(multi_index); - (*mutable_u64s())[linear_index] = value; +template +void Literal::AppendSparseElement( + tensorflow::gtl::ArraySlice multi_index, NativeT value, + const ShapeIndex& shape_index) { + Piece& p = piece(shape_index); + const Shape& subshape = p.subshape(); + CHECK(LayoutUtil::IsSparseArray(subshape)); + int64 rank = ShapeUtil::Rank(subshape); + CHECK_EQ(multi_index.size(), rank); + int64 last_element = p.sparse_indices()->index_count(); + CHECK_LT(last_element, LayoutUtil::MaxSparseElements(subshape.layout())); + p.sparse_indices()->Append(multi_index); + CHECK_LT(last_element, p.data().size()); + p.data()[last_element] = value; } // Returns an identity matrix (rank 2) with the given row and column count. @@ -1071,51 +1137,31 @@ void Literal::EachCell( } while (IndexUtil::BumpIndices(shape(), &indices)); } -template -inline void Literal::PopulateR0(NativeT value) { - *mutable_shape() = ShapeUtil::MakeShape( - primitive_util::NativeToPrimitiveType(), {}); - Resize(1, value); -} - template inline void Literal::PopulateR1(tensorflow::gtl::ArraySlice values) { - *mutable_shape() = - ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType(), - {static_cast(values.size())}); - Reserve(values.size()); + CHECK(ShapeUtil::IsArray(shape())); + CHECK_EQ(ShapeUtil::Rank(shape()), 1); + CHECK_EQ(ShapeUtil::ElementsIn(shape()), values.size()); + CHECK_EQ(shape().element_type(), + primitive_util::NativeToPrimitiveType()); for (int64 i = 0; i < values.size(); ++i) { Set({i}, values[i]); } } -inline void Literal::PopulateR1(const tensorflow::core::Bitmap& values) { - *mutable_shape() = - ShapeUtil::MakeShape(PRED, {static_cast(values.bits())}); - Reserve(values.bits()); - for (int64 i = 0; i < static_cast(values.bits()); ++i) { - Set({i}, values.get(i)); - } -} - template -void Literal::PopulateR2WithLayout( - std::initializer_list> values, - const Layout& layout) { - *mutable_shape() = ShapeUtil::MakeShapeWithLayout( - primitive_util::NativeToPrimitiveType(), - {static_cast(values.size()), - static_cast(values.begin()->size())}, - AsInt64Slice(layout.minor_to_major())); +void Literal::PopulateR2( + std::initializer_list> values) { + CHECK(ShapeUtil::IsArray(shape())); + CHECK_EQ(ShapeUtil::Rank(shape()), 2); + CHECK_EQ(shape().element_type(), + primitive_util::NativeToPrimitiveType()); const int64 dim0_size = values.size(); const int64 dim1_size = values.begin()->size(); CHECK_EQ(dim0_size, shape().dimensions(0)); CHECK_EQ(dim1_size, shape().dimensions(1)); - const int64 num_elements = dim1_size * dim0_size; - Reserve(num_elements); - int64 dim0 = 0; for (auto inner_list : values) { int64 dim1 = 0; @@ -1129,69 +1175,65 @@ void Literal::PopulateR2WithLayout( } template -void Literal::PopulateR2( - std::initializer_list> values) { - PopulateR2WithLayout(values, LayoutUtil::GetDefaultLayoutForR2()); -} - -template -void Literal::PopulateFromArrayWithLayout(const Array& values, - const Layout& layout) { - *mutable_shape() = ShapeUtil::MakeShapeWithLayout( - primitive_util::NativeToPrimitiveType(), values.dimensions(), - AsInt64Slice(layout.minor_to_major())); - Reserve(values.num_elements()); +void Literal::PopulateFromArray(const Array& values) { + CHECK(ShapeUtil::IsArray(shape())); + CHECK_EQ(shape().element_type(), + primitive_util::NativeToPrimitiveType()); + CHECK_EQ(ShapeUtil::Rank(shape()), values.num_dimensions()); + for (int dim = 0; dim < values.num_dimensions(); ++dim) { + CHECK_EQ(values.dim(dim), shape().dimensions(dim)); + } values.Each([this](tensorflow::gtl::ArraySlice indices, NativeT value) { this->Set(indices, value); }); } -template -void Literal::PopulateFromArray(const Array& values) { - PopulateFromArrayWithLayout( - values, LayoutUtil::GetDefaultLayoutForRank(values.num_dimensions())); -} - -template -void Literal::PopulateR2FromArray2DWithLayout(const Array2D& values, - const Layout& layout) { - PopulateFromArrayWithLayout(values, layout); -} - template void Literal::PopulateR2FromArray2D(const Array2D& values) { PopulateFromArray(values); } -template -void Literal::PopulateR3FromArray3DWithLayout(const Array3D& values, - const Layout& layout) { - PopulateFromArrayWithLayout(values, layout); -} - template void Literal::PopulateR3FromArray3D(const Array3D& values) { PopulateFromArray(values); } template -void Literal::PopulateR4FromArray4DWithLayout(const Array4D& values, - const Layout& layout) { - PopulateFromArrayWithLayout(values, layout); +void Literal::PopulateR4FromArray4D(const Array4D& values) { + PopulateFromArray(values); } template -void Literal::PopulateR4FromArray4D(const Array4D& values) { - PopulateFromArray(values); +void Literal::PopulateSparse(SparseIndexArray indices, + tensorflow::gtl::ArraySlice values, + bool sort) { + CHECK(LayoutUtil::IsSparseArray(shape())); + int rank = ShapeUtil::Rank(shape()); + CHECK_EQ(indices.rank(), rank); + int64 max_elements = LayoutUtil::MaxSparseElements(shape().layout()); + CHECK_LE(indices.max_indices(), max_elements); + int64 num_elements = values.size(); + CHECK_LE(num_elements, max_elements); + CHECK_EQ(num_elements, indices.index_count()); + auto root_data = root_piece().data(); + root_data.remove_suffix(max_elements - values.size()); + std::copy(values.begin(), values.end(), root_data.begin()); + *this->root_piece().sparse_indices() = std::move(indices); + if (sort) { + auto root_data = this->root_piece().data(); + root_data.remove_suffix(root_data.size() - num_elements); + this->root_piece().sparse_indices()->SortWithValues(root_data); + } + DCHECK(this->root_piece().sparse_indices()->Validate(shape())); } template Status Literal::Populate(const FnType& generator) { const Shape& this_shape = shape(); const int64 rank = ShapeUtil::Rank(this_shape); + TF_RET_CHECK(LayoutUtil::IsDenseArray(this_shape)); TF_RET_CHECK(this_shape.element_type() == primitive_util::NativeToPrimitiveType()); - tensorflow::gtl::MutableArraySlice data = - GetMutableArraySlice(); + tensorflow::gtl::MutableArraySlice literal_data = data(); if (rank > 0) { StrideConfig stride_config(this_shape, this_shape, AsInt64Slice(this_shape.dimensions())); @@ -1200,11 +1242,12 @@ Status Literal::Populate(const FnType& generator) { ShapeUtil::GetDimension(this_shape, stride_config.minor_dimension); auto init_function = [&](const std::vector& indexes) { - const int64 index = LinearIndex(indexes); + const int64 index = + IndexUtil::MultidimensionalIndexToLinearIndex(shape(), indexes); std::copy(indexes.begin(), indexes.end(), minor_scan_indexes.begin()); for (int64 i = 0; i < minor_dimension_size; ++i) { minor_scan_indexes[stride_config.minor_dimension] = i; - data.at(index + i) = generator(minor_scan_indexes); + literal_data.at(index + i) = generator(minor_scan_indexes); } return true; }; @@ -1213,32 +1256,27 @@ Status Literal::Populate(const FnType& generator) { init_function); } else { // For scalars. - data.at(0) = generator({}); + literal_data.at(0) = generator({}); } return Status::OK(); } template -void Literal::PopulateWithValue(NativeT value, - tensorflow::gtl::ArraySlice dimensions) { - *mutable_shape() = ShapeUtil::MakeShape( - primitive_util::NativeToPrimitiveType(), dimensions); - Resize(ShapeUtil::ElementsIn(shape()), value); +void Literal::PopulateWithValue(NativeT value) { + CHECK(ShapeUtil::IsArray(shape())); + CHECK_EQ(shape().element_type(), + primitive_util::NativeToPrimitiveType()); + for (NativeT& element : data()) { + element = value; + } } template -/* static */ std::unique_ptr -Literal::CreateFullWithMonotonicDim0MajorLayout( +/* static */ std::unique_ptr Literal::CreateFullWithDescendingLayout( tensorflow::gtl::ArraySlice dimensions, NativeT value) { - Shape this_shape = ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout( - primitive_util::NativeToPrimitiveType(), dimensions); - auto literal = MakeUnique(); - *literal->mutable_shape() = this_shape; - literal->Reserve(ShapeUtil::ElementsIn(this_shape)); - std::vector index(dimensions.size(), 0); - do { - literal->Set(index, value); - } while (IndexUtil::BumpIndices(this_shape, &index)); + auto literal = MakeUnique(ShapeUtil::MakeShapeWithDescendingLayout( + primitive_util::NativeToPrimitiveType(), dimensions)); + literal->PopulateWithValue(value); return literal; } @@ -1249,14 +1287,12 @@ std::unique_ptr Literal::Replicate(int64 times) const { for (int64 bound : shape().dimensions()) { bounds.push_back(bound); } - auto literal = MakeUnique(); - *literal->mutable_shape() = - ShapeUtil::MakeShape(shape().element_type(), bounds); + auto literal = + MakeUnique(ShapeUtil::MakeShape(shape().element_type(), bounds)); int64 elements = ShapeUtil::ElementsIn(literal->shape()); if (elements == 0) { return literal; } - literal->Reserve(elements); DimensionVector output_indices(bounds.size(), 0); tensorflow::gtl::ArraySlice input_indices = output_indices; diff --git a/tensorflow/compiler/xla/literal_util_test.cc b/tensorflow/compiler/xla/literal_util_test.cc index 816bb3c549eaae4e8fc2b7d438627266603272f9..b3583c2eb75de8297d5e7507430491f119bd4462 100644 --- a/tensorflow/compiler/xla/literal_util_test.cc +++ b/tensorflow/compiler/xla/literal_util_test.cc @@ -31,6 +31,7 @@ namespace xla { namespace { using ::testing::ElementsAre; +using ::testing::HasSubstr; class LiteralUtilTest : public ::testing::Test { protected: @@ -192,6 +193,34 @@ TEST_F(LiteralUtilTest, CreateR3FromArray3d) { ASSERT_EQ(expected, result); } +TEST_F(LiteralUtilTest, CreateSparse) { + std::vector dimensions = {8, 8, 8}; + Array2D indices = { + {3, 4, 5}, + {1, 2, 3}, + {2, 3, 4}, + {3, 5, 6}, + }; + std::vector values = {7, 8, 9, 10}; + auto literal = Literal::CreateSparse( + dimensions, SparseIndexArray(indices.n1() + 3, indices), values); + + Array2D expected_indices = { + {1, 2, 3}, + {2, 3, 4}, + {3, 4, 5}, + {3, 5, 6}, + }; + std::vector expected_values = {8, 9, 7, 10}; + + EXPECT_EQ(literal->sparse_indices()->data(), + tensorflow::gtl::ArraySlice( + expected_indices.data(), expected_indices.num_elements())); + EXPECT_EQ(tensorflow::gtl::ArraySlice(literal->data().data(), + expected_values.size()), + tensorflow::gtl::ArraySlice(expected_values)); +} + TEST_F(LiteralUtilTest, LiteralR4F32ProjectedStringifies) { // clang-format off auto literal = Literal::CreateR4Projected({ @@ -293,29 +322,28 @@ TEST_F(LiteralUtilTest, NonScalarEquality) { auto matrix_different = Literal::CreateR2({{4.0, 3.0}, {1.0, 2.0}}); auto vector_literal = Literal::CreateR1({1.0, 2.0, 3.0, 4.0}); auto scalar = Literal::CreateR0(1.0); + Literal nil(ShapeUtil::MakeNil()); EXPECT_EQ(*matrix, *matrix); EXPECT_EQ(*matrix, *matrix_clone); EXPECT_NE(*matrix, *matrix_different); EXPECT_NE(*matrix, *vector_literal); EXPECT_NE(*matrix, *scalar); + EXPECT_NE(*matrix, nil); + EXPECT_EQ(nil, nil); } TEST_F(LiteralUtilTest, DifferentLayoutEquality) { // Test equality with literals which have different layouts. - auto colmajor = MakeUnique(); - *colmajor->mutable_shape() = ShapeUtil::MakeShape(F32, {2, 2}); - *colmajor->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1}); - colmajor->Reserve(4); + auto colmajor = + MakeUnique(ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1})); colmajor->Set({0, 0}, 1.0); colmajor->Set({0, 1}, 2.0); colmajor->Set({1, 0}, 3.0); colmajor->Set({1, 1}, 4.0); - auto rowmajor = MakeUnique(); - *rowmajor->mutable_shape() = ShapeUtil::MakeShape(F32, {2, 2}); - *rowmajor->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({1, 0}); - rowmajor->Reserve(4); + auto rowmajor = + MakeUnique(ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0})); rowmajor->Set({0, 0}, 1.0); rowmajor->Set({0, 1}, 2.0); rowmajor->Set({1, 0}, 3.0); @@ -515,7 +543,7 @@ TYPED_TEST(LiteralUtilTestTemplated, Relayout2x2) { TEST_F(LiteralUtilTest, ReshapeR0) { auto original = Literal::CreateR0(1.7f); - auto reshape = original->Reshape(/*shape=*/{}).ConsumeValueOrDie(); + auto reshape = original->Reshape(/*dimensions=*/{}).ConsumeValueOrDie(); EXPECT_EQ(*original, *reshape); } @@ -597,24 +625,26 @@ TEST_F(LiteralUtilTest, TestR4RelayoutEquivalence) { TEST_F(LiteralUtilTest, TestR2LinearLayout) { // Test expected memory layout of R2 dim0-minor (column-major) literal. - auto mat_dim0minor = Literal::CreateR2WithLayout({{1, 2, 3}, {4, 5, 6}}, - layout_r2_dim0minor_); - EXPECT_EQ(mat_dim0minor->s32s_size(), 6); - EXPECT_THAT(mat_dim0minor->s32s(), ElementsAre(1, 4, 2, 5, 3, 6)); + auto mat_dim0minor = Literal::CreateR2WithLayout( + {{1, 2, 3}, {4, 5, 6}}, layout_r2_dim0minor_); + EXPECT_EQ(mat_dim0minor->element_count(), 6); + EXPECT_THAT(mat_dim0minor->data(), ElementsAre(1, 4, 2, 5, 3, 6)); // Test expected memory layout when using Relayout to row major. auto relaid_mat_to_dim0major = mat_dim0minor->Relayout(layout_r2_dim0major_); - EXPECT_THAT(relaid_mat_to_dim0major->s32s(), ElementsAre(1, 2, 3, 4, 5, 6)); + EXPECT_THAT(relaid_mat_to_dim0major->data(), + ElementsAre(1, 2, 3, 4, 5, 6)); // Test expected memory layout of R2 created with dim0-major (row-major). - auto mat_dim0major = Literal::CreateR2WithLayout({{1, 2, 3}, {4, 5, 6}}, - layout_r2_dim0major_); - EXPECT_EQ(mat_dim0major->s32s_size(), 6); - EXPECT_THAT(mat_dim0major->s32s(), ElementsAre(1, 2, 3, 4, 5, 6)); + auto mat_dim0major = Literal::CreateR2WithLayout( + {{1, 2, 3}, {4, 5, 6}}, layout_r2_dim0major_); + EXPECT_EQ(mat_dim0major->element_count(), 6); + EXPECT_THAT(mat_dim0major->data(), ElementsAre(1, 2, 3, 4, 5, 6)); // Test expected memory layout when using Relayout to column major. auto relaid_mat_to_dim0minor = mat_dim0major->Relayout(layout_r2_dim0minor_); - EXPECT_THAT(relaid_mat_to_dim0minor->s32s(), ElementsAre(1, 4, 2, 5, 3, 6)); + EXPECT_THAT(relaid_mat_to_dim0minor->data(), + ElementsAre(1, 4, 2, 5, 3, 6)); } TEST_F(LiteralUtilTest, TestR3LinearLayout) { @@ -634,27 +664,27 @@ TEST_F(LiteralUtilTest, TestR3LinearLayout) { auto lit_dim0minor = Literal::CreateR3FromArray3DWithLayout(arr3d, layout_r3_dim0minor_); - EXPECT_EQ(lit_dim0minor->s32s_size(), 12); + EXPECT_EQ(lit_dim0minor->element_count(), 12); std::vector expected_dim0minor{1, 7, 4, 10, 2, 8, 5, 11, 3, 9, 6, 12}; - EXPECT_THAT(lit_dim0minor->s32s(), + EXPECT_THAT(lit_dim0minor->data(), testing::ElementsAreArray(expected_dim0minor)); // Test expected memory layout when using Relayout to row major. auto relaid_lit_to_dim0major = lit_dim0minor->Relayout(layout_r3_dim0major_); std::vector expected_dim0major{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; - EXPECT_THAT(relaid_lit_to_dim0major->s32s(), + EXPECT_THAT(relaid_lit_to_dim0major->data(), testing::ElementsAreArray(expected_dim0major)); // Test expected memory layout of R3 created with dim0-major (row-major). auto lit_dim0major = Literal::CreateR3FromArray3DWithLayout(arr3d, layout_r3_dim0major_); - EXPECT_EQ(lit_dim0major->s32s_size(), 12); - EXPECT_THAT(lit_dim0major->s32s(), + EXPECT_EQ(lit_dim0major->element_count(), 12); + EXPECT_THAT(lit_dim0major->data(), testing::ElementsAreArray(expected_dim0major)); // Test expected memory layout when using Relayout to column major. auto relaid_lit_to_dim0minor = lit_dim0major->Relayout(layout_r3_dim0minor_); - EXPECT_THAT(relaid_lit_to_dim0minor->s32s(), + EXPECT_THAT(relaid_lit_to_dim0minor->data(), testing::ElementsAreArray(expected_dim0minor)); } @@ -687,28 +717,28 @@ TEST_F(LiteralUtilTest, SliceR3U32Full) { } TEST_F(LiteralUtilTest, PopulateR1S64) { - Literal output; + Literal output(ShapeUtil::MakeShape(S64, {1})); output.PopulateR1({77}); auto expected = Literal::CreateR1({77}); EXPECT_EQ(output, *expected); } TEST_F(LiteralUtilTest, PopulateR1U64) { - Literal output; + Literal output(ShapeUtil::MakeShape(U64, {2})); output.PopulateR1({{77, 88}}); auto expected = Literal::CreateR1({{77, 88}}); EXPECT_EQ(output, *expected); } TEST_F(LiteralUtilTest, PopulateR1C64) { - Literal output; + Literal output(ShapeUtil::MakeShape(C64, {1})); output.PopulateR1({{77, 88}}); auto expected = Literal::CreateR1({{77, 88}}); EXPECT_EQ(output, *expected); } TEST_F(LiteralUtilTest, PopulateR2C64) { - Literal output; + Literal output(ShapeUtil::MakeShape(C64, {2, 2})); output.PopulateR2({{{7, 8}, {9, 10}}, {{1, 2}, {3, 4}}}); auto expected = Literal::CreateR2({{{7, 8}, {9, 10}}, {{1, 2}, {3, 4}}}); @@ -716,78 +746,78 @@ TEST_F(LiteralUtilTest, PopulateR2C64) { } TEST_F(LiteralUtilTest, PopulateWithValueR0BF16) { - Literal output; + Literal output(ShapeUtil::MakeShape(BF16, {})); bfloat16 h(0.25f); - output.PopulateWithValue(h, {}); + output.PopulateWithValue(h); auto expected = Literal::CreateR0(h); EXPECT_EQ(output, *expected); } TEST_F(LiteralUtilTest, PopulateWithValueR1BF16) { - Literal output; + Literal output(ShapeUtil::MakeShape(BF16, {3})); bfloat16 h(0.5f); - output.PopulateWithValue(h, {3}); + output.PopulateWithValue(h); auto expected = Literal::CreateR1({h, h, h}); EXPECT_EQ(output, *expected); } TEST_F(LiteralUtilTest, PopulateWithValueR2BF16) { - Literal output; + Literal output(ShapeUtil::MakeShape(BF16, {2, 2})); bfloat16 h(2.0f); - output.PopulateWithValue(h, {2, 2}); + output.PopulateWithValue(h); auto expected = Literal::CreateR2({{h, h}, {h, h}}); EXPECT_EQ(output, *expected); } TEST_F(LiteralUtilTest, PopulateWithValueR0F32) { - Literal output; - output.PopulateWithValue(2.5f, {}); + Literal output(ShapeUtil::MakeShape(F32, {})); + output.PopulateWithValue(2.5f); auto expected = Literal::CreateR0(2.5f); EXPECT_EQ(output, *expected); } TEST_F(LiteralUtilTest, PopulateWithValueR1S64) { - Literal output; - output.PopulateWithValue(-7, {3}); + Literal output(ShapeUtil::MakeShape(S64, {3})); + output.PopulateWithValue(-7); auto expected = Literal::CreateR1({-7, -7, -7}); EXPECT_EQ(output, *expected); } TEST_F(LiteralUtilTest, PopulateWithValueR2U64) { - Literal output; - output.PopulateWithValue(42, {2, 2}); + Literal output(ShapeUtil::MakeShape(U64, {2, 2})); + output.PopulateWithValue(42); auto expected = Literal::CreateR2({{42, 42}, {42, 42}}); EXPECT_EQ(output, *expected); } TEST_F(LiteralUtilTest, PopulateWithValueR2C64) { - Literal output; - output.PopulateWithValue({4, 2}, {2, 2}); + Literal output(ShapeUtil::MakeShape(C64, {2, 2})); + output.PopulateWithValue({4, 2}); auto expected = Literal::CreateR2({{{4, 2}, {4, 2}}, {{4, 2}, {4, 2}}}); EXPECT_EQ(output, *expected); } TEST_F(LiteralUtilTest, PopulateWithValueR0F16) { - Literal output; + Literal output(ShapeUtil::MakeShape(F16, {})); half h(0.25f); - output.PopulateWithValue(h, {}); + output.PopulateWithValue(h); auto expected = Literal::CreateR0(h); EXPECT_EQ(output, *expected); } TEST_F(LiteralUtilTest, PopulateWithValueR1F16) { - Literal output; + Literal output(ShapeUtil::MakeShape(F16, {3})); half h(0.5f); - output.PopulateWithValue(h, {3}); + output.PopulateWithValue(h); auto expected = Literal::CreateR1({h, h, h}); EXPECT_EQ(output, *expected); } TEST_F(LiteralUtilTest, PopulateWithValueR2F16) { - Literal output; + Literal output(ShapeUtil::MakeShape(F16, {2, 2})); half h(2.0f); - output.PopulateWithValue(h, {2, 2}); + output.PopulateWithValue(h); auto expected = Literal::CreateR2({{h, h}, {h, h}}); EXPECT_EQ(output, *expected); } @@ -803,7 +833,7 @@ TEST_F(LiteralUtilTest, ReplicateR2U32) { EXPECT_EQ(*output, *expected); } -TEST_F(LiteralUtilTest, Copy) { +TEST_F(LiteralUtilTest, CopySliceFrom) { const int64 dimensions[] = {17, 15, 34, 21}; const int64 layouts[][4] = { {3, 2, 1, 0}, {0, 2, 1, 3}, {0, 1, 2, 3}, {2, 0, 3, 1}, {1, 3, 0, 2}}; @@ -826,7 +856,7 @@ TEST_F(LiteralUtilTest, Copy) { const int64 src_base[] = {3, 1, 5, 7}; const int64 dest_base[] = {6, 4, 12, 2}; const int64 copy_size[] = {7, 8, 11, 9}; - TF_EXPECT_OK(blank->Copy(*source, src_base, dest_base, copy_size)); + TF_EXPECT_OK(blank->CopySliceFrom(*source, src_base, dest_base, copy_size)); std::vector source_indexes(TF_ARRAYSIZE(dimensions), 0); std::vector blank_indexes(TF_ARRAYSIZE(dimensions), 0); @@ -849,16 +879,16 @@ TEST_F(LiteralUtilTest, Copy) { } } -TEST_F(LiteralUtilTest, CopyScalars) { +TEST_F(LiteralUtilTest, CopyFromScalars) { auto zero = Literal::CreateR0(0); auto nine = Literal::CreateR0(9); - TF_EXPECT_OK(zero->Copy(*nine, {}, {}, {})); + TF_EXPECT_OK(zero->CopyFrom(*nine)); EXPECT_EQ(*zero, *nine); auto vect = Literal::CreateR1({3, 4, 9, 12, 5, 17, 21}); - TF_EXPECT_OK(zero->Copy(*vect, {5}, {}, {})); + TF_EXPECT_OK(zero->CopySliceFrom(*vect, {5}, {}, {})); EXPECT_EQ(zero->Get({}), 17); - TF_EXPECT_OK(vect->Copy(*zero, {}, {4}, {})); + TF_EXPECT_OK(vect->CopySliceFrom(*zero, {}, {4}, {})); EXPECT_EQ(vect->Get({4}), 17); } @@ -872,7 +902,7 @@ TEST_F(LiteralUtilTest, CopyFromAndToZeroElement) { const auto empty = Literal::CreateFromShape(empty_r1_shape); auto nine = Literal::CreateR1({9}); - TF_EXPECT_OK(nine->Copy(*empty, {0}, {0}, {0})); + TF_EXPECT_OK(nine->CopySliceFrom(*empty, {0}, {0}, {0})); EXPECT_EQ(*nine, *const_nine); } @@ -881,18 +911,101 @@ TEST_F(LiteralUtilTest, CopyFromAndToZeroElement) { const auto empty = Literal::CreateFromShape(empty_r1_shape); auto nine = Literal::CreateR1({9}); - TF_EXPECT_OK(empty->Copy(*nine, {0}, {0}, {0})); + TF_EXPECT_OK(empty->CopySliceFrom(*nine, {0}, {0}, {0})); EXPECT_EQ(*empty, *const_empty); } } +TEST_F(LiteralUtilTest, CopyFromNilShape) { + Literal nil_literal0(ShapeUtil::MakeNil()); + Literal nil_literal1(ShapeUtil::MakeNil()); + // This doesn't actually do any copying, but it should succeed. + TF_ASSERT_OK(nil_literal0.CopyFrom(nil_literal1)); +} + +TEST_F(LiteralUtilTest, CopyFromArrays) { + auto scalar_42 = Literal::CreateR0(42.0); + auto scalar_123 = Literal::CreateR0(123.0); + EXPECT_NE(*scalar_42, *scalar_123); + TF_ASSERT_OK(scalar_42->CopyFrom(*scalar_123, /*dest_shape_index=*/{}, + /*src_shape_index=*/{})); + EXPECT_EQ(*scalar_42, *scalar_123); + EXPECT_EQ(scalar_42->Get({}), 123.0f); + + auto matrix_1234 = Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + auto matrix_5678 = Literal::CreateR2({{5.0, 6.0}, {7.0, 8.0}}); + EXPECT_NE(*matrix_1234, *matrix_5678); + EXPECT_EQ(matrix_1234->Get({0, 0}), 1.0f); + TF_ASSERT_OK(matrix_1234->CopyFrom(*matrix_5678, /*dest_shape_index=*/{}, + /*src_shape_index=*/{})); + EXPECT_EQ(*matrix_1234, *matrix_5678); + EXPECT_EQ(matrix_1234->Get({0, 0}), 5.0f); +} + +TEST_F(LiteralUtilTest, CopyFromTuples) { + auto matrix = Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + Literal nil_literal(ShapeUtil::MakeNil()); + auto nested_tuple = Literal::MakeTuple( + {matrix.get(), + Literal::MakeTuple({Literal::CreateR0(42).get(), + Literal::CreateR1({23.0, 44.0}).get(), + &nil_literal}) + .get()}); + // Create a tuple the same shape as the inner tuple of nested_tuple but with + // different values.. + auto tuple = Literal::MakeTuple({Literal::CreateR0(-5).get(), + Literal::CreateR1({2.0, 4.0}).get(), + &nil_literal}); + + EXPECT_EQ(*matrix, LiteralView::Create(*nested_tuple, {0})); + EXPECT_EQ(nested_tuple->Get({}, {1, 0}), 42); + EXPECT_EQ(nested_tuple->Get({0}, {1, 1}), 23.0); + EXPECT_EQ(nested_tuple->Get({1}, {1, 1}), 44.0); + + // Overwrite the inner tuple element of nested_tuple with the contents of + // 'tuple'. + TF_ASSERT_OK(nested_tuple->CopyFrom(*tuple, /*dest_shape_index=*/{1}, + /*src_shape_index=*/{})); + + // The matrix element should be unchanged. + EXPECT_EQ(*matrix, LiteralView::Create(*nested_tuple, {0})); + + // The tuple element should have been copied from 'tuple'. + EXPECT_EQ(nested_tuple->Get({}, {1, 0}), -5); + EXPECT_EQ(nested_tuple->Get({0}, {1, 1}), 2.0); + EXPECT_EQ(nested_tuple->Get({1}, {1, 1}), 4.0); +} +TEST_F(LiteralUtilTest, CopyBetweenSameTuple) { + auto tuple = Literal::MakeTuple( + {Literal::CreateR0(-2).get(), Literal::CreateR0(4).get()}); + + EXPECT_EQ(tuple->Get({}, {0}), -2); + EXPECT_EQ(tuple->Get({}, {1}), 4); + + // Copy from one element to the other. + TF_ASSERT_OK(tuple->CopyFrom(*tuple, /*dest_shape_index=*/{1}, + /*src_shape_index=*/{0})); + + EXPECT_EQ(tuple->Get({}, {0}), -2); + EXPECT_EQ(tuple->Get({}, {1}), -2); +} + +TEST_F(LiteralUtilTest, CopyFromDifferentShapes) { + auto matrix = Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + auto vector = Literal::CreateR1({5.0, 7.0}); + Status status = matrix->CopyFrom(*vector); + ASSERT_FALSE(status.ok()); + ASSERT_THAT(status.error_message(), + HasSubstr("Destination subshape incompatible")); +} + TEST_F(LiteralUtilTest, F16) { // Verify that the internal data views are consistent and that they // are in little endian format // TODO - modify if we make the data format machine endianess dependent auto m1 = Literal::CreateFromShape(ShapeUtil::MakeShape(F16, {2, 2})); Literal* l1 = m1.get(); - const char* d1 = static_cast(l1->InternalData()); + const char* d1 = reinterpret_cast(l1->data().data()); EXPECT_EQ(d1[0], 0); EXPECT_EQ(d1[1], 0); EXPECT_EQ(d1[2], 0); @@ -901,13 +1014,12 @@ TEST_F(LiteralUtilTest, F16) { EXPECT_EQ(d1[5], 0); EXPECT_EQ(d1[6], 0); EXPECT_EQ(d1[7], 0); - EXPECT_EQ(l1->InternalData(), l1->MutableInternalData()); half h1(1.0f); half h2(2.0f); auto m2 = Literal::CreateR2({{h1, h2}, {h2, h1}}); Literal* l2 = m2.get(); - const char* d2 = static_cast(l2->InternalData()); + const char* d2 = reinterpret_cast(l2->data().data()); EXPECT_EQ(d2[0], 0); EXPECT_EQ(d2[1], 0x3C); EXPECT_EQ(d2[2], 0); @@ -916,7 +1028,6 @@ TEST_F(LiteralUtilTest, F16) { EXPECT_EQ(d2[5], 0x40); EXPECT_EQ(d2[6], 0); EXPECT_EQ(d2[7], 0x3C); - EXPECT_EQ(l2->InternalData(), l2->MutableInternalData()); } TEST_F(LiteralUtilTest, Populate) { @@ -941,7 +1052,9 @@ TEST_F(LiteralUtilTest, Populate) { auto generator = [&](tensorflow::gtl::ArraySlice indexes) -> uint32 { // Offsets from linear index just to avoid R0 literals to be initialized // with zero. - return literal->LinearIndex(indexes) + 17; + return IndexUtil::MultidimensionalIndexToLinearIndex(literal->shape(), + indexes) + + 17; }; TF_EXPECT_OK(literal->Populate(generator)); @@ -1118,16 +1231,18 @@ TEST_F(LiteralUtilTest, CopyFromProto_Bool) { for (int len = 0; len < 25; ++len) { p.mutable_shape()->clear_dimensions(); p.mutable_shape()->add_dimensions(len); + LayoutUtil::SetToDefaultLayout(p.mutable_shape()); p.clear_preds(); for (int i = 0; i < len; ++i) { p.add_preds((i % 2) == (len % 2)); } - Literal literal(p); - ASSERT_EQ(len, literal.preds_size()); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr literal, + Literal::CreateFromProto(p)); + ASSERT_EQ(len, literal->data().size()); int i = 0; - for (auto it = literal.preds().begin(); it < literal.preds().end(); ++it) { - EXPECT_EQ((i % 2) == (len % 2), *it); + for (bool value : literal->data()) { + EXPECT_EQ((i % 2) == (len % 2), value); ++i; } } @@ -1141,8 +1256,7 @@ TEST_F(LiteralUtilTest, ToProto_f16) { auto m = Literal::CreateR2({{h1, h2}, {h2, h1}}); Literal* l = m.get(); EXPECT_EQ(4, ShapeUtil::ElementsIn(l->shape())); - EXPECT_EQ(4, l->f16s().size()); - EXPECT_EQ(4, l->f16s_size()); + EXPECT_EQ(4, l->data().size()); LiteralProto p = l->ToProto(); EXPECT_EQ(4, ShapeUtil::ElementsIn(p.shape())); @@ -1168,17 +1282,12 @@ TEST_F(LiteralUtilTest, CopyFromProto_f16) { p.mutable_shape()->set_element_type(F16); p.mutable_shape()->clear_dimensions(); p.mutable_shape()->add_dimensions(4); + LayoutUtil::SetToDefaultLayout(p.mutable_shape()); p.clear_f16s(); p.set_f16s(half_vals, 8); - - Literal literal(p); - ASSERT_EQ(4, literal.f16s_size()); - ASSERT_EQ(h1, literal.f16s(0)); - ASSERT_EQ(h2, literal.f16s(1)); - ASSERT_EQ(h2, literal.f16s(2)); - ASSERT_EQ(h1, literal.f16s(3)); - - const std::vector& r = literal.f16s(); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr literal, + Literal::CreateFromProto(p)); + auto r = literal->data(); ASSERT_EQ(4, r.size()); ASSERT_EQ(h1, r[0]); ASSERT_EQ(h2, r[1]); @@ -1186,24 +1295,402 @@ TEST_F(LiteralUtilTest, CopyFromProto_f16) { ASSERT_EQ(h1, r[3]); } -TEST_F(LiteralUtilTest, Subliterals) { +TEST_F(LiteralUtilTest, LiteralViewTest) { + auto scalar = Literal::CreateR0(1.0); + auto matrix = Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + auto tuple = Literal::MakeTuple({scalar.get(), matrix.get()}); + auto nested_tuple = Literal::MakeTuple({tuple.get(), scalar.get()}); + Literal nil(ShapeUtil::MakeNil()); + + EXPECT_EQ(LiteralView::Create(*scalar, {}), *scalar); + EXPECT_EQ(LiteralView::Create(*matrix, {}), *matrix); + EXPECT_EQ(LiteralView::Create(*tuple, {}), *tuple); + EXPECT_EQ(LiteralView::Create(*nested_tuple, {}), *nested_tuple); + EXPECT_EQ(LiteralView::Create(nil, {}), nil); + + EXPECT_EQ(LiteralView::Create(*tuple, {0}), *scalar); + EXPECT_EQ(LiteralView::Create(*tuple, {1}), *matrix); + + EXPECT_EQ(LiteralView::Create(*nested_tuple, {0}), *tuple); + EXPECT_EQ(LiteralView::Create(*nested_tuple, {0, 0}), *scalar); + EXPECT_EQ(LiteralView::Create(*nested_tuple, {0, 1}), *matrix); + EXPECT_EQ(LiteralView::Create(*nested_tuple, {1}), *scalar); +} + +TEST_F(LiteralUtilTest, MutatingLiteralView) { + auto scalar = Literal::CreateR0(1.0); + auto matrix = Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + auto tuple = Literal::MakeTuple({scalar.get(), matrix.get()}); + auto nested_tuple = Literal::MakeTuple({tuple.get(), scalar.get()}); + // Verify that changing the underlying data beneath the view changes the + // data of the view itself. + const auto nested_tuple_view = LiteralView::Create(*nested_tuple); + EXPECT_EQ( + nested_tuple->Get(/*multi_index=*/{}, /*shape_index=*/{0, 0}), + 1.0f); + EXPECT_EQ(nested_tuple_view.Get(/*multi_index=*/{}, + /*shape_index=*/{0, 0}), + 1.0f); + nested_tuple->Set(/*multi_index=*/{}, /*shape_index=*/{0, 0}, 555.0f); + EXPECT_EQ( + nested_tuple->Get(/*multi_index=*/{}, /*shape_index=*/{0, 0}), + 555.0f); + EXPECT_EQ(nested_tuple_view.Get(/*multi_index=*/{}, + /*shape_index=*/{0, 0}), + 555.0f); +} + +TEST_F(LiteralUtilTest, LiteralViewOfALiteralView) { auto scalar = Literal::CreateR0(1.0); auto matrix = Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); auto tuple = Literal::MakeTuple({scalar.get(), matrix.get()}); auto nested_tuple = Literal::MakeTuple({tuple.get(), scalar.get()}); - EXPECT_EQ(&scalar->GetSubliteral(/*index=*/{}), scalar.get()); - EXPECT_EQ(&matrix->GetSubliteral(/*index=*/{}), matrix.get()); - EXPECT_EQ(&tuple->GetSubliteral(/*index=*/{}), tuple.get()); - EXPECT_EQ(&nested_tuple->GetSubliteral(/*index=*/{}), nested_tuple.get()); + const auto nested_tuple_view = LiteralView::Create(*nested_tuple); + const auto tuple_view = + LiteralView::Create(nested_tuple_view, /*view_root=*/{0}); + const auto matrix_view = LiteralView::Create(tuple_view, /*view_root=*/{1}); + EXPECT_EQ(matrix_view, *Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}})); +} + +TEST_F(LiteralUtilTest, LiteralMove) { + std::unique_ptr matrix = + Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + Literal literal(std::move(*matrix)); + + EXPECT_TRUE( + ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {2, 2}), literal.shape())); + EXPECT_EQ(literal.Get({0, 0}), 1.0); + EXPECT_EQ(literal.Get({0, 1}), 2.0); + EXPECT_EQ(literal.Get({1, 0}), 3.0); + EXPECT_EQ(literal.Get({1, 1}), 4.0); +} - EXPECT_EQ(tuple->GetSubliteral(/*index=*/{0}), *scalar); - EXPECT_EQ(tuple->GetSubliteral(/*index=*/{1}), *matrix); +TEST_F(LiteralUtilTest, DecomposeTuple) { + Literal nil_literal(ShapeUtil::MakeNil()); + auto nested_tuple = Literal::MakeTuple( + {Literal::CreateR2({{1, 2}, {3, 4}}).get(), + Literal::MakeTuple({Literal::CreateR0(42).get(), + Literal::CreateR1({23.0, 44.0}).get(), + &nil_literal}) + .get(), + &nil_literal}); + + EXPECT_FALSE(ShapeUtil::IsNil(nested_tuple->shape())); + std::vector elements = nested_tuple->DecomposeTuple(); + EXPECT_TRUE(ShapeUtil::IsNil(nested_tuple->shape())); + + ASSERT_EQ(elements.size(), 3); + + EXPECT_TRUE(ShapeUtil::Compatible(elements[0].shape(), + ShapeUtil::MakeShape(S32, {2, 2}))); + EXPECT_EQ(elements[0].Get({0, 0}), 1); + EXPECT_EQ(elements[0].Get({0, 1}), 2); + EXPECT_EQ(elements[0].Get({1, 0}), 3); + EXPECT_EQ(elements[0].Get({1, 1}), 4); + + EXPECT_TRUE(ShapeUtil::Compatible( + elements[1].shape(), + ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(S32, {}), + ShapeUtil::MakeShape(F64, {2}), + ShapeUtil::MakeNil()}))); + EXPECT_EQ(elements[1].Get({}, /*shape_index=*/{0}), 42); + EXPECT_EQ(elements[1].Get({0}, /*shape_index=*/{1}), 23.0); + EXPECT_EQ(elements[1].Get({1}, /*shape_index=*/{1}), 44.0); + + EXPECT_TRUE(ShapeUtil::Compatible(elements[2].shape(), ShapeUtil::MakeNil())); +} + +TEST_F(LiteralUtilTest, DecomposeEmptyTuple) { + Literal nil_literal(ShapeUtil::MakeNil()); + std::vector elements = nil_literal.DecomposeTuple(); + EXPECT_EQ(elements.size(), 0); +} + +TEST_F(LiteralUtilTest, MoveIntoTuple) { + std::vector elements; + elements.push_back(std::move(*Literal::CreateR0(1.0))); + elements.push_back(std::move(*Literal::CreateR1({4, 8}))); + elements.push_back(std::move( + *Literal::MakeTuple({Literal::CreateR0(42).get(), + Literal::CreateR1({23.0, 44.0}).get()}) + + )); + + Literal literal = Literal::MoveIntoTuple(&elements); + ASSERT_TRUE(ShapeUtil::IsTuple(literal.shape())); + ASSERT_EQ(ShapeUtil::TupleElementCount(literal.shape()), 3); + + EXPECT_EQ(literal.Get({}, /*shape_index=*/{0}), 1.0); + EXPECT_EQ(literal.Get({0}, /*shape_index=*/{1}), 4); + EXPECT_EQ(literal.Get({1}, /*shape_index=*/{1}), 8); + EXPECT_EQ(literal.Get({}, /*shape_index=*/{2, 0}), 42); + EXPECT_EQ(literal.Get({0}, /*shape_index=*/{2, 1}), 23.0); + EXPECT_EQ(literal.Get({1}, /*shape_index=*/{2, 1}), 44.0); + + for (const Literal& element : elements) { + EXPECT_TRUE(ShapeUtil::IsNil(element.shape())); + } +} + +TEST_F(LiteralUtilTest, MoveIntoEmptyTuple) { + Literal literal = Literal::MoveIntoTuple({}); + ASSERT_TRUE(ShapeUtil::IsTuple(literal.shape())); + ASSERT_EQ(ShapeUtil::TupleElementCount(literal.shape()), 0); +} + +TEST_F(LiteralUtilTest, LiteralMoveAssignment) { + Literal literal; + EXPECT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeNil(), literal.shape())); + + std::unique_ptr matrix = + Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + literal = std::move(*matrix); + + EXPECT_TRUE( + ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {2, 2}), literal.shape())); + EXPECT_EQ(literal.Get({0, 0}), 1.0); + EXPECT_EQ(literal.Get({0, 1}), 2.0); + EXPECT_EQ(literal.Get({1, 0}), 3.0); + EXPECT_EQ(literal.Get({1, 1}), 4.0); +} + +TEST_F(LiteralUtilTest, LiteralViewCopy) { + std::unique_ptr matrix = + Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + const auto matrix_view = LiteralView::Create(*matrix); + LiteralView matrix_view_copy(matrix_view); + + EXPECT_EQ(matrix_view_copy.Get({0, 0}), 1.0); + EXPECT_EQ(matrix_view_copy.Get({0, 1}), 2.0); + EXPECT_EQ(matrix_view_copy.Get({1, 0}), 3.0); + EXPECT_EQ(matrix_view_copy.Get({1, 1}), 4.0); +} + +TEST_F(LiteralUtilTest, GetSetTuple) { + auto tuple = Literal::MakeTuple( + {Literal::CreateR0(42.0).get(), + Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}).get()}); + EXPECT_EQ(tuple->Get(/*multi_index=*/{}, /*shape_index=*/{0}), 42.0); + tuple->Set(/*multi_index=*/{}, /*shape_index=*/{0}, -5.0); + EXPECT_EQ(tuple->Get(/*multi_index=*/{}, /*shape_index=*/{0}), -5.0); + + EXPECT_EQ(tuple->Get(/*multi_index=*/{1, 0}, /*shape_index=*/{1}), + 3.0); + tuple->Set(/*multi_index=*/{1, 0}, /*shape_index=*/{1}, -4.0); + EXPECT_EQ(tuple->Get(/*multi_index=*/{1, 0}, /*shape_index=*/{1}), + -4.0); +} + +TEST_F(LiteralUtilTest, CreateFromShapeZeroInitialized) { + // Literals constructed using CreateFromShape should be zero initialized. + std::unique_ptr scalar_f32 = + Literal::CreateFromShape(ShapeUtil::MakeShape(F32, {})); + EXPECT_EQ(scalar_f32->Get({}), 0.0); + EXPECT_TRUE(scalar_f32->IsAll(0)); + + std::unique_ptr vector_s32 = + Literal::CreateFromShape(ShapeUtil::MakeShape(S32, {3})); + EXPECT_EQ(vector_s32->Get({0}), 0); + EXPECT_EQ(vector_s32->Get({1}), 0); + EXPECT_EQ(vector_s32->Get({2}), 0); + EXPECT_TRUE(vector_s32->IsAll(0)); + + std::unique_ptr tuple = + Literal::CreateFromShape(ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F64, {}), ShapeUtil::MakeShape(PRED, {2}), + ShapeUtil::MakeShape(U64, {2, 1}), ShapeUtil::MakeShape(C64, {})})); + + EXPECT_EQ(tuple->Get({}, {0}), 0.0); + EXPECT_EQ(tuple->Get({0}, {1}), false); + EXPECT_EQ(tuple->Get({1}, {1}), false); + EXPECT_EQ(tuple->Get({0, 0}, {2}), 0); + EXPECT_EQ(tuple->Get({1, 0}, {2}), 0); + EXPECT_EQ(tuple->Get({}, {3}), complex64(0.0f, 0.0f)); +} + +TEST_F(LiteralUtilTest, ProtoRoundTrip) { + // Test serializing then deserializing a Literal through a proto. + auto one_f32 = Literal::CreateR0(1.0); + auto two_f32 = Literal::CreateR0(2.0); + auto vector_int8 = Literal::CreateR1({-128, 0, 2, 4, 7, 56, 127}); + auto vector_c64 = Literal::CreateR1({{1.0, 2.0}, {3.0, 4.0}}); + auto vector_bfloat16 = Literal::CreateR1( + {bfloat16{-1.0}, bfloat16{2.0}, bfloat16{-3.0}}); + auto vector_half = + Literal::CreateR1({half{10.0}, half{20.0}, half{-30.0}}); + auto matrix_pred = + Literal::CreateR2({{true, false, true}, {false, false, true}}); + auto tuple = Literal::MakeTuple( + {one_f32.get(), vector_half.get(), matrix_pred.get(), matrix_pred.get()}); + Literal nil_literal(ShapeUtil::MakeNil()); + auto nested_tuple = Literal::MakeTuple( + {tuple.get(), vector_bfloat16.get(), tuple.get(), &nil_literal}); + + auto to_from_proto = [](const Literal& literal) -> Literal { + return std::move(*Literal::CreateFromProto(literal.ToProto()).ValueOrDie()); + }; + + EXPECT_EQ(*one_f32, to_from_proto(*one_f32)); + EXPECT_EQ(*vector_c64, to_from_proto(*vector_c64)); + EXPECT_EQ(*vector_bfloat16, to_from_proto(*vector_bfloat16)); + EXPECT_EQ(*matrix_pred, to_from_proto(*matrix_pred)); + EXPECT_EQ(*tuple, to_from_proto(*tuple)); + EXPECT_EQ(*nested_tuple, to_from_proto(*nested_tuple)); + EXPECT_EQ(nil_literal, to_from_proto(nil_literal)); + + EXPECT_NE(*one_f32, *two_f32); + EXPECT_NE(*one_f32, to_from_proto(*two_f32)); +} + +TEST_F(LiteralUtilTest, InvalidProtoNoValues) { + // Proto contains a shape, but no values. + LiteralProto proto; + *proto.mutable_shape() = ShapeUtil::MakeShape(F32, {3}); + Status status = Literal::CreateFromProto(proto).status(); + ASSERT_FALSE(status.ok()); + ASSERT_THAT(status.error_message(), + HasSubstr("Expected 3 elements in LiteralProto")); +} + +TEST_F(LiteralUtilTest, InvalidProtoNoShape) { + // Proto contains values, but no shape. + LiteralProto proto; + proto.add_preds(false); + proto.add_preds(true); + proto.add_preds(false); + Status status = Literal::CreateFromProto(proto).status(); + ASSERT_FALSE(status.ok()); + ASSERT_THAT(status.error_message(), HasSubstr("LiteralProto has no shape")); +} + +TEST_F(LiteralUtilTest, InvalidProtoWrongContainer) { + // Proto contains values in wrong container. + LiteralProto proto; + *proto.mutable_shape() = ShapeUtil::MakeShape(F32, {3}); + proto.add_preds(false); + proto.add_preds(true); + proto.add_preds(false); + Status status = Literal::CreateFromProto(proto).status(); + ASSERT_FALSE(status.ok()); + ASSERT_THAT(status.error_message(), + HasSubstr("Expected 3 elements in LiteralProto")); +} + +TEST_F(LiteralUtilTest, InvalidProtoTooFewValues) { + // Proto contains too few values. + LiteralProto proto; + *proto.mutable_shape() = ShapeUtil::MakeShape(F32, {42, 2}); + proto.add_f32s(1.0); + proto.add_f32s(2.0); + proto.add_f32s(3.0); + Status status = Literal::CreateFromProto(proto).status(); + ASSERT_FALSE(status.ok()); + ASSERT_THAT(status.error_message(), + HasSubstr("Expected 84 elements in LiteralProto")); +} + +TEST_F(LiteralUtilTest, InvalidProtoTooManyValues) { + // Proto contains too many values. + LiteralProto proto; + *proto.mutable_shape() = ShapeUtil::MakeShape(S32, {2}); + proto.add_s32s(42); + proto.add_s32s(-10); + proto.add_s32s(100); + Status status = Literal::CreateFromProto(proto).status(); + ASSERT_FALSE(status.ok()); + ASSERT_THAT(status.error_message(), + HasSubstr("Expected 2 elements in LiteralProto")); +} + +TEST_F(LiteralUtilTest, InvalidProtoMissingLayout) { + // Proto shape missing layout. + LiteralProto proto; + *proto.mutable_shape() = ShapeUtil::MakeShape(PRED, {2, 2}); + LayoutUtil::ClearLayout(proto.mutable_shape()); + proto.add_preds(true); + proto.add_preds(false); + proto.add_preds(true); + proto.add_preds(false); + Status status = Literal::CreateFromProto(proto).status(); + ASSERT_FALSE(status.ok()); + ASSERT_THAT(status.error_message(), HasSubstr("LiteralProto has no layout")); +} + +TEST_F(LiteralUtilTest, InvalidProtoTooFewTupleElements) { + // Proto has the too few tuple elements. + LiteralProto proto; + *proto.mutable_shape() = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(PRED, {2}), ShapeUtil::MakeShape(F32, {})}); + LiteralProto* element0 = proto.add_tuple_literals(); + *element0->mutable_shape() = + ShapeUtil::GetTupleElementShape(proto.shape(), 0); + element0->add_preds(false); + element0->add_preds(true); + + Status status = Literal::CreateFromProto(proto).status(); + ASSERT_FALSE(status.ok()); + ASSERT_THAT(status.error_message(), HasSubstr("Expected 2 tuple elements")); +} + +TEST_F(LiteralUtilTest, InvalidProtoTooManyTupleElements) { + // Proto has the too many tuple elements. + LiteralProto proto; + *proto.mutable_shape() = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(PRED, {2}), ShapeUtil::MakeShape(F32, {})}); + LiteralProto* element0 = proto.add_tuple_literals(); + *element0->mutable_shape() = + ShapeUtil::GetTupleElementShape(proto.shape(), 0); + element0->add_preds(false); + element0->add_preds(true); + LiteralProto* element1 = proto.add_tuple_literals(); + *element1->mutable_shape() = + ShapeUtil::GetTupleElementShape(proto.shape(), 1); + element1->add_f32s(42.0); + LiteralProto* element2 = proto.add_tuple_literals(); + *element2->mutable_shape() = ShapeUtil::MakeShape(F32, {}); + element2->add_f32s(123.0); + + Status status = Literal::CreateFromProto(proto).status(); + ASSERT_FALSE(status.ok()); + ASSERT_THAT(status.error_message(), HasSubstr("Expected 2 tuple elements")); +} + +TEST_F(LiteralUtilTest, SortSparseElements) { + auto literal = + Literal::CreateSparse({10, 10, 10}, SparseIndexArray(10, 3), {}); + literal->AppendSparseElement({2, 3, 4}, 2.0); + literal->AppendSparseElement({3, 4, 5}, 3.0); + literal->AppendSparseElement({1, 2, 3}, 1.0); + literal->SortSparseElements(); + ASSERT_EQ(literal->ToString(false), + "f32[10,10,10]{[1, 2, 3]: 1, [2, 3, 4]: 2, [3, 4, 5]: 3}"); +} - EXPECT_EQ(nested_tuple->GetSubliteral(/*index=*/{0}), *tuple); - EXPECT_EQ(nested_tuple->GetSubliteral(/*index=*/{0, 0}), *scalar); - EXPECT_EQ(nested_tuple->GetSubliteral(/*index=*/{0, 1}), *matrix); - EXPECT_EQ(nested_tuple->GetSubliteral(/*index=*/{1}), *scalar); +TEST_F(LiteralUtilTest, GetSparseElementAsString) { + std::vector dimensions = {10, 10, 10}; + SparseIndexArray indices(10, {{1, 2, 3}, {2, 3, 4}, {3, 4, 5}}); + + ASSERT_EQ( + Literal::CreateSparse(dimensions, indices, {true, false, true}) + ->GetSparseElementAsString(1), + "false"); + ASSERT_EQ(Literal::CreateSparse(dimensions, indices, {1, 2, 3}) + ->GetSparseElementAsString(1), + tensorflow::strings::StrCat(int64{2})); + ASSERT_EQ(Literal::CreateSparse(dimensions, indices, {1.0, 2.0, 3.0}) + ->GetSparseElementAsString(1), + tensorflow::strings::StrCat(double{2.0})); + ASSERT_EQ(Literal::CreateSparse(dimensions, indices, + {half{1.0}, half{2.0}, half{3.0}}) + ->GetSparseElementAsString(1), + tensorflow::strings::StrCat(half{2.0})); + ASSERT_EQ( + Literal::CreateSparse( + dimensions, indices, + std::vector{{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}}) + ->GetSparseElementAsString(1), + tensorflow::strings::StrCat("(", float{3.0}, ", ", float{4.0}, ")")); } } // namespace diff --git a/tensorflow/compiler/xla/map_util.h b/tensorflow/compiler/xla/map_util.h index 51d0d5f86f00c539951e8e2baa6296337a5a21e9..50659c12405f2a29c69b03b3c7de5bd6cb6af9c2 100644 --- a/tensorflow/compiler/xla/map_util.h +++ b/tensorflow/compiler/xla/map_util.h @@ -60,6 +60,12 @@ bool ContainsKey(const Collection& collection, const Key& key) { return collection.find(key) != collection.end(); } +// Inserts `value` into `set`. Dies if it was already present. +template +void InsertOrDie(Set* const set, const typename Set::value_type& value) { + CHECK(set->insert(value).second) << "duplicate value: " << value; +} + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_MAP_UTIL_H_ diff --git a/tensorflow/compiler/xla/packed_literal_reader.cc b/tensorflow/compiler/xla/packed_literal_reader.cc index 70e0f5a74711c8ceef1b6d4225141aa1cc9c6219..857aae0a7982a57bb3057a6f267f5f033a0fdde4 100644 --- a/tensorflow/compiler/xla/packed_literal_reader.cc +++ b/tensorflow/compiler/xla/packed_literal_reader.cc @@ -44,11 +44,11 @@ StatusOr> PackedLiteralReader::Read( VLOG(3) << "reading shape from file: " << ShapeUtil::HumanString(shape) << " layout: " << (layout == nullptr ? "" : layout->ShortDebugString()); - auto result = MakeUnique(); - *result->mutable_shape() = shape; + Shape literal_shape = shape; if (layout != nullptr) { - TF_RETURN_IF_ERROR(LayoutUtil::ValidateLayoutForShape(*layout, shape)); - *result->mutable_shape()->mutable_layout() = *layout; + TF_RETURN_IF_ERROR( + LayoutUtil::ValidateLayoutForShape(*layout, literal_shape)); + *literal_shape.mutable_layout() = *layout; } if (shape.element_type() != F32) { @@ -57,10 +57,12 @@ StatusOr> PackedLiteralReader::Read( PrimitiveType_Name(shape.element_type()).c_str()); } + auto result = MakeUnique(literal_shape); + result->PopulateWithValue(std::numeric_limits::quiet_NaN()); + int64 elements = ShapeUtil::ElementsIn(shape); - result->Resize(elements, std::numeric_limits::quiet_NaN()); - std::vector* field = result->mutable_f32s(); - char* data = tensorflow::bit_cast(field->data()); + tensorflow::gtl::ArraySlice field = result->data(); + char* data = tensorflow::bit_cast(field.data()); uint64 bytes = elements * sizeof(float); tensorflow::StringPiece sp; auto s = file_->Read(offset_, bytes, &sp, data); diff --git a/tensorflow/compiler/xla/primitive_util.h b/tensorflow/compiler/xla/primitive_util.h index 19c6a138885c61f1304bfae3d8bb5d958a1bb5bc..cb4583d198b454be1432134a9f6a77dbbbe5bdd8 100644 --- a/tensorflow/compiler/xla/primitive_util.h +++ b/tensorflow/compiler/xla/primitive_util.h @@ -26,6 +26,13 @@ limitations under the License. namespace xla { namespace primitive_util { +// The number of exponent bits in a BF16 value. +const int kBFloat16ExponentBits = 8; + +// The number of mantissa bits in a BF16 value. There is an implicit leading +// 1, so there is an implicit additional bit of precision. +const int kBFloat16MantissaBits = 7; + // Returns the XLA primitive type (eg, F32) corresponding to the given // template parameter native type (eg, float). template diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..a8ca0e3ea0115d412e96ebacb320cc0dde061dff --- /dev/null +++ b/tensorflow/compiler/xla/python/BUILD @@ -0,0 +1,85 @@ +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//tensorflow:internal"]) + +load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc") + +py_library( + name = "xla_client", + srcs = ["xla_client.py"], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + ":pywrap_xla", + "//tensorflow/compiler/xla:xla_data_proto_py", + ], +) + +py_test( + name = "xla_client_test", + srcs = ["xla_client_test.py"], + main = "xla_client_test.py", + srcs_version = "PY2AND3", + deps = [ + ":xla_client", + "//tensorflow/python:platform_test", + ], +) + +cc_library( + name = "numpy_bridge", + srcs = ["numpy_bridge.cc"], + hdrs = ["numpy_bridge.h"], + deps = [ + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:lib", + "//tensorflow/python:numpy_lib", + ], +) + +cc_library( + name = "local_computation_builder", + srcs = ["local_computation_builder.cc"], + hdrs = ["local_computation_builder.h"], + deps = [ + "//tensorflow/compiler/xla:executable_run_options", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/service:shaped_buffer", + "//tensorflow/core:framework_lite", + "//tensorflow/core:lib", + ], +) + +tf_py_wrap_cc( + name = "pywrap_xla", + srcs = ["xla.i"], + swig_includes = [ + "local_computation_builder.i", + ], + deps = [ + ":local_computation_builder", + ":numpy_bridge", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:cpu_plugin", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/compiler/xla/python/__init__.py b/tensorflow/compiler/xla/python/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc new file mode 100644 index 0000000000000000000000000000000000000000..37f1eada2bc9f5ef72d99a835a17b4e78a354ae6 --- /dev/null +++ b/tensorflow/compiler/xla/python/local_computation_builder.cc @@ -0,0 +1,537 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/python/local_computation_builder.h" +#include "tensorflow/compiler/xla/executable_run_options.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/platform/default/thread_annotations.h" + +namespace xla { + +namespace swig { + +// TODO(b/34473877) Ideally XLA would support AllReduce among arbitrary sets of +// device handles instead of needing to set the number of replicas at XLA +// service initialization time. +tensorflow::mutex g_local_client_mutex(tensorflow::LINKER_INITIALIZED); +int g_replica_count GUARDED_BY(g_local_client_mutex) = 1; +LocalClient* g_local_client GUARDED_BY(g_local_client_mutex) = nullptr; + +Status InitializeReplicaCount(int replica_count) { + if (replica_count < 1) { + return InvalidArgument("Replica count must be >= 1; got %d.", + replica_count); + } + tensorflow::mutex_lock lock(g_local_client_mutex); + if (g_local_client != nullptr) { + return FailedPrecondition( + "Attempted to set the replica count to %d, but a local XLA service was " + "previously created with a replica count of %d.", + replica_count, g_replica_count); + } + g_replica_count = replica_count; + return Status::OK(); +} + +int GetReplicaCount() { + tensorflow::mutex_lock lock(g_local_client_mutex); + return g_replica_count; +} + +LocalClient* GetOrCreateLocalClient() { + tensorflow::mutex_lock lock(g_local_client_mutex); + if (g_local_client != nullptr) { + return g_local_client; + } + LocalClientOptions options; + options.set_number_of_replicas(g_replica_count); + g_local_client = ClientLibrary::GetOrCreateLocalClient(options).ValueOrDie(); + CHECK(g_local_client != nullptr); + return g_local_client; +} + +Status TransferToInfeedLocal(const Literal& literal) { + VLOG(1) << "Infeeding literal without replica number; shape: " + << literal.shape(); + LocalClient* client = GetOrCreateLocalClient(); + return client->TransferToInfeedLocal(literal, /*device_ordinal=*/0); +} + +Status TransferToInfeedLocalReplica(const Literal& literal, + int replica_number) { + VLOG(1) << "Infeeding shape " << literal.shape() + << " to replica number: " << replica_number; + LocalClient* client = GetOrCreateLocalClient(); + TF_ASSIGN_OR_RETURN(int device_ordinal, + client->ReplicaNumberToDeviceOrdinal(replica_number)); + return client->TransferToInfeedLocal(literal, device_ordinal); +} + +StatusOr> TransferFromOutfeedLocalReplica( + const Shape& shape, int replica_number) { + VLOG(1) << "Outfeeding literal from replica number: " << replica_number + << " shape: " << shape; + LocalClient* client = GetOrCreateLocalClient(); + TF_ASSIGN_OR_RETURN(int device_ordinal, + client->ReplicaNumberToDeviceOrdinal(replica_number)); + return client->TransferFromOutfeedLocal(shape, device_ordinal); +} + +LocalShapedBuffer::LocalShapedBuffer( + std::unique_ptr shaped_buffer) + : shaped_buffer_(std::move(shaped_buffer)) {} + +const std::unique_ptr& LocalShapedBuffer::shaped_buffer() + const { + return shaped_buffer_; +} + +/* static */ +LocalShapedBuffer* LocalShapedBuffer::FromLiteral(const Literal& argument) { + LocalClient* client = GetOrCreateLocalClient(); + std::unique_ptr buf = + client + ->LiteralToShapedBuffer(argument, + /*device_ordinal=*/0, + client->backend().memory_allocator()) + .ConsumeValueOrDie(); + return new LocalShapedBuffer(std::move(buf)); +} + +std::unique_ptr LocalShapedBuffer::ToLiteral() const { + LocalClient* client = GetOrCreateLocalClient(); + return client->ShapedBufferToLiteral(*shaped_buffer()).ConsumeValueOrDie(); +} + +CompiledLocalComputation::CompiledLocalComputation( + std::unique_ptr executable) + : executable_(std::move(executable)) {} + +StatusOr> CompiledLocalComputation::Execute( + const std::vector& arguments) { + LocalClient* client = GetOrCreateLocalClient(); + + VLOG(1) << "Execution requested with " << GetReplicaCount() << " replicas."; + + // Each replica populates a StatusOr result, but only replica zero actually + // retrieves its literal value. + std::vector>> results(GetReplicaCount()); + { + tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), "xlarun", + GetReplicaCount()); + + for (int replica = 0; replica < GetReplicaCount(); ++replica) { + pool.Schedule([this, client, replica, &arguments, &results] { + StatusOr device_ordinal_status = + client->ReplicaNumberToDeviceOrdinal(replica); + if (!device_ordinal_status.ok()) { + results[replica] = device_ordinal_status.status(); + return; + } + const int device_ordinal = device_ordinal_status.ValueOrDie(); + VLOG(3) << "Replica " << replica + << " mapped to device ordinal for execution: " + << device_ordinal; + // Transfer arguments in + std::vector> scoped_buffers; + scoped_buffers.reserve(arguments.size()); + for (const Literal& argument : arguments) { + StatusOr> pushed = + client->LiteralToShapedBuffer( + argument, device_ordinal, + client->backend().memory_allocator()); + if (!pushed.ok()) { + results[replica] = pushed.status(); + return; + } + scoped_buffers.push_back(std::move(pushed).ValueOrDie()); + } + + // Execute + std::vector argument_buffers; + argument_buffers.reserve(scoped_buffers.size()); + for (auto& buffer : scoped_buffers) { + argument_buffers.push_back(buffer.get()); + } + + DeviceAssignment device_assignment = + client->backend() + .computation_placer() + ->AssignDevices(GetReplicaCount(), /*computation_count=*/1) + .ConsumeValueOrDie(); + + ExecutableRunOptions options; + options.set_device_ordinal(device_ordinal); + options.set_allocator(client->backend().memory_allocator()); + options.set_inter_op_thread_pool( + client->backend().inter_op_thread_pool()); + options.set_intra_op_thread_pool( + client->backend().eigen_intra_op_thread_pool_device()); + options.set_device_assignment(&device_assignment); + StatusOr> result_buffer_status = + executable_->Run(argument_buffers, options); + if (!result_buffer_status.ok()) { + results[replica] = result_buffer_status.status(); + return; + } + + // Transfer result out + results[replica] = + client->ShapedBufferToLiteral(*result_buffer_status.ValueOrDie()); + }); + } + } + + for (int replica = 0; replica < GetReplicaCount(); ++replica) { + const auto& statusor = results[replica]; + if (!statusor.ok()) { + return InternalError( + "Failed running replica %d (other replicas may have failed as well): " + "%s.", + replica, statusor.status().ToString().c_str()); + } + } + + return std::move(results[0]); +} + +LocalShapedBuffer* CompiledLocalComputation::ExecuteWithShapedBuffers( + tensorflow::gtl::ArraySlice argument_handles) { + LocalClient* client = GetOrCreateLocalClient(); + + std::vector argument_buffers; + argument_buffers.reserve(argument_handles.size()); + for (auto& handle : argument_handles) { + argument_buffers.push_back(handle->shaped_buffer().get()); + } + + // Execute + ExecutableRunOptions options; + options.set_allocator(client->backend().memory_allocator()); + options.set_inter_op_thread_pool(client->backend().inter_op_thread_pool()); + options.set_intra_op_thread_pool( + client->backend().eigen_intra_op_thread_pool_device()); + std::unique_ptr result_buffer = + executable_->Run(argument_buffers, options).ConsumeValueOrDie(); + + return new LocalShapedBuffer(std::move(result_buffer)); +} + +LocalComputation::LocalComputation(Computation computation) + : computation_(std::move(computation)) {} + +StatusOr LocalComputation::Compile( + const std::vector& argument_shapes) { + std::vector argument_shape_pointers; + argument_shape_pointers.reserve(argument_shapes.size()); + for (auto& argument_shape : argument_shapes) { + argument_shape_pointers.push_back(&argument_shape); + } + + LocalClient* client = GetOrCreateLocalClient(); + ExecutableBuildOptions options; + TF_ASSIGN_OR_RETURN( + auto local_executable, + client->Compile(computation_, argument_shape_pointers, options)); + return new CompiledLocalComputation(std::move(local_executable)); +} + +const Computation& LocalComputation::computation() const { + return computation_; +} + +LocalComputationBuilder::LocalComputationBuilder(const string& computation_name) + : builder_(GetOrCreateLocalClient(), computation_name) {} + +void LocalComputationBuilder::SetOpMetadata(const OpMetadata& metadata) { + builder_.SetOpMetadata(metadata); +} + +void LocalComputationBuilder::ClearOpMetadata() { builder_.ClearOpMetadata(); } + +StatusOr LocalComputationBuilder::Build() { + TF_ASSIGN_OR_RETURN(Computation computation, builder_.Build()); + return new LocalComputation(std::move(computation)); +} + +ComputationDataHandle LocalComputationBuilder::Parameter(int64 parameter_number, + const Shape& shape, + const string& name) { + return builder_.Parameter(parameter_number, shape, name); +} + +std::unique_ptr LocalComputationBuilder::GetShape( + const ComputationDataHandle& operand) { + return builder_.GetShape(operand).ConsumeValueOrDie(); +} + +ComputationDataHandle LocalComputationBuilder::Infeed(const Shape& shape) { + return builder_.Infeed(shape); +} + +void LocalComputationBuilder::Outfeed(const ComputationDataHandle& operand, + const Shape& shape, + const string& outfeed_config) { + builder_.Outfeed(operand, shape, outfeed_config); +} + +ComputationDataHandle LocalComputationBuilder::ConstantLiteral( + const Literal& literal) { + return builder_.ConstantLiteral(literal); +} + +ComputationDataHandle LocalComputationBuilder::Broadcast( + const ComputationDataHandle& operand, + tensorflow::gtl::ArraySlice broadcast_sizes) { + return builder_.Broadcast(operand, broadcast_sizes); +} + +ComputationDataHandle LocalComputationBuilder::Pad( + const ComputationDataHandle& operand, + const ComputationDataHandle& padding_value, + const PaddingConfig& padding_config) { + return builder_.Pad(operand, padding_value, padding_config); +} + +ComputationDataHandle LocalComputationBuilder::Reshape( + const ComputationDataHandle& operand, + tensorflow::gtl::ArraySlice dimensions, + tensorflow::gtl::ArraySlice new_sizes) { + return builder_.Reshape(operand, dimensions, new_sizes); +} + +ComputationDataHandle LocalComputationBuilder::Collapse( + const ComputationDataHandle& operand, + tensorflow::gtl::ArraySlice dimensions) { + return builder_.Collapse(operand, dimensions); +} + +ComputationDataHandle LocalComputationBuilder::CrossReplicaSum( + const ComputationDataHandle& operand) { + return builder_.CrossReplicaSum(operand); +} + +ComputationDataHandle LocalComputationBuilder::Slice( + const ComputationDataHandle& operand, + tensorflow::gtl::ArraySlice start_indices, + tensorflow::gtl::ArraySlice limit_indices, + tensorflow::gtl::ArraySlice strides) { + return builder_.Slice(operand, start_indices, limit_indices, strides); +} + +ComputationDataHandle LocalComputationBuilder::DynamicSlice( + const ComputationDataHandle& operand, + const ComputationDataHandle& start_indices, + tensorflow::gtl::ArraySlice slice_sizes) { + return builder_.DynamicSlice(operand, start_indices, slice_sizes); +} + +ComputationDataHandle LocalComputationBuilder::DynamicUpdateSlice( + const ComputationDataHandle& operand, const ComputationDataHandle& update, + const ComputationDataHandle& start_indices) { + return builder_.DynamicUpdateSlice(operand, update, start_indices); +} + +ComputationDataHandle LocalComputationBuilder::ConcatInDim( + tensorflow::gtl::ArraySlice operands, + int64 dimension) { + return builder_.ConcatInDim(operands, dimension); +} + +ComputationDataHandle +LocalComputationBuilder::SelectAndScatterWithGeneralPadding( + const ComputationDataHandle& operand, const LocalComputation& select, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding, + const ComputationDataHandle& source, + const ComputationDataHandle& init_value, const LocalComputation& scatter) { + return builder_.SelectAndScatterWithGeneralPadding( + operand, select.computation(), window_dimensions, window_strides, padding, + source, init_value, scatter.computation()); +} + +ComputationDataHandle LocalComputationBuilder::Select( + const ComputationDataHandle& pred, const ComputationDataHandle& on_true, + const ComputationDataHandle& on_false) { + return builder_.Select(pred, on_true, on_false); +} + +ComputationDataHandle LocalComputationBuilder::Tuple( + tensorflow::gtl::ArraySlice elements) { + return builder_.Tuple(elements); +} + +ComputationDataHandle LocalComputationBuilder::GetTupleElement( + const ComputationDataHandle& tuple_data, int64 index) { + return builder_.GetTupleElement(tuple_data, index); +} + +ComputationDataHandle LocalComputationBuilder::Dot( + const ComputationDataHandle& lhs, const ComputationDataHandle& rhs) { + return builder_.Dot(lhs, rhs); +} + +ComputationDataHandle LocalComputationBuilder::ConvGeneralDilated( + const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding, + tensorflow::gtl::ArraySlice lhs_dilation, + tensorflow::gtl::ArraySlice rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers) { + return builder_.ConvGeneralDilated(lhs, rhs, window_strides, padding, + lhs_dilation, rhs_dilation, + dimension_numbers); +} + +ComputationDataHandle LocalComputationBuilder::ConvertElementType( + const ComputationDataHandle& operand, PrimitiveType new_element_type) { + return builder_.ConvertElementType(operand, new_element_type); +} + +ComputationDataHandle LocalComputationBuilder::Call( + const LocalComputation& local_computation, + tensorflow::gtl::ArraySlice operands) { + return builder_.Call(local_computation.computation(), operands); +} + +ComputationDataHandle LocalComputationBuilder::Transpose( + const ComputationDataHandle& operand, + tensorflow::gtl::ArraySlice permutation) { + return builder_.Transpose(operand, permutation); +} + +ComputationDataHandle LocalComputationBuilder::Rev( + const ComputationDataHandle& operand, + tensorflow::gtl::ArraySlice dimensions) { + return builder_.Rev(operand, dimensions); +} + +ComputationDataHandle LocalComputationBuilder::Map( + tensorflow::gtl::ArraySlice operands, + const LocalComputation& local_computation, + tensorflow::gtl::ArraySlice dimensions, + tensorflow::gtl::ArraySlice static_operands) { + return builder_.Map(operands, local_computation.computation(), dimensions, + static_operands); +} + +ComputationDataHandle LocalComputationBuilder::Reduce( + const ComputationDataHandle& operand, + const ComputationDataHandle& init_value, + const LocalComputation& local_computation, + tensorflow::gtl::ArraySlice dimensions_to_reduce) { + return builder_.Reduce(operand, init_value, local_computation.computation(), + dimensions_to_reduce); +} + +ComputationDataHandle LocalComputationBuilder::ReduceWindowWithGeneralPadding( + const ComputationDataHandle& operand, + const ComputationDataHandle& init_value, + const LocalComputation& local_computation, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding) { + return builder_.ReduceWindowWithGeneralPadding( + operand, init_value, local_computation.computation(), window_dimensions, + window_strides, padding); +} + +ComputationDataHandle LocalComputationBuilder::RngNormal( + const ComputationDataHandle& mu, const ComputationDataHandle& sigma, + const Shape& shape) { + return builder_.RngNormal(mu, sigma, shape); +} + +ComputationDataHandle LocalComputationBuilder::RngUniform( + const ComputationDataHandle& a, const ComputationDataHandle& b, + const Shape& shape) { + return builder_.RngUniform(a, b, shape); +} + +ComputationDataHandle LocalComputationBuilder::While( + const LocalComputation& condition, const LocalComputation& body, + const ComputationDataHandle& init) { + return builder_.While(condition.computation(), body.computation(), init); +} + +#define _FORWARD(method_name, return_sig, args_sig, args) \ + return_sig LocalComputationBuilder::method_name args_sig { \ + return builder_.method_name args; \ + } + +#define _FORWARD_UNOP(method_name) \ + _FORWARD(method_name, ComputationDataHandle, \ + (const ComputationDataHandle& operand), (operand)) + +#define _FORWARD_BINOP(method_name) \ + _FORWARD( \ + method_name, ComputationDataHandle, \ + (const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, \ + tensorflow::gtl::ArraySlice broadcast_dimensions), \ + (lhs, rhs, broadcast_dimensions)) + +_FORWARD_BINOP(Eq) +_FORWARD_BINOP(Ne) +_FORWARD_BINOP(Ge) +_FORWARD_BINOP(Gt) +_FORWARD_BINOP(Lt) +_FORWARD_BINOP(Le) +_FORWARD_BINOP(Add) +_FORWARD_BINOP(Sub) +_FORWARD_BINOP(Mul) +_FORWARD_BINOP(Div) +_FORWARD_BINOP(Rem) +_FORWARD_BINOP(Max) +_FORWARD_BINOP(Min) +_FORWARD_BINOP(And) +_FORWARD_BINOP(Or) +_FORWARD_UNOP(Not) +_FORWARD_UNOP(Abs) +_FORWARD_UNOP(Exp) +_FORWARD_UNOP(Floor) +_FORWARD_UNOP(Ceil) +_FORWARD_UNOP(Log) +_FORWARD_UNOP(Sign) +_FORWARD_UNOP(Cos) +_FORWARD_UNOP(Sin) +_FORWARD_UNOP(Tanh) +_FORWARD_UNOP(SqrtF32) +_FORWARD_UNOP(SquareF32) +_FORWARD_BINOP(Pow) +_FORWARD_UNOP(IsFinite) +_FORWARD_UNOP(ReciprocalF32) +_FORWARD_UNOP(Neg) +_FORWARD_UNOP(Sort) + +#undef _FORWARD +#undef _FORWARD_UNOP +#undef _FORWARD_BINOP + +void DeleteLocalShapedBuffer(LocalShapedBuffer* local_shaped_buffer) { + delete local_shaped_buffer; +} + +void DeleteCompiledLocalComputation(CompiledLocalComputation* computation) { + delete computation; +} + +void DeleteLocalComputation(LocalComputation* computation) { + delete computation; +} + +} // namespace swig + +} // namespace xla diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h new file mode 100644 index 0000000000000000000000000000000000000000..e5503cd52fa60eff30eea38c83aafe0f0ff1efc8 --- /dev/null +++ b/tensorflow/compiler/xla/python/local_computation_builder.h @@ -0,0 +1,305 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_COMPUTATION_BUILDER_H_ +#define TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_COMPUTATION_BUILDER_H_ + +#include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/service/shaped_buffer.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/gtl/array_slice.h" + +namespace xla { + +namespace swig { + +// Initializes the number of replicas that XLA will be initialized with (when +// first obtaining a handle to the local XLA service). If this is called after +// the handle to the local XLA service has been established, then an error is +// returned. +Status InitializeReplicaCount(int replica_count); + +// Returns the replica count that is currently set, regardless of whether the +// local XLA service has been instantiated yet or not. +int GetReplicaCount(); + +// Wraps the local client's infeed-transfer function. +// +// The default device ordinal (0) is used. +Status TransferToInfeedLocal(const Literal& literal); + +// Transfers the given literal to the infeed of the given replica. +// +// The replica number is resolved to an appropriate device ordinal. +Status TransferToInfeedLocalReplica(const Literal& literal, int replica_number); + +// Transfers a literal of the given shape from the outfeed of the given replica. +// +// The replica number is resolved to an appropriate device ordinal. +StatusOr > TransferFromOutfeedLocalReplica( + const Shape& shape, int replica_number); + +// Wraps a ScopedShapedBuffer produced by copying a literal "to +// device," i.e. copying a literal to a scoped buffer via the local +// client. +class LocalShapedBuffer { + public: + static LocalShapedBuffer* FromLiteral(const Literal& argument); + LocalShapedBuffer(std::unique_ptr shaped_buffer); + const std::unique_ptr& shaped_buffer() const; + std::unique_ptr ToLiteral() const; + + private: + std::unique_ptr shaped_buffer_; +}; + +// Wraps a LocalExecutable produced by compiling a +// LocalComputation. The Execute method forwards to that of the +// underlying LocalExecutable, and additionally handles tranferring +// arguments and return values in and back out of the client library's +// local client. This class is intended to be made available to Python +// via SWIG. +class CompiledLocalComputation { + public: + CompiledLocalComputation(std::unique_ptr executable); + StatusOr > Execute( + const std::vector& arguments); + LocalShapedBuffer* ExecuteWithShapedBuffers( + tensorflow::gtl::ArraySlice argument_handles); + + private: + std::unique_ptr executable_; +}; + +// Wraps a Computation produced by a LocalComputationBuilder. The +// Compile method compiles the computation to a (local) executable via +// the client library's local client. This class is intended to be +// made available to Python via SWIG. +class LocalComputation { + public: + LocalComputation(Computation computation); + StatusOr Compile( + const std::vector& argument_shapes); + const Computation& computation() const; + + private: + Computation computation_; +}; + +// Wraps the ComputationBuilder API in order to: +// - Support consumption by SWIG in order to be made available to +// Python. +// - Set up the underlying builder to use the client library's +// LocalClient. +// - Wrap Computations in LocalComputations for Python access. +// - Correspondingly unwrap incoming LocalComputations. +class LocalComputationBuilder { + public: + LocalComputationBuilder(const string& computation_name); + + void SetOpMetadata(const OpMetadata& metadata); + void ClearOpMetadata(); + + // Returns an owned LocalComputation to the caller on success. + StatusOr Build(); + + ComputationDataHandle Parameter(int64 parameter_number, const Shape& shape, + const string& name); + + std::unique_ptr GetShape(const ComputationDataHandle& operand); + + ComputationDataHandle Infeed(const Shape& shape); + + void Outfeed(const ComputationDataHandle& operand, const Shape& shape, + const string& outfeed_config); + + ComputationDataHandle ConstantLiteral(const Literal& literal); + + ComputationDataHandle Broadcast( + const ComputationDataHandle& operand, + tensorflow::gtl::ArraySlice broadcast_sizes); + + ComputationDataHandle Pad(const ComputationDataHandle& operand, + const ComputationDataHandle& padding_value, + const PaddingConfig& padding_config); + + ComputationDataHandle Reshape(const ComputationDataHandle& operand, + tensorflow::gtl::ArraySlice dimensions, + tensorflow::gtl::ArraySlice new_sizes); + + ComputationDataHandle Collapse(const ComputationDataHandle& operand, + tensorflow::gtl::ArraySlice dimensions); + + ComputationDataHandle CrossReplicaSum(const ComputationDataHandle& operand); + + ComputationDataHandle Slice(const ComputationDataHandle& operand, + tensorflow::gtl::ArraySlice start_indices, + tensorflow::gtl::ArraySlice limit_indices, + tensorflow::gtl::ArraySlice strides); + + ComputationDataHandle DynamicSlice( + const ComputationDataHandle& operand, + const ComputationDataHandle& start_indices, + tensorflow::gtl::ArraySlice slice_sizes); + + ComputationDataHandle DynamicUpdateSlice( + const ComputationDataHandle& operand, const ComputationDataHandle& update, + const ComputationDataHandle& start_indices); + + ComputationDataHandle ConcatInDim( + tensorflow::gtl::ArraySlice operands, + int64 dimension); + + ComputationDataHandle SelectAndScatterWithGeneralPadding( + const ComputationDataHandle& operand, const LocalComputation& select, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice > padding, + const ComputationDataHandle& source, + const ComputationDataHandle& init_value, const LocalComputation& scatter); + + ComputationDataHandle Select(const ComputationDataHandle& pred, + const ComputationDataHandle& on_true, + const ComputationDataHandle& on_false); + + ComputationDataHandle Tuple( + tensorflow::gtl::ArraySlice elements); + + ComputationDataHandle GetTupleElement(const ComputationDataHandle& tuple_data, + int64 index); + + ComputationDataHandle Dot(const ComputationDataHandle& lhs, + const ComputationDataHandle& rhs); + + ComputationDataHandle ConvGeneralDilated( + const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice > padding, + tensorflow::gtl::ArraySlice lhs_dilation, + tensorflow::gtl::ArraySlice rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers); + + ComputationDataHandle ConvertElementType(const ComputationDataHandle& operand, + PrimitiveType new_element_type); + + ComputationDataHandle Call( + const LocalComputation& local_computation, + tensorflow::gtl::ArraySlice operands); + + ComputationDataHandle Transpose( + const ComputationDataHandle& operand, + tensorflow::gtl::ArraySlice permutation); + + ComputationDataHandle Rev(const ComputationDataHandle& operand, + tensorflow::gtl::ArraySlice dimensions); + + ComputationDataHandle Map( + tensorflow::gtl::ArraySlice operands, + const LocalComputation& local_computation, + tensorflow::gtl::ArraySlice dimensions, + tensorflow::gtl::ArraySlice static_operands); + + ComputationDataHandle Reduce( + const ComputationDataHandle& operand, + const ComputationDataHandle& init_value, + const LocalComputation& local_computation, + tensorflow::gtl::ArraySlice dimensions_to_reduce); + + ComputationDataHandle ReduceWindowWithGeneralPadding( + const ComputationDataHandle& operand, + const ComputationDataHandle& init_value, + const LocalComputation& local_computation, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice > padding); + + ComputationDataHandle RngNormal(const ComputationDataHandle& mu, + const ComputationDataHandle& sigma, + const Shape& shape); + + ComputationDataHandle RngUniform(const ComputationDataHandle& a, + const ComputationDataHandle& b, + const Shape& shape); + + ComputationDataHandle While(const LocalComputation& condition, + const LocalComputation& body, + const ComputationDataHandle& init); + +#define _FORWARD(method_name, return_sig, args_sig) \ + return_sig method_name args_sig; + +#define _FORWARD_UNOP(method_name) \ + _FORWARD(method_name, ComputationDataHandle, \ + (const ComputationDataHandle& operand)) + +#define _FORWARD_BINOP(method_name) \ + _FORWARD( \ + method_name, ComputationDataHandle, \ + (const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, \ + tensorflow::gtl::ArraySlice broadcast_dimensions)) + + _FORWARD_BINOP(Eq) + _FORWARD_BINOP(Ne) + _FORWARD_BINOP(Ge) + _FORWARD_BINOP(Gt) + _FORWARD_BINOP(Lt) + _FORWARD_BINOP(Le) + _FORWARD_BINOP(Add) + _FORWARD_BINOP(Sub) + _FORWARD_BINOP(Mul) + _FORWARD_BINOP(Div) + _FORWARD_BINOP(Rem) + _FORWARD_BINOP(Max) + _FORWARD_BINOP(Min) + _FORWARD_BINOP(And) + _FORWARD_BINOP(Or) + _FORWARD_UNOP(Not) + _FORWARD_UNOP(Abs) + _FORWARD_UNOP(Exp) + _FORWARD_UNOP(Floor) + _FORWARD_UNOP(Ceil) + _FORWARD_UNOP(Log) + _FORWARD_UNOP(Sign) + _FORWARD_UNOP(Cos) + _FORWARD_UNOP(Sin) + _FORWARD_UNOP(Tanh) + _FORWARD_UNOP(SqrtF32) + _FORWARD_UNOP(SquareF32) + _FORWARD_BINOP(Pow) + _FORWARD_UNOP(IsFinite) + _FORWARD_UNOP(ReciprocalF32) + _FORWARD_UNOP(Neg) + _FORWARD_UNOP(Sort) + +#undef _FORWARD +#undef _FORWARD_UNOP +#undef _FORWARD_BINOP + + private: + ComputationBuilder builder_; +}; + +// Functions for freeing resources from the Python side. +void DeleteLocalShapedBuffer(LocalShapedBuffer* local_shaped_buffer); +void DeleteCompiledLocalComputation(CompiledLocalComputation* computation); +void DeleteLocalComputation(LocalComputation* computation); + +} // namespace swig + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_COMPUTATION_BUILDER_H_ diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i new file mode 100644 index 0000000000000000000000000000000000000000..31789259609714e7d20247eec072e05a181715e6 --- /dev/null +++ b/tensorflow/compiler/xla/python/local_computation_builder.i @@ -0,0 +1,719 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// SWIG typemaps and declarations for building, compiling, and +// executing XLA computations, wrapping most of what is declared in +// local_computation_builder.h. +// +// The typemaps below implement/assert the following correspondences +// (with elaborations below): +// +// C++ Python +// -------------------------------------+--------------------------------------- +// ComputationDataHandle <-> int +// ArraySlice <- sequence of int +// ArraySlice <- sequence of int +// Literal <-> (nested tuple of) numpy ndarray +// std::vector <- sequence of (nested tuple of) ndarray +// Shape <-> pair holding (dtype, dimensions) +// std::vector <- sequence of shape information pairs +// PrimitiveType <- int +// ArraySlice> <- sequence of int pairs +// PaddingConfig proto <- corresponding Python proto +// ConvolutionDimensionNumbers proto <- corresponding Python proto +// +// Arrows indicate whether a conversion only ever occurs in one +// direction, or whether it is maintained bidirectionally. +// +// The Python objects corresponding to C++ Literals have the type: +// +// T = ndarray | (T, ...) +// +// where a terminal numpy ndarray translates to a Literal with a +// non-tuple Shape, an XLA primitive element type corresponding to the +// ndarray's dtype. Meanwhile, a non-terminal "tuple of T" translates +// to a tuple-shaped Literal whose tuple components are translated +// recursively. For example, if x is a numpy ndarray in Python, with +// shape (2, 3) and dtype of dtype('float32'), then x translates to a +// Literal with rank 2, dimension 2 and 3, and XLA primitive type +// F32. Meanwhile, +// +// (x, (x, x), (x,)), +// +// translates to a tuple-shaped XLA Literal, whose component subshapes +// are a 2x3 F32-shaped literal followed by two tuple-shaped literals. +// +// The Python objects corresponding to C++ Shapes have the type: +// +// T = (dtype, S) +// S = DIMENSIONS | TUPLE_SHAPES +// DIMENSIONS = (int, ...) +// TUPLE_SHAPES = (T, ...) +// +// In the pair described by the T rule, the terminal dtype determines +// whether S expands as DIMENSIONS or TUPLE_SHAPES. Namely if it is +// dtype('O'), numpy's object dtype, the structure represents a tuple +// shape and the expansion of the non-terminal S is +// TUPLE_SHAPES. Otherwise, dtype describes a primitive element type +// and S expands into DIMENSIONS giving dimension sizes. For example: +// +// (dtype('float32'), (3, 5, 7)) +// +// describes a 3x5x7 array of F32s, and +// +// (dtype('O'), ((dtype('float32'), (2, 3)), +// (dtype('float64'), (4, 5)))) +// +// describes a tuple shape with two subshapes: the first a 2x3 F32, +// and the other a 4x5 F64. +// +// The Python int corresponding to a PrimitiveType enum must be valid +// per xla_data.proto (e.g. xla_data.PRED, xla_data.F32). +// +// The SWIG object wrappers generated by this file are not intended +// for end use, but rather for internal use in the Python XLA client, +// xla_client.py. +// +// One central reason for the Python-side indirection is that the +// Python-side objects produced by the typemaps in this file are +// further packaged up by xla_client before being passed on. For +// instance, xla_client wraps the long produced for a C++ +// ComputationDataHandle in a Python ComputationDataHandle proto, +// rather than exposing a raw long outside of the client. Similarly, +// the Python pair produced for a C++ Shape is further wrapped in a +// Python class (xla_client.Shape) so as not to expose the raw pair +// externally. +// +// Other SWIG object wrappers (e.g. of LocalComputation) are further +// wrapped by xla_client in order to set up a custom destructor that +// triggers memory deallocation on the C++ side. + +%module(threads="1") local_computation_builder + +// Keep the GIL except where explicitly specified. +%nothread; + +%include "tensorflow/python/platform/base.i" + +%{ +// Must be included first +#include "tensorflow/python/lib/core/numpy.h" + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/compiler/xla/python/numpy_bridge.h" +#include "tensorflow/compiler/xla/python/local_computation_builder.h" + +using namespace xla; +using namespace xla::swig; + +namespace xla { +namespace swig { + +bool GetIntAttr(PyObject* o, const char* field, int64* result) { + PyObject* fo = PyObject_GetAttrString(o, field); + if (!fo) { + return false; + } + const int64 value = numpy::PyIntOrPyLongToLong(fo); + if (value == -1 && PyErr_Occurred()) { + Py_DECREF(fo); + return false; + } + Py_DECREF(fo); + *result = value; + return true; +} + +} +} +%} + +// Required to use PyArray_* functions. +%init %{ +tensorflow::ImportNumpy(); +%} + +// ComputationDataHandle + +%typemap(in) const ComputationDataHandle& (ComputationDataHandle temp) { + const int64 handle = numpy::PyIntOrPyLongToLong($input); + if (handle == -1 && PyErr_Occurred()) { + return NULL; + } + temp.set_handle(handle); + $1 = &temp; +} + +%typemap(out) ComputationDataHandle { + $result = numpy::LongToPyIntOrPyLong($1.handle()); +} + +%typemap(out) StatusOr { + if ($1.ok()) { + auto* value = $1.ValueOrDie(); + { + auto* $1 = value; + $typemap(out, xla::swig::CompiledLocalComputation*) + } + } else { + PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); + return NULL; + } +} + +%typemap(out) StatusOr { + if ($1.ok()) { + auto* value = $1.ValueOrDie(); + { + auto* $1 = value; + $typemap(out, xla::swig::LocalComputation*) + } + } else { + PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); + return NULL; + } +} + +%typemap(out) Status { + if (!$1.ok()) { + PyErr_SetString( + PyExc_RuntimeError, $1.ToString().c_str()); + return NULL; + } + $result = Py_None; +} + +// ArraySlice + +%typemap(in) tensorflow::gtl::ArraySlice + (std::vector temps) { + if (!PySequence_Check($input)) { + PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); + return NULL; + } + const int size = PySequence_Size($input); + temps.resize(size); + for (int i = 0; i < size; ++i) { + PyObject* o = PySequence_GetItem($input, i); + PyObject* py_int = numpy::PyNumberToPyInt(o); + if (!py_int) { + PyErr_SetString( + PyExc_TypeError, + "Argument sequence element cannot be converted to int"); + Py_DECREF(o); + return NULL; + } + temps[i] = numpy::PyIntOrPyLongToLong(py_int); + if (temps[i] == -1 && PyErr_Occurred()) { + Py_DECREF(py_int); + Py_DECREF(o); + return NULL; + } + Py_DECREF(py_int); + Py_DECREF(o); + } + $1 = temps; +} + +// ComputationDataHandle + +%typemap(in) tensorflow::gtl::ArraySlice + (std::vector temps) { + if (!PySequence_Check($input)) { + PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); + return NULL; + } + const int size = PySequence_Size($input); + temps.resize(size); + for (int i = 0; i < size; ++i) { + PyObject* o = PySequence_GetItem($input, i); + PyObject* py_int = numpy::PyNumberToPyInt(o); + if (!py_int) { + PyErr_SetString( + PyExc_TypeError, + "Argument sequence element cannot be converted to int"); + return NULL; + } + const int64 handle = numpy::PyIntOrPyLongToLong(py_int); + if (handle == -1 && PyErr_Occurred()) { + Py_DECREF(py_int); + Py_DECREF(o); + return NULL; + } + temps[i].set_handle(handle); + Py_DECREF(py_int); + Py_DECREF(o); + } + $1 = temps; +} + +// LocalShapedBuffer* + +%typemap(in) tensorflow::gtl::ArraySlice + (std::vector temps) { + if (!PySequence_Check($input)) { + PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); + return NULL; + } + const int size = PySequence_Size($input); + temps.reserve(size); + for (int i = 0; i < size; ++i) { + PyObject* o = PySequence_GetItem($input, i); + LocalShapedBuffer* lsbp; + if ((SWIG_ConvertPtr(o, (void**) &lsbp, $descriptor(xla::swig::LocalShapedBuffer*), + SWIG_POINTER_EXCEPTION)) == -1) { + return NULL; + } + temps.push_back(lsbp); + Py_DECREF(o); + } + $1 = temps; +} + +// Literal + +%typemap(in) const Literal& (StatusOr< std::unique_ptr > literal_status) { + literal_status = numpy::XlaLiteralFromPyObject($input); + if (!literal_status.ok()) { + PyErr_SetString(PyExc_RuntimeError, literal_status.status().ToString().c_str()); + return NULL; + } + $1 = literal_status.ValueOrDie().get(); +} + +%typemap(out) std::unique_ptr { + $result = numpy::PyObjectFromXlaLiteral(*$1); +} + +%typemap(out) StatusOr< std::unique_ptr > { + if (!$1.ok()) { + PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); + return NULL; + } + $result = numpy::PyObjectFromXlaLiteral(*$1.ValueOrDie()); +} + +%typemap(in) const std::vector& (std::vector temps) { + if (!PySequence_Check($input)) { + PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); + return NULL; + } + const int size = PySequence_Size($input); + for (int i = 0; i < size; ++i) { + PyObject* o = PySequence_GetItem($input, i); + StatusOr< std::unique_ptr > literal_status = numpy::XlaLiteralFromPyObject(o); + if (!literal_status.ok()) { + PyErr_SetString(PyExc_RuntimeError, literal_status.status().ToString().c_str()); + Py_DECREF(o); + return NULL; + } + temps.push_back(std::move(*literal_status.ConsumeValueOrDie())); + Py_DECREF(o); + } + $1 = &temps; +} + +// OpMetadata + +%typemap(in) const OpMetadata& (OpMetadata temp) { + StatusOr statusor = numpy::OpMetadataFromPyObject($input); + if (!statusor.ok()) { + PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str()); + return NULL; + } + temp = std::move(statusor).ValueOrDie(); + $1 = &temp; +} + +// Shape + +%typemap(in) const Shape& (Shape temp) { + Status shape_status = numpy::CheckPyShapeInfo($input); + if (!shape_status.ok()) { + PyErr_SetString(PyExc_RuntimeError, shape_status.ToString().c_str()); + return NULL; + } + temp = numpy::XlaShapeFromPyShapeInfo($input); + $1 = &temp; +} + +%typemap(out) std::unique_ptr { + $result = numpy::PyShapeInfoFromXlaShape(*$1); +} + +%typemap(in) const std::vector& (std::vector temps) { + if (!PySequence_Check($input)) { + PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); + return NULL; + } + const int size = PySequence_Size($input); + for (int i = 0; i < size; ++i) { + PyObject* o = PySequence_GetItem($input, i); + Status shape_status = numpy::CheckPyShapeInfo(o); + if (!shape_status.ok()) { + PyErr_SetString(PyExc_RuntimeError, shape_status.ToString().c_str()); + Py_DECREF(o); + return NULL; + } + temps.push_back(numpy::XlaShapeFromPyShapeInfo(o)); + Py_DECREF(o); + } + $1 = &temps; +} + +// PrimitiveType + +%typemap(in) PrimitiveType { + PyObject* py_int = numpy::PyNumberToPyInt($input); + if (!py_int) { + PyErr_SetString(PyExc_TypeError, "Argument cannot be converted to int"); + return NULL; + } + const long value = numpy::PyIntOrPyLongToLong(py_int); + if (value == -1 && PyErr_Occurred()) { + Py_DECREF(py_int); + return NULL; + } + if (!PrimitiveType_IsValid(value)) { + PyErr_SetString( + PyExc_TypeError, "Argument not valid for PrimitiveType enum"); + Py_DECREF(py_int); + return NULL; + } + $1 = static_cast(value); +} + +// ArraySlice> + +%typemap(in) tensorflow::gtl::ArraySlice > + (std::vector > temps) { + if (!PySequence_Check($input)) { + PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); + return NULL; + } + const int size = PySequence_Size($input); + temps.reserve(size); + for (int i = 0; i < size; ++i) { + PyObject* o = PySequence_GetItem($input, i); + if (!o) { + return NULL; + } + PyObject* first = PyTuple_GetItem(o, 0); + if (!first) { + Py_DECREF(o); + return NULL; + } + PyObject* first_pyint = numpy::PyNumberToPyInt(first); + if (!first_pyint) { + PyErr_SetString( + PyExc_TypeError, + "First pair item cannot be converted to int"); + Py_DECREF(o); + return NULL; + } + PyObject* second = PyTuple_GetItem(o, 1); + if (!second) { + Py_DECREF(o); + Py_DECREF(first_pyint); + return NULL; + } + PyObject* second_pyint = numpy::PyNumberToPyInt(second); + if (!second_pyint) { + PyErr_SetString( + PyExc_TypeError, + "Second pair item cannot be converted to int"); + Py_DECREF(o); + Py_DECREF(first_pyint); + return NULL; + } + const int64 first_value = numpy::PyIntOrPyLongToLong(first_pyint); + if (first_value == -1 && PyErr_Occurred()) { + Py_DECREF(o); + Py_DECREF(first_pyint); + Py_DECREF(second_pyint); + return NULL; + } + const int64 second_value = numpy::PyIntOrPyLongToLong(second_pyint); + if (second_value == -1 && PyErr_Occurred()) { + Py_DECREF(o); + Py_DECREF(first_pyint); + Py_DECREF(second_pyint); + return NULL; + } + temps.push_back(std::make_pair(first_value, second_value)); + Py_DECREF(o); + } + $1 = temps; +} + +// PaddingConfig + +%typemap(in) const PaddingConfig& + (PaddingConfig padding_config) { + PyObject* dimensions = PyObject_GetAttrString($input, "dimensions"); + if (!dimensions) { + return NULL; + } + + int length = PySequence_Size(dimensions); + if (length == -1) { + Py_DECREF(dimensions); + return NULL; + } + + for (int i = 0; i < length; ++i) { + PyObject* item = PySequence_GetItem(dimensions, i); + if (!item) { + Py_DECREF(dimensions); + return NULL; + } + int64 edge_padding_low, edge_padding_high, interior_padding; + if (!GetIntAttr(item, "edge_padding_low", &edge_padding_low) + || !GetIntAttr(item, "edge_padding_high", &edge_padding_high) + || !GetIntAttr(item, "interior_padding", &interior_padding)) { + Py_DECREF(item); + Py_DECREF(dimensions); + return NULL; + } + Py_DECREF(item); + + PaddingConfig::PaddingConfigDimension* dimension = + padding_config.add_dimensions(); + dimension->set_edge_padding_low(edge_padding_low); + dimension->set_edge_padding_high(edge_padding_high); + dimension->set_interior_padding(interior_padding); + } + Py_DECREF(dimensions); + + $1 = &padding_config; +} + +// ConvolutionDimensionNumbers + +%typemap(in) const ConvolutionDimensionNumbers& + (ConvolutionDimensionNumbers dimension_numbers) { + int64 value; + + if (!GetIntAttr($input, "input_batch_dimension", &value)) { + return NULL; + } + dimension_numbers.set_input_batch_dimension(value); + + if (!GetIntAttr($input, "input_feature_dimension", &value)) { + return NULL; + } + dimension_numbers.set_input_feature_dimension(value); + + if (!GetIntAttr($input, "output_batch_dimension", &value)) { + return NULL; + } + dimension_numbers.set_output_batch_dimension(value); + + if (!GetIntAttr($input, "output_feature_dimension", &value)) { + return NULL; + } + dimension_numbers.set_output_feature_dimension(value); + + if (!GetIntAttr($input, "kernel_output_feature_dimension", &value)) { + return NULL; + } + dimension_numbers.set_kernel_output_feature_dimension(value); + + if (!GetIntAttr($input, "kernel_input_feature_dimension", &value)) { + return NULL; + } + dimension_numbers.set_kernel_input_feature_dimension(value); + + PyObject* o; + int length; + + o = PyObject_GetAttrString($input, "input_spatial_dimensions"); + if (!o) { + return NULL; + } + length = PySequence_Size(o); + if (length == -1) { + Py_DECREF(o); + return NULL; + } + for (int i = 0; i < length; ++i) { + PyObject* item = PySequence_GetItem(o, i); + if (!item) { + Py_DECREF(o); + return NULL; + } + const int64 dimension = numpy::PyIntOrPyLongToLong(item); + if (dimension == -1 && PyErr_Occurred()) { + Py_DECREF(item); + Py_DECREF(o); + return NULL; + } + dimension_numbers.add_input_spatial_dimensions(dimension); + Py_DECREF(item); + } + Py_DECREF(o); + + o = PyObject_GetAttrString($input, "kernel_spatial_dimensions"); + if (!o) { + return NULL; + } + length = PySequence_Size(o); + if (length == -1) { + Py_DECREF(o); + return NULL; + } + for (int i = 0; i < length; ++i) { + PyObject* item = PySequence_GetItem(o, i); + if (!item) { + Py_DECREF(o); + return NULL; + } + const int64 dimension = numpy::PyIntOrPyLongToLong(item); + if (dimension == -1 && PyErr_Occurred()) { + Py_DECREF(item); + Py_DECREF(o); + return NULL; + } + dimension_numbers.add_kernel_spatial_dimensions(dimension); + Py_DECREF(item); + } + Py_DECREF(o); + + o = PyObject_GetAttrString($input, "output_spatial_dimensions"); + if (!o) { + return NULL; + } + length = PySequence_Size(o); + if (length == -1) { + Py_DECREF(o); + return NULL; + } + for (int i = 0; i < length; ++i) { + PyObject* item = PySequence_GetItem(o, i); + if (!item) { + Py_DECREF(o); + return NULL; + } + const int64 dimension = numpy::PyIntOrPyLongToLong(item); + if (dimension == -1 && PyErr_Occurred()) { + Py_DECREF(item); + Py_DECREF(o); + return NULL; + } + dimension_numbers.add_output_spatial_dimensions(dimension); + Py_DECREF(item); + } + Py_DECREF(o); + + $1 = &dimension_numbers; +} + +%ignoreall +%unignore xla; +%unignore xla::swig; +%unignore xla::swig::InitializeReplicaCount; +%unignore xla::swig::GetReplicaCount; +%unignore xla::swig::TransferToInfeedLocal; +%unignore xla::swig::TransferToInfeedLocalReplica; +%unignore xla::swig::TransferFromOutfeedLocalReplica; +%unignore xla::swig::LocalShapedBuffer; +%unignore xla::swig::LocalShapedBuffer::FromLiteral; +%unignore xla::swig::LocalShapedBuffer::ToLiteral; +%unignore xla::swig::CompiledLocalComputation; +%unignore xla::swig::CompiledLocalComputation::Execute; +%unignore xla::swig::CompiledLocalComputation::ExecuteWithShapedBuffers; +%unignore xla::swig::LocalComputation; +%unignore xla::swig::LocalComputation::Compile; +%unignore xla::swig::LocalComputationBuilder; +%unignore xla::swig::LocalComputationBuilder::LocalComputationBuilder; +%unignore xla::swig::LocalComputationBuilder::Build; +%unignore xla::swig::LocalComputationBuilder::SetOpMetadata; +%unignore xla::swig::LocalComputationBuilder::ClearOpMetadata; +%unignore xla::swig::LocalComputationBuilder::Parameter; +%unignore xla::swig::LocalComputationBuilder::GetShape; +%unignore xla::swig::LocalComputationBuilder::Infeed; +%unignore xla::swig::LocalComputationBuilder::Outfeed; +%unignore xla::swig::LocalComputationBuilder::ConstantLiteral; +%unignore xla::swig::LocalComputationBuilder::ConstantR0; +%unignore xla::swig::LocalComputationBuilder::Broadcast; +%unignore xla::swig::LocalComputationBuilder::Pad; +%unignore xla::swig::LocalComputationBuilder::Reshape; +%unignore xla::swig::LocalComputationBuilder::Collapse; +%unignore xla::swig::LocalComputationBuilder::CrossReplicaSum; +%unignore xla::swig::LocalComputationBuilder::Slice; +%unignore xla::swig::LocalComputationBuilder::DynamicSlice; +%unignore xla::swig::LocalComputationBuilder::DynamicUpdateSlice; +%unignore xla::swig::LocalComputationBuilder::ConcatInDim; +%unignore xla::swig::LocalComputationBuilder::SelectAndScatterWithGeneralPadding; +%unignore xla::swig::LocalComputationBuilder::Select; +%unignore xla::swig::LocalComputationBuilder::Tuple; +%unignore xla::swig::LocalComputationBuilder::GetTupleElement; +%unignore xla::swig::LocalComputationBuilder::ConvertElementType; +%unignore xla::swig::LocalComputationBuilder::Call; +%unignore xla::swig::LocalComputationBuilder::Transpose; +%unignore xla::swig::LocalComputationBuilder::Rev; +%unignore xla::swig::LocalComputationBuilder::Map; +%unignore xla::swig::LocalComputationBuilder::Reduce; +%unignore xla::swig::LocalComputationBuilder::ReduceWindowWithGeneralPadding; +%unignore xla::swig::LocalComputationBuilder::RngNormal; +%unignore xla::swig::LocalComputationBuilder::RngUniform; +%unignore xla::swig::LocalComputationBuilder::RngBernoulli; +%unignore xla::swig::LocalComputationBuilder::While; +%unignore xla::swig::LocalComputationBuilder::Eq; +%unignore xla::swig::LocalComputationBuilder::Ne; +%unignore xla::swig::LocalComputationBuilder::Ge; +%unignore xla::swig::LocalComputationBuilder::Gt; +%unignore xla::swig::LocalComputationBuilder::Lt; +%unignore xla::swig::LocalComputationBuilder::Le; +%unignore xla::swig::LocalComputationBuilder::Dot; +%unignore xla::swig::LocalComputationBuilder::ConvGeneralDilated; +%unignore xla::swig::LocalComputationBuilder::Add; +%unignore xla::swig::LocalComputationBuilder::Sub; +%unignore xla::swig::LocalComputationBuilder::Mul; +%unignore xla::swig::LocalComputationBuilder::Div; +%unignore xla::swig::LocalComputationBuilder::Rem; +%unignore xla::swig::LocalComputationBuilder::Max; +%unignore xla::swig::LocalComputationBuilder::Min; +%unignore xla::swig::LocalComputationBuilder::And; +%unignore xla::swig::LocalComputationBuilder::Or; +%unignore xla::swig::LocalComputationBuilder::Not; +%unignore xla::swig::LocalComputationBuilder::Abs; +%unignore xla::swig::LocalComputationBuilder::Exp; +%unignore xla::swig::LocalComputationBuilder::Floor; +%unignore xla::swig::LocalComputationBuilder::Ceil; +%unignore xla::swig::LocalComputationBuilder::Log; +%unignore xla::swig::LocalComputationBuilder::Sign; +%unignore xla::swig::LocalComputationBuilder::Cos; +%unignore xla::swig::LocalComputationBuilder::Sin; +%unignore xla::swig::LocalComputationBuilder::Tanh; +%unignore xla::swig::LocalComputationBuilder::SqrtF32; +%unignore xla::swig::LocalComputationBuilder::SquareF32; +%unignore xla::swig::LocalComputationBuilder::Pow; +%unignore xla::swig::LocalComputationBuilder::IsFinite; +%unignore xla::swig::LocalComputationBuilder::ReciprocalF32; +%unignore xla::swig::LocalComputationBuilder::Neg; +%unignore xla::swig::LocalComputationBuilder::Sort; +%unignore xla::swig::DeleteLocalShapedBuffer; +%unignore xla::swig::DeleteLocalComputation; +%unignore xla::swig::DeleteCompiledLocalComputation; + +%thread; +%include "tensorflow/compiler/xla/python/local_computation_builder.h" +%nothread; + +%unignoreall diff --git a/tensorflow/compiler/xla/python/numpy_bridge.cc b/tensorflow/compiler/xla/python/numpy_bridge.cc new file mode 100644 index 0000000000000000000000000000000000000000..5c722623e318ece9eca6bdc8750195ce5fd5defb --- /dev/null +++ b/tensorflow/compiler/xla/python/numpy_bridge.cc @@ -0,0 +1,495 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/python/numpy_bridge.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { + +namespace swig { + +namespace numpy { + +int PrimitiveTypeToNumpyType(PrimitiveType primitive_type) { + switch (primitive_type) { + case PRED: + return NPY_BOOL; + case S8: + return NPY_INT8; + case S16: + return NPY_INT16; + case S32: + return NPY_INT32; + case S64: + return NPY_INT64; + case U8: + return NPY_UINT8; + case U16: + return NPY_UINT16; + case U32: + return NPY_UINT32; + case U64: + return NPY_UINT64; + case F16: + return NPY_FLOAT16; + case F32: + return NPY_FLOAT32; + case F64: + return NPY_FLOAT64; + case TUPLE: + return NPY_OBJECT; + default: + LOG(FATAL) << "No Numpy type for XLA primitive type " << primitive_type; + } +} + +PrimitiveType NumpyTypeToPrimitiveType(int np_type) { + switch (np_type) { + case NPY_BOOL: + return PRED; + case NPY_INT8: + return S8; + case NPY_INT16: + return S16; + case NPY_INT32: + return S32; + case NPY_INT64: + return S64; + case NPY_UINT8: + return U8; + case NPY_UINT16: + return U16; + case NPY_UINT32: + return U32; + case NPY_UINT64: + return U64; + case NPY_FLOAT16: + return F16; + case NPY_FLOAT32: + return F32; + case NPY_FLOAT64: + return F64; + case NPY_OBJECT: + return TUPLE; + default: + LOG(FATAL) << "No XLA primitive type for Numpy type " << np_type; + } +} + +bool NumpyTypeIsValid(int np_type) { + switch (np_type) { + case NPY_BOOL: + case NPY_INT8: + case NPY_INT16: + case NPY_INT32: + case NPY_INT64: + case NPY_UINT8: + case NPY_UINT16: + case NPY_UINT32: + case NPY_UINT64: + case NPY_FLOAT16: + case NPY_FLOAT32: + case NPY_FLOAT64: + case NPY_OBJECT: + return true; + default: + return false; + } +} + +PyObject* PyShapeInfoFromXlaShape(const Shape& shape) { + int np_typenum = PrimitiveTypeToNumpyType(shape.element_type()); + PyArray_Descr* np_dtype = PyArray_DescrFromType(np_typenum); + + PyObject* dimensions; + if (ShapeUtil::IsTuple(shape)) { + int num_elements = ShapeUtil::TupleElementCount(shape); + dimensions = PyTuple_New(ShapeUtil::TupleElementCount(shape)); + for (int i = 0; i < num_elements; ++i) { + PyTuple_SET_ITEM( + dimensions, i, + PyShapeInfoFromXlaShape(ShapeUtil::GetTupleElementShape(shape, i))); + } + } else { + int rank = ShapeUtil::Rank(shape); + dimensions = PyTuple_New(rank); + for (int i = 0; i < rank; ++i) { + PyTuple_SET_ITEM(dimensions, i, + LongToPyIntOrPyLong(ShapeUtil::GetDimension(shape, i))); + } + } + return PyTuple_Pack(2, np_dtype, dimensions); +} + +// Precondition: o->ob_type == &PyArrayDescr_Type +static int NumpyTypenum(PyObject* o) { + return reinterpret_cast(o)->type_num; +} + +// Extracts the string held inside r and returns it as a C++ string. +// +// NOTE: this is an internal helper for conversion to a C++, and so decrefs r. +static string ExtractStringAndDecref(PyObject* r) { + auto error = [r] { + return tensorflow::strings::Printf("", r); + }; + if (r == nullptr) { + return error(); + } +#if PY_MAJOR_VERSION < 3 + string result = PyString_AsString(r); +#else + PyObject* bytes = PyUnicode_AsEncodedString(r, 0, 0); + if (bytes == nullptr) { + return error(); + } + CHECK(PyBytes_Check(bytes)); + string result = PyBytes_AsString(bytes); + Py_DECREF(bytes); +#endif + Py_DECREF(r); + return result; +} + +// Safely returns a str of the given Python object o as a C++ string. +static string PyObjectCppStr(PyObject* o) { + PyObject* s = PyObject_Str(o); + return ExtractStringAndDecref(s); +} + +// Safely returns a repr of the given Python object o as a C++ string. +static string PyObjectCppRepr(PyObject* o) { + PyObject* r = PyObject_Repr(o); + return ExtractStringAndDecref(r); +} + +Status CheckPyShapeInfo(PyObject* o) { + auto error = [o](const string& prefix) { + return InvalidArgument("%s; got %s", prefix.c_str(), + PyObjectCppRepr(o).c_str()); + }; + // The object is a tuple (a pair) + if (!PyTuple_Check(o)) { + return error("Shape record must be a tuple"); + } + if (PyTuple_Size(o) != 2) { + return error("Shape record tuple must be of length 2"); + } + + // It has a first element, which is a numpy dtype object + PyObject* first = PyTuple_GetItem(o, 0); + if (first == nullptr) { + return error("Tuple has no item 0 (shape dtype)"); + } + if (first->ob_type != &PyArrayDescr_Type) { + return error( + "Shape record does not have a numpy dtype as its first element"); + } + const int np_type = NumpyTypenum(first); + if (!NumpyTypeIsValid(np_type)) { + return error("Shape record has an invalid integer dtype"); + } + + // It has a second element, which is a tuple, either of shape + // records or of Python ints + PyObject* second = PyTuple_GetItem(o, 1); + if (!second) { + return error("Tuple has no item 0 (shape dimensions)"); + } + if (!PyTuple_Check(second)) { + return error("Shape record does not have a tuple as its second element"); + } + const int length = PyTuple_Size(second); + const PrimitiveType element_type = NumpyTypeToPrimitiveType(np_type); + for (int i = 0; i < length; i++) { + PyObject* dimension = PyTuple_GetItem(second, i); + if (element_type == TUPLE) { + VLOG(3) << "element_type is tuple, checking member: " << i; + Status result = CheckPyShapeInfo(dimension); + if (!result.ok()) { + return AddStatus( + result, tensorflow::strings::StrCat("Validating tuple member ", i, + " of ", PyObjectCppRepr(o))); + } + } else if (!CheckPyIntOrLong(dimension)) { + return error("Non-tuple shape record has a non-integer dimension"); + } + } + + return Status::OK(); +} + +// Precondition: CheckPyShapeInfo(o) +Shape XlaShapeFromPyShapeInfo(PyObject* o) { + const int np_type = NumpyTypenum(PyTuple_GetItem(o, 0)); + const PrimitiveType element_type = NumpyTypeToPrimitiveType(np_type); + PyObject* py_dimensions = PyTuple_GetItem(o, 1); + const int length = PyTuple_Size(py_dimensions); + if (element_type == TUPLE) { + std::vector subshapes; + subshapes.reserve(length); + for (int i = 0; i < length; i++) { + subshapes.push_back( + XlaShapeFromPyShapeInfo(PyTuple_GetItem(py_dimensions, i))); + } + return ShapeUtil::MakeTupleShape(subshapes); + } else { + std::vector dimensions(length); + for (int i = 0; i < length; i++) { + dimensions[i] = PyIntOrPyLongToLong(PyTuple_GetItem(py_dimensions, i)); + if (dimensions[i] == -1) { + CHECK(!PyErr_Occurred()); + } + } + return ShapeUtil::MakeShape(element_type, dimensions); + } +} + +// Helper that retrieves the member with attr_name, stringifies it if is not +// None, and returns it as a C++ string. +static tensorflow::gtl::optional GetAttrAsString( + PyObject* o, const string& attr_name) { + if (!PyObject_HasAttrString(o, attr_name.c_str())) { + return tensorflow::gtl::nullopt; + } + PyObject* attr = PyObject_GetAttrString(o, attr_name.c_str()); + if (attr == Py_None) { + Py_DECREF(attr); + return tensorflow::gtl::nullopt; + } + string result = PyObjectCppStr(attr); + Py_DECREF(attr); + return result; +} + +// Helper that retrieves the member with attr_name, checks that it is an integer +// if it is not None, and returns it as an int32 value. +static tensorflow::gtl::optional GetAttrAsInt32( + PyObject* o, const string& attr_name) { + if (!PyObject_HasAttrString(o, attr_name.c_str())) { + return tensorflow::gtl::nullopt; + } + PyObject* attr = PyObject_GetAttrString(o, attr_name.c_str()); + if (attr == Py_None) { + Py_DECREF(attr); + return tensorflow::gtl::nullopt; + } + if (!CheckPyIntOrLong(attr)) { + Py_DECREF(attr); + return tensorflow::gtl::nullopt; + } + long value = PyIntOrPyLongToLong(attr); // NOLINT + Py_DECREF(attr); + if (value == -1 && PyErr_Occurred() != nullptr) { + return tensorflow::gtl::nullopt; + } + if (static_cast(value) != value) { + return tensorflow::gtl::nullopt; + } + return value; +} + +StatusOr OpMetadataFromPyObject(PyObject* o) { + OpMetadata result; + tensorflow::gtl::optional op_type = GetAttrAsString(o, "op_type"); + if (op_type.has_value()) { + result.set_op_type(op_type.value()); + } + tensorflow::gtl::optional op_name = GetAttrAsString(o, "op_name"); + if (op_name.has_value()) { + result.set_op_name(op_name.value()); + } + tensorflow::gtl::optional source_file = + GetAttrAsString(o, "source_file"); + if (source_file.has_value()) { + result.set_source_file(source_file.value()); + } + tensorflow::gtl::optional source_line = + GetAttrAsInt32(o, "source_line"); + if (source_line.has_value()) { + result.set_source_line(source_line.value()); + } + return result; +} + +PyObject* PyObjectFromXlaLiteral(const Literal& literal) { + if (ShapeUtil::IsTuple(literal.shape())) { + int num_elements = ShapeUtil::TupleElementCount(literal.shape()); + PyObject* tuple = PyTuple_New(num_elements); + for (int i = 0; i < num_elements; i++) { + PyTuple_SET_ITEM( + tuple, i, PyObjectFromXlaLiteral(LiteralView::Create(literal, {i}))); + } + return tuple; + } else { + int rank = ShapeUtil::Rank(literal.shape()); + std::vector dimensions(rank); // NOLINT - PyArray requires a long* + for (int i = 0; i < rank; i++) { + dimensions[i] = ShapeUtil::GetDimension(literal.shape(), i); + } + int np_type = PrimitiveTypeToNumpyType(literal.shape().element_type()); + PyObject* array = + PyArray_EMPTY(rank, dimensions.data(), np_type, /*fortran=*/0); + CopyLiteralToNumpyArray(np_type, literal, + reinterpret_cast(array)); + return array; + } +} + +StatusOr> XlaLiteralFromPyObject(PyObject* o) { + if (PyTuple_Check(o)) { + int num_elements = PyTuple_Size(o); + std::vector> elements; + elements.reserve(num_elements); + for (int i = 0; i < num_elements; i++) { + PyObject* element = PyTuple_GetItem(o, i); + TF_ASSIGN_OR_RETURN(auto literal, XlaLiteralFromPyObject(element)); + elements.push_back(std::move(literal)); + } + return Literal::MakeTupleOwned(std::move(elements)); + } else if (PyArray_Check(o)) { + PyArrayObject* py_array = reinterpret_cast(o); + int rank = PyArray_NDIM(py_array); + std::vector dimensions(rank); + for (int i = 0; i < rank; i++) { + dimensions[i] = PyArray_DIM(py_array, i); + } + int np_type = PyArray_TYPE(py_array); + auto literal = Literal::CreateFromDimensions( + NumpyTypeToPrimitiveType(np_type), dimensions); + TF_RETURN_IF_ERROR( + CopyNumpyArrayToLiteral(np_type, py_array, literal.get())); + return std::move(literal); + } else { + return InvalidArgument( + "Non-tuple or Numpy array encountered in conversion to XLA literal."); + } +} + +Status CopyNumpyArrayToLiteral(int np_type, PyArrayObject* py_array, + Literal* literal) { + switch (np_type) { + case NPY_BOOL: + CopyNumpyArrayToLiteral(py_array, literal); + break; + case NPY_INT32: + CopyNumpyArrayToLiteral(py_array, literal); + break; + case NPY_INT64: + CopyNumpyArrayToLiteral(py_array, literal); + break; + case NPY_UINT8: + CopyNumpyArrayToLiteral(py_array, literal); + break; + case NPY_UINT32: + CopyNumpyArrayToLiteral(py_array, literal); + break; + case NPY_UINT64: + CopyNumpyArrayToLiteral(py_array, literal); + break; + case NPY_FLOAT16: + CopyNumpyArrayToLiteral(py_array, literal); + break; + case NPY_FLOAT32: + CopyNumpyArrayToLiteral(py_array, literal); + break; + case NPY_FLOAT64: + CopyNumpyArrayToLiteral(py_array, literal); + break; + default: + return InvalidArgument( + "No XLA literal container for Numpy type number: %d", np_type); + } + return Status::OK(); +} + +void CopyLiteralToNumpyArray(int np_type, const Literal& literal, + PyArrayObject* py_array) { + switch (np_type) { + case NPY_BOOL: + CopyLiteralToNumpyArray(literal, py_array); + break; + case NPY_INT32: + CopyLiteralToNumpyArray(literal, py_array); + break; + case NPY_INT64: + CopyLiteralToNumpyArray(literal, py_array); + break; + case NPY_UINT8: + CopyLiteralToNumpyArray(literal, py_array); + break; + case NPY_UINT32: + CopyLiteralToNumpyArray(literal, py_array); + break; + case NPY_UINT64: + CopyLiteralToNumpyArray(literal, py_array); + break; + case NPY_FLOAT16: + CopyLiteralToNumpyArray(literal, py_array); + break; + case NPY_FLOAT32: + CopyLiteralToNumpyArray(literal, py_array); + break; + case NPY_FLOAT64: + CopyLiteralToNumpyArray(literal, py_array); + break; + default: + LOG(FATAL) << "No XLA literal container for Numpy type" << np_type; + } +} + +PyObject* LongToPyIntOrPyLong(long x) { // NOLINT +#if PY_MAJOR_VERSION < 3 + return PyInt_FromLong(x); +#else + return PyLong_FromLong(x); +#endif +} + +long PyIntOrPyLongToLong(PyObject* o) { // NOLINT +#if PY_MAJOR_VERSION < 3 + return PyInt_AsLong(o); +#else + return PyLong_AsLong(o); +#endif +} + +bool CheckPyIntOrLong(PyObject* o) { +#if PY_MAJOR_VERSION < 3 + return PyInt_Check(o); +#else + if (!PyLong_Check(o)) { + return false; + } + int overflow = 0; + PyLong_AsLongAndOverflow(o, &overflow); + return (overflow == 0); +#endif +} + +PyObject* PyNumberToPyInt(PyObject* o) { +#if PY_MAJOR_VERSION < 3 + return PyNumber_Int(o); +#else + return PyNumber_Long(o); +#endif +} + +} // namespace numpy + +} // namespace swig + +} // namespace xla diff --git a/tensorflow/compiler/xla/python/numpy_bridge.h b/tensorflow/compiler/xla/python/numpy_bridge.h new file mode 100644 index 0000000000000000000000000000000000000000..6ff1c34cfc5e0323a6729bdfd5572239f4966211 --- /dev/null +++ b/tensorflow/compiler/xla/python/numpy_bridge.h @@ -0,0 +1,127 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// These functions transform Python/Numpy data structures to XLA data +// structures and vice versa, performing copies where +// appropriate. Python tuples and Numpy ndarrays translate to XLA +// tuples and XLA literals, respectively, and Numpy shape/dtype +// information is translated to XLA shape information. + +#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_NUMPY_BRIDGE_H_ +#define TENSORFLOW_COMPILER_XLA_PYTHON_NUMPY_BRIDGE_H_ + +#include +#include + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/python/lib/core/numpy.h" + +namespace xla { + +namespace swig { + +namespace numpy { + +// Maps XLA primitive types (PRED, S8, F32, ..., and TUPLE) to numpy +// dtypes (NPY_BOOL, NPY_INT8, NPY_FLOAT32, ..., and NPY_OBJECT), and +// vice versa. +int PrimitiveTypeToNumpyType(PrimitiveType primitive_type); +PrimitiveType NumpyTypeToPrimitiveType(int np_type); + +// Determines whether an integer-encoded Numpy dtype is valid, +// i.e. has a supported conversion to an XLA PrimitiveType. +bool NumpyTypeIsValid(int np_type); + +// Converts XLA shape information into a Python pair of the form +// (numpy dtype, dimensions). If the XLA shape represents a tuple, +// then the numpy dtype is NPY_OBJECT ('O') and `dimensions` is a +// Python tuple of shape-description pairs, created +// recursively. Otherwise, `dimensions` is a Python tuple-of-integers +// providing the array dimensions. +// +// The return value is a new reference. +PyObject* PyShapeInfoFromXlaShape(const Shape& shape); + +// Returns the outcome of a best-effort check that the Python object +// is a pair of the form (numpy dtype, dimensions), as produced by +// PyShapeInfoFromXlaShape. +Status CheckPyShapeInfo(PyObject* o); + +// Performs the inverse conversion to that of PyShapeInfoFromXlaShape. +// +// The return value is a new reference. +Shape XlaShapeFromPyShapeInfo(PyObject* o); + +// Converts a PyObject that represents operation metadata into protocol buffer +// form. +StatusOr OpMetadataFromPyObject(PyObject* o); + +// Converts an XLA literal to a Python object, either a Numpy ndarray +// or a nested Python tuple thereof. +// +// To avoid transferring ownership of the data buffers that underlie +// PyArrays and XLA literals, this function makes deep copies of all +// array data. +// +// The return value is a new reference. +PyObject* PyObjectFromXlaLiteral(const Literal& literal); + +// Converts a Numpy ndarray or a nested Python tuple thereof to a +// corresponding XLA literal. +// +// To avoid transferring ownership of the data buffers that underlie +// PyArrays and XLA literals, this function makes deep copies of all +// array data. +StatusOr > XlaLiteralFromPyObject(PyObject* o); + +// The following functions copy array data from the buffers underlying Numpy +// ndarrays into those underlying XLA literals, and vice versa. + +Status CopyNumpyArrayToLiteral(int np_type, PyArrayObject* py_array, + Literal* literal); + +void CopyLiteralToNumpyArray(int np_type, const Literal& literal, + PyArrayObject* py_array); + +template +void CopyNumpyArrayToLiteral(PyArrayObject* py_array, Literal* literal) { + NativeT* source = static_cast(PyArray_DATA(py_array)); + auto dest = literal->data(); + std::copy(source, source + PyArray_SIZE(py_array), dest.data()); +} + +template +void CopyLiteralToNumpyArray(const Literal& literal, PyArrayObject* py_array) { + NativeT* dest = static_cast(PyArray_DATA(py_array)); + auto source = literal.data(); + std::copy(source.begin(), source.end(), dest); +} + +// Workarounds for Python 2 and 3 interop + +PyObject* LongToPyIntOrPyLong(long x); // NOLINT +long PyIntOrPyLongToLong(PyObject* o); // NOLINT +bool CheckPyIntOrLong(PyObject* o); +PyObject* PyNumberToPyInt(PyObject* o); + +} // namespace numpy + +} // namespace swig + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_PYTHON_NUMPY_BRIDGE_H_ diff --git a/tensorflow/compiler/xla/python/xla.i b/tensorflow/compiler/xla/python/xla.i new file mode 100644 index 0000000000000000000000000000000000000000..1c4021a558d3fcff2abfdbdbad7f3928e86ed3b8 --- /dev/null +++ b/tensorflow/compiler/xla/python/xla.i @@ -0,0 +1,18 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +/* XLA-wide SWIG wrapper */ + +%include "tensorflow/compiler/xla/python/local_computation_builder.i" diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py new file mode 100644 index 0000000000000000000000000000000000000000..5455adafcded90dbe38b4c444d2bc03fae445888 --- /dev/null +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -0,0 +1,999 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""An in-process, local XLA client in Python, supporting AOT compilation.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import enum # pylint: disable=g-bad-import-order +import inspect +import itertools +import os + +import numpy as np + +from tensorflow.compiler.xla import xla_data_pb2 +from tensorflow.compiler.xla.python import pywrap_xla as c_api + + +# Most functions are snake_case for consistency with other modules, +# whereas method names of ComputationBuilder and LocalComputation are +# CamelCase for consistency with XLA. +# pylint: disable=invalid-name + + +OpMetadata = collections.namedtuple( + 'OpMetadata', + [ + 'op_type', + 'op_name', + 'source_file', + 'source_line', + ], +) + + +def CurrentSourceInfoMetadata(op_type=None, op_name=None, skip_frames=1): + """Helper for use in source mapping that returns an OpMetadata object.""" + full_filename, lineno = inspect.stack()[skip_frames][1:3] + filename = os.path.basename(full_filename) + return OpMetadata( + op_type=op_type, + op_name=op_name, + source_file=filename, + source_line=lineno) + + +class PaddingType(enum.Enum): + VALID = 1 + SAME = 2 + + +def _convert_padding_type_to_pad_values(padding_type, lhs_dims, rhs_dims, + window_strides): + """Maps PaddingType (VALID or SAME) to pad values (list of pairs of ints).""" + if padding_type == PaddingType.VALID: + return [(0, 0)] * len(window_strides) + + out_shape = np.ceil(np.true_divide(lhs_dims, window_strides)).astype(int) + pad_sizes = [max((out_size - 1) * stride + filter_size - in_size, 0) + for out_size, stride, filter_size, in_size + in zip(out_shape, window_strides, rhs_dims, lhs_dims)] + return [(pad_size // 2, pad_size - pad_size // 2) + for pad_size in pad_sizes] + + +_UNARY_OPS = [ + 'Not', + 'Abs', + 'Exp', + 'Floor', + 'Ceil', + 'Log', + 'Sign', + 'Cos', + 'Sin', + 'Tanh', + 'SqrtF32', + 'SquareF32', + 'IsFinite', + 'ReciprocalF32', + 'Neg', + 'Sort', +] + +_BINARY_OPS = [ + 'Eq', + 'Ne', + 'Ge', + 'Gt', + 'Lt', + 'Le', + 'Add', + 'Sub', + 'Mul', + 'Div', + 'Rem', + 'Max', + 'Min', + 'And', + 'Or', + 'Pow', +] + +XLA_ELEMENT_TYPE_TO_DTYPE = { + xla_data_pb2.F32: np.dtype(np.float32), + xla_data_pb2.F64: np.dtype(np.float64), + xla_data_pb2.S32: np.dtype(np.int32), + xla_data_pb2.S64: np.dtype(np.int64), + xla_data_pb2.U32: np.dtype(np.uint32), + xla_data_pb2.U64: np.dtype(np.uint64), + xla_data_pb2.PRED: np.dtype(np.bool), + xla_data_pb2.TUPLE: np.dtype(np.object), +} + +# Note the conversion on the key. Numpy has a known issue wherein dtype hashing +# doesn't work as expected (https://github.com/numpy/numpy/issues/7242). Thus, +# when keying by dtype in this dict, we use the string form of dtypes. +DTYPE_TO_XLA_ELEMENT_TYPE = { + str(v): k + for k, v in XLA_ELEMENT_TYPE_TO_DTYPE.items() +} + + +class LocalBuffer(object): + """Represents a handle to data owned by XLA. + + The referent is ready for use in executing a local, compiled + Computation. On XLA platforms involving a device (e.g. GPU), this + means the referent is in device memory. + """ + + def __init__(self, c_local_shaped_buffer): + self.c_local_shaped_buffer = c_local_shaped_buffer + self._delete = c_api.DeleteLocalShapedBuffer + + @staticmethod + def from_py(npval): + npval = require_numpy_array_layout(npval) + return LocalBuffer(c_api.LocalShapedBuffer.FromLiteral(npval)) + + def to_py(self): + return self.c_local_shaped_buffer.ToLiteral() + + def delete(self): + if self.c_local_shaped_buffer is not None: + self._delete(self.c_local_shaped_buffer) + self.c_local_shaped_buffer = None + + def is_deleted(self): + return self.c_local_shaped_buffer is None + + def __del__(self): + self.delete() + + +class Shape(object): + """XLA shape. + + Represents an XLA shape by a corresponding Python/Numpy type and a + list of dimensions, which are themselves Shapes in case this one + represents an XLA tuple. + """ + + def __init__(self, np_dtype, dimensions): + self.np_dtype = np_dtype + self._dimensions = dimensions + + def __repr__(self): + return 'xla_client.Shape(np_dtype={!r}, dimensions={!r})'.format( + self.np_dtype, self._dimensions) + + def element_type(self): + return DTYPE_TO_XLA_ELEMENT_TYPE[str(self.np_dtype)] + + def is_tuple(self): + return self.element_type() == xla_data_pb2.TUPLE + + def dimensions(self): + if self.is_tuple(): + raise ValueError('Tuple shape has no dimensions') + return self._dimensions + + def tuple_shapes(self): + if not self.is_tuple(): + raise ValueError('Shape is not a tuple shape') + return self._dimensions + + @staticmethod + def from_numpy(npval): + + def convert(npval): + if isinstance(npval, tuple): + return Shape(np.dtype('O'), tuple(convert(elt) for elt in npval)) + else: + return Shape(npval.dtype, np.shape(npval)) + + return convert(require_numpy_array_layout(npval)) + + +def _wrap_shape(shape_info): + dtype, dims = shape_info + element_type = DTYPE_TO_XLA_ELEMENT_TYPE[str(dtype)] + if element_type == xla_data_pb2.TUPLE: + dims = [_wrap_shape(subshape_info) for subshape_info in dims] + return Shape(dtype, dims) + + +def _unwrap_shape(shape): + if shape.is_tuple(): + components = tuple( + _unwrap_shape(subshape) for subshape in shape.tuple_shapes()) + else: + components = shape.dimensions() + return (shape.np_dtype, components) + + +def _unwrap_shapes(shapes): + return [_unwrap_shape(shape) for shape in shapes] + + +def _wrap_data_handle(handle): + cdh = xla_data_pb2.ComputationDataHandle() + cdh.handle = handle + return cdh + + +def _unwrap_data_handle(handle_proto): + return handle_proto.handle + + +def _unwrap_data_handles(handle_protos): + return [_unwrap_data_handle(cdh) for cdh in handle_protos] + + +def require_numpy_array_layout(value): + if isinstance(value, tuple): + return tuple(require_numpy_array_layout(x) for x in value) + else: + return np.require(value, requirements=['C', 'A']) + + +def transfer_to_infeed(value, replica_number=None): + """Transfers the given value into the XLA infeed queue. + + XLA's infeed queue is a single queue that feeds the "XLA virtual machine" with + a totally ordered stream of values. This is dequeued from XLA computations via + the Infeed() operation. + + Args: + value: the value that the caller would like to enqueue into the XLA infeed + queue + replica_number: the replica number to infeed the value to -- if not + provided, then the default replica (trivially replica 0) is used. + """ + if replica_number is None: + c_api.TransferToInfeedLocal(require_numpy_array_layout(value)) + else: + c_api.TransferToInfeedLocalReplica( + require_numpy_array_layout(value), replica_number) + + +def transfer_from_outfeed(shape, replica_number=None): + """Transfers a literal of the given shape from replica_number's outfeed. + + Args: + shape: The shape of the value to transfer from outfeed. + replica_number: The replica number ordinal to transfer the outfeed value + from. (Each replica has a distinct outfeed queue.) + + Returns: + The literal value that is produced from the outfeed queue. + """ + return c_api.TransferFromOutfeedLocalReplica( + _unwrap_shape(shape), replica_number or 0) + + +class LocalComputation(object): + """Python wrapper for a local XLA Computation. + + A LocalComputation can be executed if it is compiled. Otherwise, it + can still be used as a Computation where required by the + ComputationBuilder methods. + """ + + def __init__(self, c_local_computation, is_compiled): + self.c_local_computation = c_local_computation + self.is_compiled = is_compiled + + # Ensure a reference to C-based destructor for use in __del__. + if is_compiled: + self._delete = c_api.DeleteCompiledLocalComputation + else: + self._delete = c_api.DeleteLocalComputation + + def Compile(self, argument_shapes=()): + if self.is_compiled: + raise ValueError('Attempt to compile a compiled local XLA computation.') + return LocalComputation( + self.c_local_computation.Compile(_unwrap_shapes(argument_shapes)), + is_compiled=True) + + def CompileWithExampleArguments(self, arguments=()): + return self.Compile( + argument_shapes=[Shape.from_numpy(arg) for arg in arguments]) + + def Execute(self, arguments=()): + if not self.is_compiled: + raise ValueError('Cannot execute an uncompiled local XLA computation.') + arguments = tuple(map(require_numpy_array_layout, arguments)) + return self.c_local_computation.Execute(arguments) + + def ExecuteWithLocalBuffers(self, arguments=()): + """Execute with LocalBuffer arguments and return value.""" + if not self.is_compiled: + raise ValueError('Cannot execute an uncompiled local XLA computation.') + arguments = tuple(arguments) + if any(arg.is_deleted() for arg in arguments): + raise ValueError('Executing with deleted local buffer argument') + return LocalBuffer( + self.c_local_computation.ExecuteWithShapedBuffers( + [arg.c_local_shaped_buffer for arg in arguments])) + + def __del__(self): + self._delete(self.c_local_computation) + + +class ComputationBuilder(object): + """XLA computation builder. + + Enqueues XLA ops in sequence and in order to build a + LocalComputation, which in turn can be compiled into a + CompiledLocalComputation, which in turn can be locally executed. + """ + + # The methods of this class map 1-to-1 onto the XLA C++ + # computation builder API. Therefore, there's no need to laboriously list + # arguments and return values for every method, especially where it's obvious. + # + # pylint: disable=g-doc-return-or-yield + # pylint: disable=g-doc-args + + def __init__(self, name): + self._client = c_api.LocalComputationBuilder(name.encode('utf8')) + self._parameter_numbering = itertools.count() + + def Build(self): + return LocalComputation(self._client.Build(), is_compiled=False) + + def SetOpMetadata(self, op_metadata): + """Set metadata for operations that are about to be enqueued.""" + self._client.SetOpMetadata(op_metadata) + + def ClearOpMetadata(self): + """Clear metadata for operations that are about to be enqueued.""" + self._client.ClearOpMetadata() + + def Infeed(self, shape): + """Enqueues an infeed op onto the computation. + + Infeed operations dequeue data of the given shape from the device's infeed + queue for subsequent use in the computation. + + Returns: + A ComputationDataHandle message. + """ + return _wrap_data_handle(self._client.Infeed(_unwrap_shape(shape))) + + def Outfeed(self, operand): + """Enqueues an outfeed op onto the computation. + + Outfeed operations enqueue data, using the given operand, onto the XLA + outfeed queue for subsequent dequeue via the client API. + """ + self._client.Outfeed( + _unwrap_data_handle(operand), _unwrap_shape(self.GetShape(operand)), + ''.encode('utf-8')) + + def Constant(self, value): + """Enqueues a constant op onto the computation. + + Args: + value: value for the constant, as a np.array with an explicit dtype set + to one of the supported types. + + Returns: + A ComputationDataHandle message. + """ + value = require_numpy_array_layout(value) + return _wrap_data_handle(self._client.ConstantLiteral(value)) + + def ConstantF32Scalar(self, value): + """Convenience method to enqueue a scalar F32 constant op. + + Args: + value: a floating-point number. + + Returns: + A ComputationDataHandle message. + """ + return self.Constant(np.array(value, dtype=np.float32)) + + def ConstantF64Scalar(self, value): + """Convenience method to enqueue a scalar F32 constant op. + + Args: + value: a floating-point number. + + Returns: + A ComputationDataHandle message. + """ + return self.Constant(np.array(value, dtype=np.float64)) + + def ConstantS32Scalar(self, value): + """Convenience method to enqueue a scalar S32 constant op. + + Args: + value: a floating-point number. + + Returns: + A ComputationDataHandle message. + """ + return self.Constant(np.array(value, dtype=np.int32)) + + def ConstantS64Scalar(self, value): + """Convenience method to enqueue a scalar S64 constant op. + + Args: + value: a floating-point number. + + Returns: + A ComputationDataHandle message. + """ + return self.Constant(np.array(value, dtype=np.int64)) + + def ConstantPredScalar(self, value): + """Convenience method to enqueue a scalar PRED constant op. + + Args: + value: a boolean value. + + Returns: + A ComputationDataHandle message. + """ + return self.Constant(np.array(value, dtype=np.bool)) + + def ParameterWithShape(self, shape, name=None, parameter_num=None): + """Enqueues a Parameter op onto the computation, given a shape. + + Args: + shape: the parameter's shape as a Shape object. + name: optional string name for the parameter. + parameter_num: parameter number in the computation function. If None, + the next linear parameter number is used. The default value capability + can be used for auto-numbering. If you're using auto-numbering for some + parameters, use it for *all* parameters to avoid clashes. + + Returns: + A ComputationDataHandle message. + """ + if name is None: + name = '' + if parameter_num is None: + parameter_num = next(self._parameter_numbering) + + return _wrap_data_handle( + self._client.Parameter( + parameter_num, _unwrap_shape(shape), name.encode('utf8'))) + + def ParameterFromNumpy(self, value, name=None, parameter_num=None): + """Enqueues a Parameter op onto the computation. + + Args: + value: a Numpy array, or a nested tuple thereof, from which the + shape is inferred. + name: as in ParameterWithShape. + parameter_num: as in ParameterWithShape. + + Returns: + A ComputationDataHandle message. + """ + return self.ParameterWithShape( + Shape.from_numpy(value), name=name, parameter_num=parameter_num) + + def Broadcast(self, operand, sizes): + """Enqueues a broadcast operation onto the computation. + + Args: + operand: the operand ComputationDataHandle to broadcast. + sizes: an iterable of broadcast sizes. + + Returns: + A ComputationDataHandle representing the added broadcast op. + """ + return _wrap_data_handle( + self._client.Broadcast(_unwrap_data_handle(operand), sizes)) + + def Concatenate(self, operands, dimension): + """Enqueues a concatenate operation onto the computation. + + Args: + operands: the operands to concatenate. + dimension: the dimension in which to perform the concatenation. + + Returns: + A ComputationDataHandle representing the added concatenate op. + """ + return _wrap_data_handle( + self._client.ConcatInDim(_unwrap_data_handles(operands), dimension)) + + def ConvertElementType(self, operand, new_element_type): + """Enqueues an element type conversion operation onto the computation. + + Args: + operand: the operand to convert. + new_element_type: the target primitive type. + + Returns: + A ComputationDataHandle representing the added conversion op. + """ + return _wrap_data_handle( + self._client.ConvertElementType( + _unwrap_data_handle(operand), new_element_type)) + + def GetShape(self, operand): + return _wrap_shape(self._client.GetShape(_unwrap_data_handle(operand))) + + def GetComputationStats(self): + raise NotImplementedError() + + def Pad(self, operand, padding_value, padding_config): + """Enqueues a Pad operation onto the computation. + + Args: + operand: ComputationDataHandle representing the array to pad. + padding_value: ComputationDataHandle representing the scalar pad value. + padding_config: either an xla_data_pb2.PaddingConfig or a list of integer + triples (edge_padding_low, edge_padding_high, interior_padding) + representing the configuration of the padding operation. + + Returns: + A ComputationDataHandle representing the added pad op. + """ + if not isinstance(padding_config, xla_data_pb2.PaddingConfig): + padding_config = self._GetPaddingConfigFromTriples(padding_config) + return _wrap_data_handle( + self._client.Pad(_unwrap_data_handle(operand), + _unwrap_data_handle(padding_value), + padding_config)) + + def _GetPaddingConfigFromTriples(self, triples): + """Create PaddingConfig proto from list of triples of integers.""" + padding_config = xla_data_pb2.PaddingConfig() + for lo, hi, interior in triples: + dimension = padding_config.dimensions.add() + dimension.edge_padding_low = lo + dimension.edge_padding_high = hi + dimension.interior_padding = interior + return padding_config + + def Reshape(self, operand, dimensions, new_sizes): + """Reshape op.""" + return _wrap_data_handle( + self._client.Reshape( + _unwrap_data_handle(operand), dimensions, new_sizes)) + + def CrossReplicaSum(self, operand): + """CrossReplicaSum op. + + Args: + operand: the operand to sum across replica instances. + + Returns: + A ComputationDataHandle that has the sum of the value among all replicas. + """ + return _wrap_data_handle( + self._client.CrossReplicaSum(_unwrap_data_handle(operand))) + + def Collapse(self, operand, dimensions): + """Collapse op.""" + return _wrap_data_handle( + self._client.Collapse(_unwrap_data_handle(operand), dimensions)) + + def Trans(self, operand): + """Specialized matrix transpose op.""" + return _wrap_data_handle( + self._client.Transpose(_unwrap_data_handle(operand), [1, 0])) + + def Transpose(self, operand, permutation): + """Transpose op.""" + return _wrap_data_handle( + self._client.Transpose(_unwrap_data_handle(operand), permutation)) + + def Rev(self, operand, dimensions): + """Rev op.""" + return _wrap_data_handle( + self._client.Rev(_unwrap_data_handle(operand), dimensions)) + + def SelectAndScatter(self, operand, select, window_dimensions, window_strides, + padding, source, init_value, scatter): + """Select and scatter op, used by the gradient of ReduceWindow. + + Args: + operand: ComputationDataHandle for array of dimension N and type T over + which the windows slide. + select: Computation of type (T, T) -> Pred to apply to the elements of + each window to indicate which element is selected. + window_dimensions: sequence of N integers for dimensions of the window. + window_strides: sequence of N integers for the strides of the window. + padding: PaddingType representing either 'SAME' or 'VALID ' padding. + source: ComputationDataHandle for array of type T with values to scatter. + init_value: ComputationDataHandle of scalar type T for initial out value. + scatter: Computation of type (T, T) -> T to apply to each scatter source + element with its destination element. + + Returns: + A ComputationDataHandle representing the added SelectAndScatter op. + """ + pads = _convert_padding_type_to_pad_values( + padding, self.GetShape(operand).dimensions(), + window_dimensions, window_strides) + return _wrap_data_handle( + self._client.SelectAndScatterWithGeneralPadding( + _unwrap_data_handle(operand), select.c_local_computation, + window_dimensions, window_strides, pads, + _unwrap_data_handle(source), _unwrap_data_handle(init_value), + scatter.c_local_computation)) + + def Select(self, pred, on_true, on_false): + """Element-wise selection op. + + Constructs an output array from elements of two input arrays, based on the + values of a predicate array. + """ + return _wrap_data_handle( + self._client.Select( + _unwrap_data_handle(pred), + _unwrap_data_handle(on_true), + _unwrap_data_handle(on_false))) + + def Slice(self, operand, start_indices, limit_indices, strides=None): + """Enqueues a slice operation onto the computation. + + Args: + operand: ComputationDataHandle for the N dimensional array to be sliced. + start_indices: iterable of N integers containing the starting indices of + the slice for each dimension. + limit_indices: iterable of N integers containing the ending indices + (exclusive) of the slice for each dimension. + strides: optional iterable of N integers containing the stride sizes for + each dimension. + + Returns: + A ComputationDataHandle representing the added Slice op. + """ + if strides is None: + start_indices = list(start_indices) + strides = [1] * len(start_indices) + return _wrap_data_handle( + self._client.Slice( + _unwrap_data_handle(operand), + start_indices, + limit_indices, + strides)) + + def DynamicSlice(self, operand, start_indices, slice_sizes): + """Enqueues a slice op with dynamic start indices onto the computation. + + Args: + operand: ComputationDataHandle for the N dimensional array to be sliced. + start_indices: ComputationDataHandle for the 1D array of N integers + containing the starting indices of the slice. + slice_sizes: iterable of N integers containing the slice sizes in each + dimension. + + Returns: + A ComputationDataHandle representing the added DynamicSlice op. + """ + return _wrap_data_handle( + self._client.DynamicSlice( + _unwrap_data_handle(operand), + _unwrap_data_handle(start_indices), + slice_sizes)) + + def DynamicUpdateSlice(self, operand, update, start_indices): + """Enqueues a dynamic update slice operation onto the computation. + + Args: + operand: ComputationDataHandle for the N dimensional array to be updated. + update: N dimensional array comprising the slice update. + start_indices: Rank-1 array of N integers comprising the starting indices + of the slice along each dimension. + Returns: + A ComputationDataHandle representing the added DynamicUpdateSlice op. + """ + return _wrap_data_handle( + self._client.DynamicUpdateSlice( + _unwrap_data_handle(operand), + _unwrap_data_handle(update), + _unwrap_data_handle(start_indices))) + + def Tuple(self, *ops): + """Enqueues a tuple operation onto the computation. + + Args: + ops: a sequence of tuple operands (each a ComputationDataHandle). + + Returns: + A ComputationDataHandle representing the added Tuple op. + """ + return _wrap_data_handle(self._client.Tuple(_unwrap_data_handles(ops))) + + def GetTupleElement(self, tup, index): + """Enqueues a 'get tuple element' operation onto the computation. + + Args: + tup: the tuple operand (a ComputationDataHandle). + index: numeric index to select from the tuple. + + Returns: + A ComputationDataHandle representing the added GetTupleElement op. + """ + return _wrap_data_handle( + self._client.GetTupleElement(_unwrap_data_handle(tup), index)) + + def Call(self, computation_to_apply, operands): + """Enqueues a call operation onto the computation. + + Args: + computation_to_apply: a Computation object. + operands: an iterable of ComputationDataHandle. The number and types of + operands must match the arity of computation_to_apply. + + Returns: + A ComputationDataHandle representing the added call op. + """ + return _wrap_data_handle( + self._client.Call(computation_to_apply.c_local_computation, + _unwrap_data_handles(operands))) + + def Map(self, operands, computation_to_apply, dimensions, static_operands=()): + """Enqueues a map operation onto the computation. + + Args: + operands: an iterable of ComputationDataHandle. + computation_to_apply: a Computation object. + dimensions: dimensions over which to apply map the function. + static_operands: auxiliary arguments passed to the applied computation. + + Returns: + A ComputationDataHandle representing the added Map op. + """ + return _wrap_data_handle( + self._client.Map( + _unwrap_data_handles(operands), + computation_to_apply.c_local_computation, + dimensions, + _unwrap_data_handles(static_operands))) + + def Reduce(self, operand, init_value, computation_to_apply, dimensions): + """Enqueues a reduction operation onto the computation. + + Args: + operand: reduction operand (ComputationDataHandle). + init_value: reduction initial value (ComputationDataHandle). + computation_to_apply: a Computation object - binary reduction function. + dimensions: sequence of dimensions (integers) to reduce on. + + Returns: + A ComputationDataHandle representing the added Reduce op. + """ + return _wrap_data_handle( + self._client.Reduce( + _unwrap_data_handle(operand), + _unwrap_data_handle(init_value), + computation_to_apply.c_local_computation, + dimensions)) + + def ReduceWindow(self, operand, init_value, computation_to_apply, + window_dimensions, window_strides, padding): + """Enqueues a windowed reduction operation onto the computation. + + Args: + operand: reduction operand (ComputationDataHandle). + init_value: reduction initial value (ComputationDataHandle). + computation_to_apply: a binary reduction function (Computation). + window_dimensions: dimensions of window (sequence of integers). + window_strides: strides for window (sequence of integers). + padding: PaddingType representing either 'SAME' or 'VALID' padding. + + Returns: + A ComputationDataHandle representing the added ReduceWindow op. + """ + pads = _convert_padding_type_to_pad_values( + padding, self.GetShape(operand).dimensions(), window_dimensions, + window_strides) + return _wrap_data_handle( + self._client.ReduceWindowWithGeneralPadding( + _unwrap_data_handle(operand), + _unwrap_data_handle(init_value), + computation_to_apply.c_local_computation, + window_dimensions, window_strides, pads)) + + def RngNormal(self, mu, sigma, dims): + """Enqueues an RngNormal operation onto the computation. + + Args: + mu: A ComputationDataHandle to an F32 scalar specifying the mean. + sigma: A ComputationDataHandle to an F32 scalar specifying the standard + deviation. + dims: A 1D array-like of nonnegative integers specifying the dimensions. + + Returns: a ComputationDataHandle to the generated array of F32 values. + """ + shape = Shape(self.GetShape(mu).np_dtype, dims) + return _wrap_data_handle( + self._client.RngNormal( + _unwrap_data_handle(mu), _unwrap_data_handle(sigma), + _unwrap_shape(shape))) + + def RngUniform(self, a, b, dims): + """Enqueues an RngUniform operation onto the computation. + + Args: + a: a ComputationDataHandle to an F32, S32, or U32 scalar (consistent with + the type of b) specifying the low end of the interval [a, b) over which + values are generated. + b: a ComputationDataHandle to an F32, S32, or U32 scalar (consistent with + the type of a) specifying the high end of the interval [a, b) over which + values are generated. + dims: A 1D array-like of nonnegative integers specifying the dimensions. + + Returns: a ComputationDataHandle to the generated array of values with the + same numeric type (F32, S32, or U32) as the arguments a and b. + """ + shape = Shape(self.GetShape(a).np_dtype, dims) + return _wrap_data_handle( + self._client.RngUniform( + _unwrap_data_handle(a), _unwrap_data_handle(b), + _unwrap_shape(shape))) + + def While(self, cond, body, init): + """Enqueues a While operation onto the computation. + + Args: + cond: a Computation for the loop condition, which has type T -> PRED + body: a Computation for the loop body, which has type T -> T + init: an ComputationDataHandle for the initial parameter, which has type T + + Returns: a ComputationDataHandle representing the While operation. + """ + return _wrap_data_handle( + self._client.While(cond.c_local_computation, + body.c_local_computation, + _unwrap_data_handle(init))) + + def Dot(self, lhs, rhs): + """Matrix multiplication between lhs and rhs.""" + return _wrap_data_handle( + self._client.Dot(_unwrap_data_handle(lhs), _unwrap_data_handle(rhs))) + + def Conv(self, lhs, rhs, window_strides, padding): + """Enqueues a Conv operation onto the computation. + + Args: + lhs: ComputationDataHandle for the rank N+2 array of inputs. + rhs: ComputationDataHandle for the rank N+2 array of kernel weights. + window_strides: length-N array-like of integer kernel strides. + padding: PaddingType representing either 'SAME' or 'VALID' padding. + + Returns: a ComputationDataHandle representing the Conv operation. + """ + pads = _convert_padding_type_to_pad_values( + padding, self.GetShape(lhs).dimensions()[2:], + self.GetShape(rhs).dimensions()[2:], window_strides) + dimension_numbers = self._GetConvDimensionNumbers(len(window_strides)) + return _wrap_data_handle( + self._client.ConvGeneralDilated(_unwrap_data_handle(lhs), + _unwrap_data_handle(rhs), + window_strides, + pads, + (), + (), + dimension_numbers)) + + def ConvWithGeneralPadding(self, lhs, rhs, window_strides, padding, + lhs_dilation, rhs_dilation): + """Enqueues a ConvWithGeneralPadding operation onto the computation. + + Args: + lhs: ComputationDataHandle for the rank N+2 array of inputs. + rhs: ComputationDataHandle for the rank N+2 array of kernel weights. + window_strides: length-N array-like of kernel strides. + padding: length-N array-like of pairs of integers of (low, high) padding. + lhs_dilation: length-N array-like of dilation factors. + rhs_dilation: length-N array-like of dilation factors. + + Returns: + A ComputationdataHandle representing the added ConvWithGeneralPadding op. + """ + dimension_numbers = self._GetConvDimensionNumbers(len(window_strides)) + return _wrap_data_handle( + self._client.ConvGeneralDilated(_unwrap_data_handle(lhs), + _unwrap_data_handle(rhs), + window_strides, + padding, + lhs_dilation, + rhs_dilation, + dimension_numbers)) + + def _GetConvDimensionNumbers(self, num_spatial_dims): + """Create ConvolutionDimensionNumbers proto for convolutions.""" + nd = num_spatial_dims + dimension_numbers = xla_data_pb2.ConvolutionDimensionNumbers() + dimension_numbers.input_batch_dimension = 0 + dimension_numbers.input_feature_dimension = 1 + dimension_numbers.output_batch_dimension = 0 + dimension_numbers.output_feature_dimension = 1 + dimension_numbers.kernel_output_feature_dimension = 0 + dimension_numbers.kernel_input_feature_dimension = 1 + dimension_numbers.input_spatial_dimensions.extend(range(2, 2 + nd)) + dimension_numbers.kernel_spatial_dimensions.extend(range(2, 2 + nd)) + dimension_numbers.output_spatial_dimensions.extend(range(2, 2 + nd)) + return dimension_numbers + + +def _forward_methods_to_local_builder(): + """Forward remaining ComputationBuilder methods to the C API. + + Set up methods, corresponding to unary and binary XLA operations, + whose calls are forwarded in a boilerplate manner to the underlying + LocalComputationBuilder C-extension API. + """ + + def forward_to_local_builder_with_handles(target_method, is_binop=False): + """Generate a forwarding method that wraps/unwraps data handles.""" + + def forward(self, *args, **kwargs): + unwrapped_args = [_unwrap_data_handle(arg) for arg in args] + + if is_binop and len(unwrapped_args) < 3: + unwrapped_args.append(kwargs.get('broadcast_dimensions', ())) + + return _wrap_data_handle( + target_method( + self._client, # pylint: disable=protected-access + *unwrapped_args)) + + return forward + + for method_name in _UNARY_OPS: + forward = forward_to_local_builder_with_handles( + getattr(c_api.LocalComputationBuilder, method_name)) + forward.__name__ = method_name + setattr(ComputationBuilder, method_name, forward) + + for method_name in _BINARY_OPS: + forward = forward_to_local_builder_with_handles( + getattr(c_api.LocalComputationBuilder, method_name), is_binop=True) + forward.__name__ = method_name + setattr(ComputationBuilder, method_name, forward) + + +_forward_methods_to_local_builder() + + +def initialize_replica_count(replica_count): + """Initializes the desired replica count to use on XLA service init. + + Args: + replica_count: number of replicas that are desired for set up during XLA + initalization. + + Raises: + A runtime exception if the XLA service has already been initialized. + """ + c_api.InitializeReplicaCount(replica_count) + + +def get_replica_count(): + """Returns the current replica count used for the XLA service. + + Note: this will return a value whether the XLA service has been initialized + yet or not. + """ + return c_api.GetReplicaCount() diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py new file mode 100644 index 0000000000000000000000000000000000000000..c0413b9bbc3b7f8b63e4cf7a8f24980322cffc47 --- /dev/null +++ b/tensorflow/compiler/xla/python/xla_client_test.py @@ -0,0 +1,1223 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the Python extension-based XLA client.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import itertools +import threading + +import numpy as np + +from tensorflow.compiler.xla.python import xla_client +import unittest + + +class LocalComputationTest(unittest.TestCase): + """Base class for running an XLA Computation through the local client.""" + + def _NewComputation(self, name=None): + if name is None: + name = self.id() + return xla_client.ComputationBuilder(name) + + def _Execute(self, c, arguments): + compiled_c = c.Build().CompileWithExampleArguments(arguments) + return compiled_c.Execute(arguments) + + def _ExecuteAndAssertWith(self, assert_func, c, arguments, expected): + assert expected is not None + result = self._Execute(c, arguments) + # Numpy's comparison methods are a bit too lenient by treating inputs as + # "array-like", meaning that scalar 4 will be happily compared equal to + # [[4]]. We'd like to be more strict so assert shapes as well. + self.assertEqual(np.asanyarray(result).shape, np.asanyarray(expected).shape) + assert_func(result, expected) + + def _ExecuteAndCompareExact(self, c, arguments=(), expected=None): + self._ExecuteAndAssertWith(np.testing.assert_equal, c, arguments, expected) + + def _ExecuteAndCompareClose(self, c, arguments=(), expected=None): + self._ExecuteAndAssertWith(np.testing.assert_allclose, c, arguments, + expected) + + +def NumpyArrayF32(*args, **kwargs): + """Convenience wrapper to create Numpy arrays with a np.float32 dtype.""" + return np.array(*args, dtype=np.float32, **kwargs) + + +def NumpyArrayF64(*args, **kwargs): + """Convenience wrapper to create Numpy arrays with a np.float64 dtype.""" + return np.array(*args, dtype=np.float64, **kwargs) + + +def NumpyArrayS32(*args, **kwargs): + """Convenience wrapper to create Numpy arrays with a np.int32 dtype.""" + return np.array(*args, dtype=np.int32, **kwargs) + + +def NumpyArrayS64(*args, **kwargs): + """Convenience wrapper to create Numpy arrays with a np.int64 dtype.""" + return np.array(*args, dtype=np.int64, **kwargs) + + +def NumpyArrayBool(*args, **kwargs): + """Convenience wrapper to create Numpy arrays with a np.bool dtype.""" + return np.array(*args, dtype=np.bool, **kwargs) + + +class ComputationsWithConstantsTest(LocalComputationTest): + """Tests focusing on Constant ops.""" + + def testConstantScalarSumF32(self): + c = self._NewComputation() + c.Add(c.ConstantF32Scalar(1.11), c.ConstantF32Scalar(3.14)) + self._ExecuteAndCompareClose(c, expected=4.25) + + def testConstantScalarSumF64(self): + c = self._NewComputation() + c.Add(c.ConstantF64Scalar(1.11), c.ConstantF64Scalar(3.14)) + self._ExecuteAndCompareClose(c, expected=4.25) + + def testConstantScalarSumS32(self): + c = self._NewComputation() + c.Add(c.ConstantS32Scalar(1), c.ConstantS32Scalar(2)) + self._ExecuteAndCompareClose(c, expected=3) + + def testConstantScalarSumS64(self): + c = self._NewComputation() + c.Add(c.ConstantS64Scalar(1), c.ConstantS64Scalar(2)) + self._ExecuteAndCompareClose(c, expected=3) + + def testConstantVectorMulF32(self): + c = self._NewComputation() + c.Mul( + c.Constant(NumpyArrayF32([2.5, 3.3, -1.2, 0.7])), + c.Constant(NumpyArrayF32([-1.2, 2, -2, -3]))) + self._ExecuteAndCompareClose(c, expected=[-3, 6.6, 2.4, -2.1]) + + def testConstantVectorMulF64(self): + c = self._NewComputation() + c.Mul( + c.Constant(NumpyArrayF64([2.5, 3.3, -1.2, 0.7])), + c.Constant(NumpyArrayF64([-1.2, 2, -2, -3]))) + self._ExecuteAndCompareClose(c, expected=[-3, 6.6, 2.4, -2.1]) + + def testConstantVectorScalarDivF32(self): + c = self._NewComputation() + c.Div( + c.Constant(NumpyArrayF32([1.5, 2.5, 3.0, -10.8])), + c.ConstantF32Scalar(2.0)) + self._ExecuteAndCompareClose(c, expected=[0.75, 1.25, 1.5, -5.4]) + + def testConstantVectorScalarDivF64(self): + c = self._NewComputation() + c.Div( + c.Constant(NumpyArrayF64([1.5, 2.5, 3.0, -10.8])), + c.ConstantF64Scalar(2.0)) + self._ExecuteAndCompareClose(c, expected=[0.75, 1.25, 1.5, -5.4]) + + def testConstantVectorScalarPowF32(self): + c = self._NewComputation() + c.Pow(c.Constant(NumpyArrayF32([1.5, 2.5, 3.0])), c.ConstantF32Scalar(2.)) + self._ExecuteAndCompareClose(c, expected=[2.25, 6.25, 9.]) + + def testConstantVectorScalarPowF64(self): + c = self._NewComputation() + c.Pow(c.Constant(NumpyArrayF64([1.5, 2.5, 3.0])), c.ConstantF64Scalar(2.)) + self._ExecuteAndCompareClose(c, expected=[2.25, 6.25, 9.]) + + def testBooleanAnd(self): + c = self._NewComputation() + c.And( + c.Constant(NumpyArrayBool([True, False, True, False])), + c.Constant(NumpyArrayBool([True, True, False, False]))) + self._ExecuteAndCompareExact(c, expected=[True, False, False, False]) + + def testBooleanOr(self): + c = self._NewComputation() + c.Or( + c.Constant(NumpyArrayBool([True, False, True, False])), + c.Constant(NumpyArrayBool([True, True, False, False]))) + self._ExecuteAndCompareExact(c, expected=[True, True, True, False]) + + def testSum2DF32(self): + c = self._NewComputation() + c.Add( + c.Constant(NumpyArrayF32([[1, 2, 3], [4, 5, 6]])), + c.Constant(NumpyArrayF32([[1, -1, 1], [-1, 1, -1]]))) + self._ExecuteAndCompareClose(c, expected=[[2, 1, 4], [3, 6, 5]]) + + def testSum2DF64(self): + c = self._NewComputation() + c.Add( + c.Constant(NumpyArrayF64([[1, 2, 3], [4, 5, 6]])), + c.Constant(NumpyArrayF64([[1, -1, 1], [-1, 1, -1]]))) + self._ExecuteAndCompareClose(c, expected=[[2, 1, 4], [3, 6, 5]]) + + def testSum2DWith1DBroadcastDim0F32(self): + # sum of a 2D array with a 1D array where the latter is replicated across + # dimension 0 to match the former's shape. + c = self._NewComputation() + c.Add( + c.Constant(NumpyArrayF32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), + c.Constant(NumpyArrayF32([10, 20, 30])), + broadcast_dimensions=(0,)) + self._ExecuteAndCompareClose( + c, expected=[[11, 12, 13], [24, 25, 26], [37, 38, 39]]) + + def testSum2DWith1DBroadcastDim0F64(self): + # sum of a 2D array with a 1D array where the latter is replicated across + # dimension 0 to match the former's shape. + c = self._NewComputation() + c.Add( + c.Constant(NumpyArrayF64([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), + c.Constant(NumpyArrayF64([10, 20, 30])), + broadcast_dimensions=(0,)) + self._ExecuteAndCompareClose( + c, expected=[[11, 12, 13], [24, 25, 26], [37, 38, 39]]) + + def testSum2DWith1DBroadcastDim1F32(self): + # sum of a 2D array with a 1D array where the latter is replicated across + # dimension 1 to match the former's shape. + c = self._NewComputation() + c.Add( + c.Constant(NumpyArrayF32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), + c.Constant(NumpyArrayF32([10, 20, 30])), + broadcast_dimensions=(1,)) + self._ExecuteAndCompareClose( + c, expected=[[11, 22, 33], [14, 25, 36], [17, 28, 39]]) + + def testSum2DWith1DBroadcastDim1F64(self): + # sum of a 2D array with a 1D array where the latter is replicated across + # dimension 1 to match the former's shape. + c = self._NewComputation() + c.Add( + c.Constant(NumpyArrayF64([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), + c.Constant(NumpyArrayF64([10, 20, 30])), + broadcast_dimensions=(1,)) + self._ExecuteAndCompareClose( + c, expected=[[11, 22, 33], [14, 25, 36], [17, 28, 39]]) + + def testConstantAxpyF32(self): + c = self._NewComputation() + c.Add( + c.Mul( + c.ConstantF32Scalar(2), + c.Constant(NumpyArrayF32([2.2, 3.3, 4.4, 5.5]))), + c.Constant(NumpyArrayF32([100, -100, 200, -200]))) + self._ExecuteAndCompareClose(c, expected=[104.4, -93.4, 208.8, -189]) + + def testConstantAxpyF64(self): + c = self._NewComputation() + c.Add( + c.Mul( + c.ConstantF64Scalar(2), + c.Constant(NumpyArrayF64([2.2, 3.3, 4.4, 5.5]))), + c.Constant(NumpyArrayF64([100, -100, 200, -200]))) + self._ExecuteAndCompareClose(c, expected=[104.4, -93.4, 208.8, -189]) + + +class ParametersTest(LocalComputationTest): + """Tests focusing on Parameter ops and argument-passing.""" + + def setUp(self): + self.f32_scalar_2 = NumpyArrayF32(2.0) + self.f32_4vector = NumpyArrayF32([-2.3, 3.3, -4.3, 5.3]) + self.f64_scalar_2 = NumpyArrayF64(2.0) + self.f64_4vector = NumpyArrayF64([-2.3, 3.3, -4.3, 5.3]) + self.s32_scalar_3 = NumpyArrayS32(3) + self.s32_4vector = NumpyArrayS32([10, 15, -2, 7]) + self.s64_scalar_3 = NumpyArrayS64(3) + self.s64_4vector = NumpyArrayS64([10, 15, -2, 7]) + + def testScalarTimesVectorAutonumberF32(self): + c = self._NewComputation() + p0 = c.ParameterFromNumpy(self.f32_scalar_2) + p1 = c.ParameterFromNumpy(self.f32_4vector) + c.Mul(p0, p1) + self._ExecuteAndCompareClose( + c, + arguments=[self.f32_scalar_2, self.f32_4vector], + expected=[-4.6, 6.6, -8.6, 10.6]) + + def testScalarTimesVectorAutonumberF64(self): + c = self._NewComputation() + p0 = c.ParameterFromNumpy(self.f64_scalar_2) + p1 = c.ParameterFromNumpy(self.f64_4vector) + c.Mul(p0, p1) + self._ExecuteAndCompareClose( + c, + arguments=[self.f64_scalar_2, self.f64_4vector], + expected=[-4.6, 6.6, -8.6, 10.6]) + + def testScalarTimesVectorS32(self): + c = self._NewComputation() + p0 = c.ParameterFromNumpy(self.s32_scalar_3) + p1 = c.ParameterFromNumpy(self.s32_4vector) + c.Mul(p0, p1) + self._ExecuteAndCompareExact( + c, + arguments=[self.s32_scalar_3, self.s32_4vector], + expected=[30, 45, -6, 21]) + + def testScalarTimesVectorS64(self): + c = self._NewComputation() + p0 = c.ParameterFromNumpy(self.s64_scalar_3) + p1 = c.ParameterFromNumpy(self.s64_4vector) + c.Mul(p0, p1) + self._ExecuteAndCompareExact( + c, + arguments=[self.s64_scalar_3, self.s64_4vector], + expected=[30, 45, -6, 21]) + + def testScalarMinusVectorExplicitNumberingF32(self): + # Use explicit numbering and pass parameter_num first. Sub is used since + # it's not commutative and can help catch parameter reversal within the + # computation. + c = self._NewComputation() + p1 = c.ParameterFromNumpy(self.f32_4vector, parameter_num=1) + p0 = c.ParameterFromNumpy(self.f32_scalar_2, parameter_num=0) + c.Sub(p1, p0) + self._ExecuteAndCompareClose( + c, + arguments=[self.f32_scalar_2, self.f32_4vector], + expected=[-4.3, 1.3, -6.3, 3.3]) + + def testScalarMinusVectorExplicitNumberingF64(self): + # Use explicit numbering and pass parameter_num first. Sub is used since + # it's not commutative and can help catch parameter reversal within the + # computation. + c = self._NewComputation() + p1 = c.ParameterFromNumpy(self.f64_4vector, parameter_num=1) + p0 = c.ParameterFromNumpy(self.f64_scalar_2, parameter_num=0) + c.Sub(p1, p0) + self._ExecuteAndCompareClose( + c, + arguments=[self.f64_scalar_2, self.f64_4vector], + expected=[-4.3, 1.3, -6.3, 3.3]) + + +class LocalBufferTest(LocalComputationTest): + """Tests focusing on execution with LocalBuffers.""" + + def _Execute(self, c, arguments): + compiled_c = c.Build().CompileWithExampleArguments(arguments) + arg_buffers = [xla_client.LocalBuffer.from_py(arg) for arg in arguments] + result_buffer = compiled_c.ExecuteWithLocalBuffers(arg_buffers) + return result_buffer.to_py() + + def testConstantSum(self): + c = self._NewComputation() + c.Add(c.ConstantF32Scalar(1.11), c.ConstantF32Scalar(3.14)) + self._ExecuteAndCompareClose(c, expected=4.25) + + def testOneParameterSum(self): + c = self._NewComputation() + c.Add(c.ParameterFromNumpy(NumpyArrayF32(0.)), c.ConstantF32Scalar(3.14)) + self._ExecuteAndCompareClose( + c, + arguments=[NumpyArrayF32(1.11)], + expected=4.25) + + def testTwoParameterSum(self): + c = self._NewComputation() + c.Add(c.ParameterFromNumpy(NumpyArrayF32(0.)), + c.ParameterFromNumpy(NumpyArrayF32(0.))) + self._ExecuteAndCompareClose( + c, + arguments=[NumpyArrayF32(1.11), NumpyArrayF32(3.14)], + expected=4.25) + + def testCannotCallWithDeletedBuffers(self): + c = self._NewComputation() + c.Add(c.ParameterFromNumpy(NumpyArrayF32(0.)), c.ConstantF32Scalar(3.14)) + arg = NumpyArrayF32(1.11) + compiled_c = c.Build().CompileWithExampleArguments([arg]) + arg_buffer = xla_client.LocalBuffer.from_py(arg) + arg_buffer.delete() + with self.assertRaises(ValueError): + compiled_c.ExecuteWithLocalBuffers([arg_buffer]) + + +class SingleOpTest(LocalComputationTest): + """Tests for single ops. + + The goal here is smoke testing - to exercise the most basic functionality of + single XLA ops. As minimal as possible number of additional ops are added + around the op being tested. + """ + + def testConcatenateF32(self): + c = self._NewComputation() + c.Concatenate( + (c.Constant(NumpyArrayF32([1.0, 2.0, 3.0])), + c.Constant(NumpyArrayF32([4.0, 5.0, 6.0]))), + dimension=0) + self._ExecuteAndCompareClose(c, expected=[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) + + def testConcatenateF64(self): + c = self._NewComputation() + c.Concatenate( + (c.Constant(NumpyArrayF64([1.0, 2.0, 3.0])), + c.Constant(NumpyArrayF64([4.0, 5.0, 6.0]))), + dimension=0) + self._ExecuteAndCompareClose(c, expected=[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) + + def testConvertElementType(self): + xla_types = { + np.bool: xla_client.xla_data_pb2.PRED, + np.int32: xla_client.xla_data_pb2.S32, + np.int64: xla_client.xla_data_pb2.S64, + np.float32: xla_client.xla_data_pb2.F32, + np.float64: xla_client.xla_data_pb2.F64, + } + + def _ConvertAndTest(template, src_dtype, dst_dtype): + c = self._NewComputation() + x = c.Constant(np.array(template, dtype=src_dtype)) + c.ConvertElementType(x, xla_types[dst_dtype]) + + result = c.Build().Compile().Execute() + expected = np.array(template, dtype=dst_dtype) + + self.assertEqual(result.shape, expected.shape) + self.assertEqual(result.dtype, expected.dtype) + np.testing.assert_equal(result, expected) + + x = [0, 1, 0, 0, 1] + for src_dtype, dst_dtype in itertools.product(xla_types, xla_types): + _ConvertAndTest(x, src_dtype, dst_dtype) + + def testCrossReplicaSumOneReplica(self): + samples = [ + NumpyArrayF32(42.0), + NumpyArrayF32([97.0]), + NumpyArrayF32([64.0, 117.0]), + NumpyArrayF32([[2.0, 3.0], [4.0, 5.0]]), + ] + for lhs in samples: + c = self._NewComputation() + c.CrossReplicaSum(c.Constant(lhs)) + self._ExecuteAndCompareExact(c, expected=lhs) + + def testDotMatrixVectorF32(self): + c = self._NewComputation() + lhs = NumpyArrayF32([[2.0, 3.0], [4.0, 5.0]]) + rhs = NumpyArrayF32([[10.0], [20.0]]) + c.Dot(c.Constant(lhs), c.Constant(rhs)) + self._ExecuteAndCompareClose(c, expected=np.dot(lhs, rhs)) + + def testDotMatrixVectorF64(self): + c = self._NewComputation() + lhs = NumpyArrayF64([[2.0, 3.0], [4.0, 5.0]]) + rhs = NumpyArrayF64([[10.0], [20.0]]) + c.Dot(c.Constant(lhs), c.Constant(rhs)) + self._ExecuteAndCompareClose(c, expected=np.dot(lhs, rhs)) + + def testDotMatrixMatrixF32(self): + c = self._NewComputation() + lhs = NumpyArrayF32([[2.0, 3.0], [4.0, 5.0]]) + rhs = NumpyArrayF32([[10.0, 20.0], [100.0, 200.0]]) + c.Dot(c.Constant(lhs), c.Constant(rhs)) + self._ExecuteAndCompareClose(c, expected=np.dot(lhs, rhs)) + + def testDotMatrixMatrixF64(self): + c = self._NewComputation() + lhs = NumpyArrayF64([[2.0, 3.0], [4.0, 5.0]]) + rhs = NumpyArrayF64([[10.0, 20.0], [100.0, 200.0]]) + c.Dot(c.Constant(lhs), c.Constant(rhs)) + self._ExecuteAndCompareClose(c, expected=np.dot(lhs, rhs)) + + def testConvF32Same(self): + c = self._NewComputation() + a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32") + lhs = a(1, 2, 3, 4) + rhs = a(1, 2, 1, 2) * 10 + c.Conv(c.Constant(lhs), c.Constant(rhs), + [1, 1], xla_client.PaddingType.SAME) + result = np.array([[[[640., 700., 760., 300.], + [880., 940., 1000., 380.], + [1120., 1180., 1240., 460.]]]]) + self._ExecuteAndCompareClose(c, expected=result) + + def testConvF32Valid(self): + c = self._NewComputation() + a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32") + lhs = a(1, 2, 3, 4) + rhs = a(1, 2, 1, 2) * 10 + c.Conv(c.Constant(lhs), c.Constant(rhs), + [2, 1], xla_client.PaddingType.VALID) + result = np.array([[[[640., 700., 760.], + [1120., 1180., 1240.]]]]) + self._ExecuteAndCompareClose(c, expected=result) + + def testConvWithGeneralPaddingF32(self): + c = self._NewComputation() + a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32") + lhs = a(1, 1, 2, 3) + rhs = a(1, 1, 1, 2) * 10 + strides = [1, 1] + pads = [(1, 0), (0, 1)] + lhs_dilation = (2, 1) + rhs_dilation = (1, 1) + c.ConvWithGeneralPadding(c.Constant(lhs), c.Constant(rhs), + strides, pads, lhs_dilation, rhs_dilation) + result = np.array([[[[0., 0., 0.], + [10., 20., 0.], + [0., 0., 0.], + [40., 50., 0.]]]]) + self._ExecuteAndCompareClose(c, expected=result) + + def testBooleanNot(self): + c = self._NewComputation() + arr = NumpyArrayBool([True, False, True]) + c.Not(c.Constant(arr)) + self._ExecuteAndCompareClose(c, expected=~arr) + + def testExp(self): + c = self._NewComputation() + arr = NumpyArrayF32([3.3, 12.1]) + c.Exp(c.Constant(arr)) + self._ExecuteAndCompareClose(c, expected=np.exp(arr)) + + def testLog(self): + c = self._NewComputation() + arr = NumpyArrayF32([3.3, 12.1]) + c.Log(c.Constant(arr)) + self._ExecuteAndCompareClose(c, expected=np.log(arr)) + + def testNeg(self): + c = self._NewComputation() + arr = NumpyArrayF32([3.3, 12.1]) + c.Neg(c.Constant(arr)) + self._ExecuteAndCompareClose(c, expected=-arr) + + def testFloor(self): + c = self._NewComputation() + arr = NumpyArrayF32([3.3, 12.1]) + c.Floor(c.Constant(arr)) + self._ExecuteAndCompareClose(c, expected=np.floor(arr)) + + def testCeil(self): + c = self._NewComputation() + arr = NumpyArrayF32([3.3, 12.1]) + c.Ceil(c.Constant(arr)) + self._ExecuteAndCompareClose(c, expected=np.ceil(arr)) + + def testAbs(self): + c = self._NewComputation() + arr = NumpyArrayF32([3.3, -12.1, 2.4, -1.]) + c.Abs(c.Constant(arr)) + self._ExecuteAndCompareClose(c, expected=np.abs(arr)) + + def testTanh(self): + c = self._NewComputation() + arr = NumpyArrayF32([3.3, 12.1]) + c.Tanh(c.Constant(arr)) + self._ExecuteAndCompareClose(c, expected=np.tanh(arr)) + + def testTrans(self): + + def _TransposeAndTest(array): + c = self._NewComputation() + c.Trans(c.Constant(array)) + self._ExecuteAndCompareClose(c, expected=array.T) + + # Test square and non-square matrices in both default (C) and F orders. + for array_fun in [NumpyArrayF32, NumpyArrayF64]: + _TransposeAndTest(array_fun([[1, 2, 3], [4, 5, 6]])) + _TransposeAndTest(array_fun([[1, 2, 3], [4, 5, 6]], order="F")) + _TransposeAndTest(array_fun([[1, 2], [4, 5]])) + _TransposeAndTest(array_fun([[1, 2], [4, 5]], order="F")) + + def testTranspose(self): + + def _TransposeAndTest(array, permutation): + c = self._NewComputation() + c.Transpose(c.Constant(array), permutation) + expected = np.transpose(array, permutation) + self._ExecuteAndCompareClose(c, expected=expected) + + _TransposeAndTest(NumpyArrayF32([[1, 2, 3], [4, 5, 6]]), [0, 1]) + _TransposeAndTest(NumpyArrayF32([[1, 2, 3], [4, 5, 6]]), [1, 0]) + _TransposeAndTest(NumpyArrayF32([[1, 2], [4, 5]]), [0, 1]) + _TransposeAndTest(NumpyArrayF32([[1, 2], [4, 5]]), [1, 0]) + + arr = np.random.RandomState(0).randn(2, 3, 4).astype(np.float32) + for permutation in itertools.permutations(range(arr.ndim)): + _TransposeAndTest(arr, permutation) + _TransposeAndTest(np.asfortranarray(arr), permutation) + + def testEq(self): + c = self._NewComputation() + c.Eq( + c.Constant(NumpyArrayS32([1, 2, 3, 4])), + c.Constant(NumpyArrayS32([4, 2, 3, 1]))) + self._ExecuteAndCompareExact(c, expected=[False, True, True, False]) + + def testNe(self): + c = self._NewComputation() + c.Ne( + c.Constant(NumpyArrayS32([1, 2, 3, 4])), + c.Constant(NumpyArrayS32([4, 2, 3, 1]))) + self._ExecuteAndCompareExact(c, expected=[True, False, False, True]) + + c.Ne( + c.Constant(NumpyArrayF32([-2.0, 0.0, + float("nan"), + float("nan")])), + c.Constant(NumpyArrayF32([2.0, -0.0, 1.0, float("nan")]))) + self._ExecuteAndAssertWith( + np.testing.assert_allclose, c, (), expected=[True, False, True, True]) + + def testGt(self): + c = self._NewComputation() + c.Gt( + c.Constant(NumpyArrayS32([1, 2, 3, 4, 9])), + c.Constant(NumpyArrayS32([1, 0, 2, 7, 12]))) + self._ExecuteAndCompareExact(c, expected=[False, True, True, False, False]) + + def testGe(self): + c = self._NewComputation() + c.Ge( + c.Constant(NumpyArrayS32([1, 2, 3, 4, 9])), + c.Constant(NumpyArrayS32([1, 0, 2, 7, 12]))) + self._ExecuteAndCompareExact(c, expected=[True, True, True, False, False]) + + def testLt(self): + c = self._NewComputation() + c.Lt( + c.Constant(NumpyArrayS32([1, 2, 3, 4, 9])), + c.Constant(NumpyArrayS32([1, 0, 2, 7, 12]))) + self._ExecuteAndCompareExact(c, expected=[False, False, False, True, True]) + + def testLe(self): + c = self._NewComputation() + c.Le( + c.Constant(NumpyArrayS32([1, 2, 3, 4, 9])), + c.Constant(NumpyArrayS32([1, 0, 2, 7, 12]))) + self._ExecuteAndCompareExact(c, expected=[True, False, False, True, True]) + + def testMax(self): + c = self._NewComputation() + c.Max( + c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0, 9.0])), + c.Constant(NumpyArrayF32([1.0, 0.0, 2.0, 7.0, 12.0]))) + self._ExecuteAndCompareExact(c, expected=[1.0, 2.0, 3.0, 7.0, 12.0]) + + def testMaxExplicitBroadcastDim0(self): + c = self._NewComputation() + c.Max( + c.Constant(NumpyArrayF32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), + c.Constant(NumpyArrayF32([3, 4, 5])), + broadcast_dimensions=(0,)) + self._ExecuteAndCompareExact(c, expected=[[3, 3, 3], [4, 5, 6], [7, 8, 9]]) + + def testMaxExplicitBroadcastDim1(self): + c = self._NewComputation() + c.Max( + c.Constant(NumpyArrayF32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), + c.Constant(NumpyArrayF32([3, 4, 5])), + broadcast_dimensions=(1,)) + self._ExecuteAndCompareExact(c, expected=[[3, 4, 5], [4, 5, 6], [7, 8, 9]]) + + def testMin(self): + c = self._NewComputation() + c.Min( + c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0, 9.0])), + c.Constant(NumpyArrayF32([1.0, 0.0, 2.0, 7.0, 12.0]))) + self._ExecuteAndCompareExact(c, expected=[1.0, 0.0, 2.0, 4.0, 9.0]) + + def testPad(self): + c = self._NewComputation() + c.Pad( + c.Constant(NumpyArrayF32([[1.0, 2.0], [3.0, 4.0]])), + c.Constant(NumpyArrayF32(0.0)), + [(1, 2, 1), (0, 1, 0)]) + self._ExecuteAndCompareClose(c, expected=[[0.0, 0.0, 0.0], + [1.0, 2.0, 0.0], + [0.0, 0.0, 0.0], + [3.0, 4.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0]]) + + def testPadWithPaddingConfig(self): + c = self._NewComputation() + padding_config = xla_client.xla_data_pb2.PaddingConfig() + for lo, hi, interior in [(1, 2, 1), (0, 1, 0)]: + dimension = padding_config.dimensions.add() + dimension.edge_padding_low = lo + dimension.edge_padding_high = hi + dimension.interior_padding = interior + c.Pad( + c.Constant(NumpyArrayF32([[1.0, 2.0], [3.0, 4.0]])), + c.Constant(NumpyArrayF32(0.0)), + padding_config) + self._ExecuteAndCompareClose(c, expected=[[0.0, 0.0, 0.0], + [1.0, 2.0, 0.0], + [0.0, 0.0, 0.0], + [3.0, 4.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0]]) + + def testReshape(self): + c = self._NewComputation() + c.Reshape( + c.Constant(NumpyArrayS32([[1, 2], [3, 4], [5, 6]])), + dimensions=[0, 1], + new_sizes=[2, 3]) + self._ExecuteAndCompareExact(c, expected=[[1, 2, 3], [4, 5, 6]]) + + def testCollapse(self): + c = self._NewComputation() + c.Collapse( + c.Constant(NumpyArrayS32([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])), + dimensions=[1, 2]) + self._ExecuteAndCompareExact(c, expected=[[1, 2, 3, 4], [5, 6, 7, 8]]) + + def testRev(self): + c = self._NewComputation() + c.Rev( + c.Constant(NumpyArrayS32([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])), + dimensions=[0, 2]) + self._ExecuteAndCompareExact( + c, expected=[[[6, 5], [8, 7]], [[2, 1], [4, 3]]]) + + def testSelect(self): + c = self._NewComputation() + c.Select( + c.Constant(NumpyArrayBool([True, False, False, True, False])), + c.Constant(NumpyArrayS32([1, 2, 3, 4, 5])), + c.Constant(NumpyArrayS32([-1, -2, -3, -4, -5]))) + self._ExecuteAndCompareExact(c, expected=[1, -2, -3, 4, -5]) + + def testSlice(self): + c = self._NewComputation() + c.Slice( + c.Constant(NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), [1, 0], + [3, 2]) + self._ExecuteAndCompareExact(c, expected=[[4, 5], [7, 8]]) + + def testDynamicSlice(self): + c = self._NewComputation() + c.DynamicSlice( + c.Constant(NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), + c.Constant(NumpyArrayS32([1, 0])), [2, 2]) + self._ExecuteAndCompareExact(c, expected=[[4, 5], [7, 8]]) + + def testDynamicUpdateSlice(self): + c = self._NewComputation() + c.DynamicUpdateSlice( + c.Constant(NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), + c.Constant(NumpyArrayS32([[1, 2], [3, 4]])), + c.Constant(NumpyArrayS32([1, 1]))) + self._ExecuteAndCompareExact(c, expected=[[1, 2, 3], [4, 1, 2], [7, 3, 4]]) + + def testTuple(self): + c = self._NewComputation() + c.Tuple( + c.ConstantS32Scalar(42), c.Constant(NumpyArrayF32([1.0, 2.0])), + c.Constant(NumpyArrayBool([True, False, False, True]))) + result = c.Build().Compile().Execute() + self.assertIsInstance(result, tuple) + np.testing.assert_equal(result[0], 42) + np.testing.assert_allclose(result[1], [1.0, 2.0]) + np.testing.assert_equal(result[2], [True, False, False, True]) + + def testGetTupleElement(self): + c = self._NewComputation() + c.GetTupleElement( + c.Tuple( + c.ConstantS32Scalar(42), c.Constant(NumpyArrayF32([1.0, 2.0])), + c.Constant(NumpyArrayBool([True, False, False, True]))), 1) + self._ExecuteAndCompareClose(c, expected=[1.0, 2.0]) + + def testBroadcast(self): + c = self._NewComputation() + c.Broadcast(c.Constant(NumpyArrayS32([10, 20, 30, 40])), sizes=(3,)) + self._ExecuteAndCompareExact( + c, expected=[[10, 20, 30, 40], [10, 20, 30, 40], [10, 20, 30, 40]]) + + def testRngNormal(self): + shape = (2, 3) + c = self._NewComputation() + c.RngNormal(c.Constant(NumpyArrayF32(0.)), c.Constant(NumpyArrayF32(1.)), + dims=shape) + result = c.Build().Compile().Execute() + # since the result is random, we just check shape and uniqueness + self.assertEqual(result.shape, shape) + self.assertEqual(len(np.unique(result)), np.prod(shape)) + + def testRngUniformF32(self): + lo, hi = 2., 4. + shape = (2, 3) + c = self._NewComputation() + c.RngUniform(c.Constant(NumpyArrayF32(lo)), c.Constant(NumpyArrayF32(hi)), + dims=shape) + result = c.Build().Compile().Execute() + # since the result is random, we just check shape, uniqueness, and range + self.assertEqual(result.shape, shape) + self.assertEqual(len(np.unique(result)), np.prod(shape)) + self.assertTrue(np.all(lo <= result)) + self.assertTrue(np.all(result < hi)) + + def testRngUniformS32(self): + lo, hi = 2, 4 + shape = (2, 3) + c = self._NewComputation() + c.RngUniform(c.Constant(NumpyArrayS32(lo)), c.Constant(NumpyArrayS32(hi)), + dims=shape) + result = c.Build().Compile().Execute() + # since the result is random, we just check shape, integrality, and range + self.assertEqual(result.shape, shape) + self.assertEqual(result.dtype, np.int32) + self.assertTrue(np.all(lo <= result)) + self.assertTrue(np.all(result < hi)) + + +class EmbeddedComputationsTest(LocalComputationTest): + """Tests for XLA graphs with embedded computations (such as maps).""" + + def _CreateConstantS32Computation(self): + """Computation (f32) -> s32 that returns a constant 1 for any input.""" + c = self._NewComputation("constant_s32_one") + # TODO(eliben): consider adding a nicer way to create new parameters without + # having to create dummy Numpy arrays or populating Shape messages. Perhaps + # we need our own (Python-client-own) way to represent Shapes conveniently. + c.ParameterFromNumpy(NumpyArrayF32(0)) + c.ConstantS32Scalar(1) + return c.Build() + + def _CreateConstantS64Computation(self): + """Computation (f64) -> s64 that returns a constant 1 for any input.""" + c = self._NewComputation("constant_s64_one") + # TODO(eliben): consider adding a nicer way to create new parameters without + # having to create dummy Numpy arrays or populating Shape messages. Perhaps + # we need our own (Python-client-own) way to represent Shapes conveniently. + c.ParameterFromNumpy(NumpyArrayF64(0)) + c.ConstantS64Scalar(1) + return c.Build() + + def _CreateConstantF32Computation(self): + """Computation (f32) -> f32 that returns a constant 1.0 for any input.""" + c = self._NewComputation("constant_f32_one") + c.ParameterFromNumpy(NumpyArrayF32(0)) + c.ConstantF32Scalar(1.0) + return c.Build() + + def _CreateConstantF64Computation(self): + """Computation (f64) -> f64 that returns a constant 1.0 for any input.""" + c = self._NewComputation("constant_f64_one") + c.ParameterFromNumpy(NumpyArrayF64(0)) + c.ConstantF64Scalar(1.0) + return c.Build() + + def _CreateMulF32By2Computation(self): + """Computation (f32) -> f32 that multiplies its parameter by 2.""" + c = self._NewComputation("mul_f32_by2") + c.Mul(c.ParameterFromNumpy(NumpyArrayF32(0)), c.ConstantF32Scalar(2.0)) + return c.Build() + + def _CreateMulF64By2Computation(self): + """Computation (f64) -> f64 that multiplies its parameter by 2.""" + c = self._NewComputation("mul_f64_by2") + c.Mul(c.ParameterFromNumpy(NumpyArrayF64(0)), c.ConstantF64Scalar(2.0)) + return c.Build() + + def _CreateBinaryAddF32Computation(self): + """Computation (f32, f32) -> f32 that adds its two parameters.""" + c = self._NewComputation("add_param0_by_param1") + c.Add( + c.ParameterFromNumpy(NumpyArrayF32(0)), + c.ParameterFromNumpy(NumpyArrayF32(0))) + return c.Build() + + def _CreateBinaryAddF64Computation(self): + """Computation (f64, f64) -> f64 that adds its two parameters.""" + c = self._NewComputation("add_param0_by_param1") + c.Add( + c.ParameterFromNumpy(NumpyArrayF64(0)), + c.ParameterFromNumpy(NumpyArrayF64(0))) + return c.Build() + + def _CreateBinaryDivF32Computation(self): + """Computation (f32, f32) -> f32 that divides its two parameters.""" + c = self._NewComputation("div_param0_by_param1") + c.Div( + c.ParameterFromNumpy(NumpyArrayF32(0)), + c.ParameterFromNumpy(NumpyArrayF32(0))) + return c.Build() + + def _CreateBinaryDivF64Computation(self): + """Computation (f64, f64) -> f64 that divides its two parameters.""" + c = self._NewComputation("div_param0_by_param1") + c.Div( + c.ParameterFromNumpy(NumpyArrayF64(0)), + c.ParameterFromNumpy(NumpyArrayF64(0))) + return c.Build() + + def _CreateTestF32Lt10Computation(self): + """Computation (f32) -> bool that tests if its parameter is less than 10.""" + c = self._NewComputation("test_f32_lt_10") + c.Lt(c.ParameterFromNumpy(NumpyArrayF32(0)), c.ConstantF32Scalar(10.)) + return c.Build() + + def _CreateTestF64Lt10Computation(self): + """Computation (f64) -> bool that tests if its parameter is less than 10.""" + c = self._NewComputation("test_f64_lt_10") + c.Lt(c.ParameterFromNumpy(NumpyArrayF64(0)), c.ConstantF64Scalar(10.)) + return c.Build() + + def _CreateBinaryGeF32Computation(self): + """Computation (f32, f32) -> bool that tests first_param >= second_param.""" + c = self._NewComputation("param0_lt_param1") + c.Ge(c.ParameterFromNumpy(NumpyArrayF32(0)), + c.ParameterFromNumpy(NumpyArrayF32(0))) + return c.Build() + + def _CreateBinaryGeF64Computation(self): + """Computation (f64, f64) -> bool that tests first_param >= second_param.""" + c = self._NewComputation("param0_lt_param1") + c.Ge(c.ParameterFromNumpy(NumpyArrayF64(0)), + c.ParameterFromNumpy(NumpyArrayF64(0))) + return c.Build() + + def _MakeSample3DArrayF32(self): + return NumpyArrayF32([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]], + [[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]]) + + def _MakeSample3DArrayF64(self): + return NumpyArrayF64([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]], + [[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]]) + + def testCallF32(self): + c = self._NewComputation() + c.Call( + self._CreateMulF32By2Computation(), + operands=(c.ConstantF32Scalar(5.0),)) + self._ExecuteAndCompareClose(c, expected=10.0) + + def testCallF64(self): + c = self._NewComputation() + c.Call( + self._CreateMulF64By2Computation(), + operands=(c.ConstantF64Scalar(5.0),)) + self._ExecuteAndCompareClose(c, expected=10.0) + + def testMapEachElementToS32Constant(self): + c = self._NewComputation() + c.Map([c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0]))], + self._CreateConstantS32Computation(), [0]) + self._ExecuteAndCompareExact(c, expected=[1, 1, 1, 1]) + + def testMapEachElementToS64Constant(self): + c = self._NewComputation() + c.Map([c.Constant(NumpyArrayF64([1.0, 2.0, 3.0, 4.0]))], + self._CreateConstantS64Computation(), [0]) + self._ExecuteAndCompareExact(c, expected=[1, 1, 1, 1]) + + def testMapMulBy2F32(self): + c = self._NewComputation() + c.Map([c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0]))], + self._CreateMulF32By2Computation(), [0]) + self._ExecuteAndCompareClose(c, expected=[2.0, 4.0, 6.0, 8.0]) + + def testMapMulBy2F64(self): + c = self._NewComputation() + c.Map([c.Constant(NumpyArrayF64([1.0, 2.0, 3.0, 4.0]))], + self._CreateMulF64By2Computation(), [0]) + self._ExecuteAndCompareClose(c, expected=[2.0, 4.0, 6.0, 8.0]) + + def testSimpleMapChainF32(self): + # Chains a map of constant-f32 with a map of mul-by-2 + c = self._NewComputation() + const_f32 = c.Map([c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0]))], + self._CreateConstantF32Computation(), [0]) + c.Map([const_f32], self._CreateMulF32By2Computation(), [0]) + self._ExecuteAndCompareClose(c, expected=[2.0, 2.0, 2.0, 2.0]) + + def testSimpleMapChainF64(self): + # Chains a map of constant-f64 with a map of mul-by-2 + c = self._NewComputation() + const_f64 = c.Map([c.Constant(NumpyArrayF64([1.0, 2.0, 3.0, 4.0]))], + self._CreateConstantF64Computation(), [0]) + c.Map([const_f64], self._CreateMulF64By2Computation(), [0]) + self._ExecuteAndCompareClose(c, expected=[2.0, 2.0, 2.0, 2.0]) + + def testDivVectorsWithMapF32(self): + c = self._NewComputation() + c.Map((c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0])), + c.Constant(NumpyArrayF32([5.0, 5.0, 4.0, 4.0]))), + self._CreateBinaryDivF32Computation(), [0]) + self._ExecuteAndCompareClose(c, expected=[0.2, 0.4, 0.75, 1.0]) + + def testDivVectorsWithMapF64(self): + c = self._NewComputation() + c.Map((c.Constant(NumpyArrayF64([1.0, 2.0, 3.0, 4.0])), + c.Constant(NumpyArrayF64([5.0, 5.0, 4.0, 4.0]))), + self._CreateBinaryDivF64Computation(), [0]) + self._ExecuteAndCompareClose(c, expected=[0.2, 0.4, 0.75, 1.0]) + + def testSelectAndScatterF32(self): + c = self._NewComputation() + c.SelectAndScatter(c.Constant(NumpyArrayF32([[1., 2., 6.], [4., 5., 3.]])), + select=self._CreateBinaryGeF32Computation(), + window_dimensions=(2, 1), + window_strides=(1, 2), + padding=xla_client.PaddingType.VALID, + source=c.Constant(NumpyArrayF32([[0.1, 0.2]])), + init_value=c.Constant(NumpyArrayF32(1)), + scatter=self._CreateBinaryAddF32Computation()) + self._ExecuteAndCompareClose(c, expected=[[1., 1., 1.2], [1.1, 1., 1.]]) + + def testSelectAndScatterF64(self): + c = self._NewComputation() + c.SelectAndScatter(c.Constant(NumpyArrayF64([[1., 2., 6.], [4., 5., 3.]])), + select=self._CreateBinaryGeF64Computation(), + window_dimensions=(2, 1), + window_strides=(1, 2), + padding=xla_client.PaddingType.VALID, + source=c.Constant(NumpyArrayF64([[0.1, 0.2]])), + init_value=c.Constant(NumpyArrayF64(1)), + scatter=self._CreateBinaryAddF64Computation()) + self._ExecuteAndCompareClose(c, expected=[[1., 1., 1.2], [1.1, 1., 1.]]) + + def testReduce1DtoScalarF32(self): + c = self._NewComputation() + c.Reduce( + operand=c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0])), + init_value=c.ConstantF32Scalar(0), + computation_to_apply=self._CreateBinaryAddF32Computation(), + dimensions=[0]) + self._ExecuteAndCompareClose(c, expected=10) + + def testReduce1DtoScalarF64(self): + c = self._NewComputation() + c.Reduce( + operand=c.Constant(NumpyArrayF64([1.0, 2.0, 3.0, 4.0])), + init_value=c.ConstantF64Scalar(0), + computation_to_apply=self._CreateBinaryAddF64Computation(), + dimensions=[0]) + self._ExecuteAndCompareClose(c, expected=10) + + def testReduce2DTo1DDim0F32(self): + input_array = NumpyArrayF32([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + c = self._NewComputation() + c.Reduce( + operand=c.Constant(input_array), + init_value=c.ConstantF32Scalar(0), + computation_to_apply=self._CreateBinaryAddF32Computation(), + dimensions=[0]) + self._ExecuteAndCompareClose(c, expected=[5, 7, 9]) + + def testReduce2DTo1DDim0F64(self): + input_array = NumpyArrayF64([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + c = self._NewComputation() + c.Reduce( + operand=c.Constant(input_array), + init_value=c.ConstantF64Scalar(0), + computation_to_apply=self._CreateBinaryAddF64Computation(), + dimensions=[0]) + self._ExecuteAndCompareClose(c, expected=[5, 7, 9]) + + def testReduce2DTo1DDim1F32(self): + input_array = NumpyArrayF32([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + c = self._NewComputation() + c.Reduce( + operand=c.Constant(input_array), + init_value=c.ConstantF32Scalar(0), + computation_to_apply=self._CreateBinaryAddF32Computation(), + dimensions=[1]) + self._ExecuteAndCompareClose(c, expected=[6, 15]) + + def testReduce2DTo1DDim1F64(self): + input_array = NumpyArrayF64([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + c = self._NewComputation() + c.Reduce( + operand=c.Constant(input_array), + init_value=c.ConstantF64Scalar(0), + computation_to_apply=self._CreateBinaryAddF64Computation(), + dimensions=[1]) + self._ExecuteAndCompareClose(c, expected=[6, 15]) + + def testReduce3DAllPossibleWaysF32(self): + input_array = self._MakeSample3DArrayF32() + + def _ReduceAndTest(*dims): + c = self._NewComputation() + c.Reduce( + operand=c.Constant(input_array), + init_value=c.ConstantF32Scalar(0), + computation_to_apply=self._CreateBinaryAddF32Computation(), + dimensions=dims) + self._ExecuteAndCompareClose( + c, expected=np.sum(input_array, axis=tuple(dims))) + + _ReduceAndTest(0) + _ReduceAndTest(0) + _ReduceAndTest(0, 1) + _ReduceAndTest(0, 2) + _ReduceAndTest(1, 2) + _ReduceAndTest(0, 1, 2) + + def testReduce3DAllPossibleWaysF64(self): + input_array = self._MakeSample3DArrayF64() + + def _ReduceAndTest(*dims): + c = self._NewComputation() + c.Reduce( + operand=c.Constant(input_array), + init_value=c.ConstantF64Scalar(0), + computation_to_apply=self._CreateBinaryAddF64Computation(), + dimensions=dims) + self._ExecuteAndCompareClose( + c, expected=np.sum(input_array, axis=tuple(dims))) + + _ReduceAndTest(0) + _ReduceAndTest(0) + _ReduceAndTest(0, 1) + _ReduceAndTest(0, 2) + _ReduceAndTest(1, 2) + _ReduceAndTest(0, 1, 2) + + def testReduceWindowValidUnitStridesF32(self): + input_array = NumpyArrayF32([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + c = self._NewComputation() + c.ReduceWindow(operand=c.Constant(input_array), + init_value=c.ConstantF32Scalar(0), + computation_to_apply=self._CreateBinaryAddF32Computation(), + window_dimensions=(2, 1), window_strides=(1, 1), + padding=xla_client.PaddingType.VALID) + self._ExecuteAndCompareClose(c, expected=[[5., 7., 9.]]) + + def testReduceWindowSameUnitStridesF32(self): + input_array = NumpyArrayF32([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + c = self._NewComputation() + c.ReduceWindow(operand=c.Constant(input_array), + init_value=c.ConstantF32Scalar(0), + computation_to_apply=self._CreateBinaryAddF32Computation(), + window_dimensions=(2, 1), window_strides=(1, 1), + padding=xla_client.PaddingType.SAME) + self._ExecuteAndCompareClose(c, expected=[[5., 7., 9.], [4., 5., 6.]]) + + def testReduceWindowValidGeneralStridesF32(self): + input_array = NumpyArrayF32([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + c = self._NewComputation() + c.ReduceWindow(operand=c.Constant(input_array), + init_value=c.ConstantF32Scalar(0), + computation_to_apply=self._CreateBinaryAddF32Computation(), + window_dimensions=(2, 1), window_strides=(1, 2), + padding=xla_client.PaddingType.VALID) + self._ExecuteAndCompareClose(c, expected=[[5., 9.]]) + + def testReduceWindowValidUnitStridesF64(self): + input_array = NumpyArrayF64([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + c = self._NewComputation() + c.ReduceWindow(operand=c.Constant(input_array), + init_value=c.ConstantF64Scalar(0), + computation_to_apply=self._CreateBinaryAddF64Computation(), + window_dimensions=(2, 1), window_strides=(1, 1), + padding=xla_client.PaddingType.VALID) + self._ExecuteAndCompareClose(c, expected=[[5., 7., 9.]]) + + def testReduceWindowSameUnitStridesF64(self): + input_array = NumpyArrayF64([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + c = self._NewComputation() + c.ReduceWindow(operand=c.Constant(input_array), + init_value=c.ConstantF64Scalar(0), + computation_to_apply=self._CreateBinaryAddF64Computation(), + window_dimensions=(2, 1), window_strides=(1, 1), + padding=xla_client.PaddingType.SAME) + self._ExecuteAndCompareClose(c, expected=[[5., 7., 9.], [4., 5., 6.]]) + + def testReduceWindowValidGeneralStridesF64(self): + input_array = NumpyArrayF64([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + c = self._NewComputation() + c.ReduceWindow(operand=c.Constant(input_array), + init_value=c.ConstantF64Scalar(0), + computation_to_apply=self._CreateBinaryAddF64Computation(), + window_dimensions=(2, 1), window_strides=(1, 2), + padding=xla_client.PaddingType.VALID) + self._ExecuteAndCompareClose(c, expected=[[5., 9.]]) + + def testWhileF32(self): + cond = self._CreateTestF32Lt10Computation() + body = self._CreateMulF32By2Computation() + c = self._NewComputation() + init = c.ConstantF32Scalar(1.) + c.While(cond, body, init) + self._ExecuteAndCompareClose(c, expected=16.) + + def testWhileF64(self): + cond = self._CreateTestF64Lt10Computation() + body = self._CreateMulF64By2Computation() + c = self._NewComputation() + init = c.ConstantF64Scalar(1.) + c.While(cond, body, init) + self._ExecuteAndCompareClose(c, expected=16.) + + def testInfeedS32Values(self): + to_infeed = NumpyArrayS32([1, 2, 3, 4]) + c = self._NewComputation() + c.Infeed(xla_client.Shape.from_numpy(to_infeed[0])) + compiled_c = c.Build().CompileWithExampleArguments() + for item in to_infeed: + xla_client.transfer_to_infeed(item) + + for item in to_infeed: + result = compiled_c.Execute() + self.assertEqual(result, item) + + def testInfeedThenOutfeedS32(self): + to_round_trip = NumpyArrayS32([1, 2, 3, 4]) + c = self._NewComputation() + x = c.Infeed(xla_client.Shape.from_numpy(to_round_trip[0])) + c.Outfeed(x) + + compiled_c = c.Build().CompileWithExampleArguments() + + for want in to_round_trip: + execution = threading.Thread(target=compiled_c.Execute) + execution.start() + xla_client.transfer_to_infeed(want) + got = xla_client.transfer_from_outfeed( + xla_client.Shape.from_numpy(to_round_trip[0])) + execution.join() + self.assertEqual(want, got) + + +class ErrorTest(LocalComputationTest): + + def setUp(self): + self.f32_scalar_2 = NumpyArrayF32(2.0) + self.s32_scalar_2 = NumpyArrayS32(2) + + def testInvokeWithWrongElementType(self): + c = self._NewComputation() + c.SetOpMetadata(xla_client.CurrentSourceInfoMetadata()) + c.ParameterFromNumpy(self.s32_scalar_2) + c.ClearOpMetadata() + self.assertRaisesRegexp( + RuntimeError, r"Invalid argument shape.*xla_client_test.py.*" + r"expected s32\[\], got f32\[\]", + lambda: c.Build().CompileWithExampleArguments([self.f32_scalar_2])) + + +if __name__ == "__main__": + unittest.main() diff --git a/tensorflow/compiler/xla/reference_util.cc b/tensorflow/compiler/xla/reference_util.cc index 5bb81b80dde4c6d9324d33ddd5d6b6d6ad3cc1ac..a9acdae380af5b7f9efb3d08302fc717108f5e40 100644 --- a/tensorflow/compiler/xla/reference_util.cc +++ b/tensorflow/compiler/xla/reference_util.cc @@ -195,14 +195,26 @@ ReferenceUtil::ReduceWindow1DGeneric( const tensorflow::gtl::ArraySlice& window, const tensorflow::gtl::ArraySlice& stride, Padding padding) { std::vector dim_lengths{static_cast(operand.size())}; - auto padding_both = xla::MakePadding(dim_lengths, window, stride, padding); + return ReduceWindow1DGeneric( + operand, init, reduce_func, window, stride, + xla::MakePadding(dim_lengths, window, stride, padding)); +} +/* static */ std::unique_ptr> +ReferenceUtil::ReduceWindow1DGeneric( + const tensorflow::gtl::ArraySlice& operand, float init, + const std::function& reduce_func, + const tensorflow::gtl::ArraySlice& window, + const tensorflow::gtl::ArraySlice& stride, + const tensorflow::gtl::ArraySlice>& padding) { + std::vector dim_lengths{static_cast(operand.size())}; std::vector window_counts(window.size(), 0); std::vector pad_low(window.size(), 0); for (int64 i = 0; i < window.size(); ++i) { + int64 padded_width = padding[i].first + dim_lengths[i] + padding[i].second; window_counts[i] = - WindowCount(dim_lengths[i], window[i], stride[i], padding); - pad_low[i] = padding_both[i].first; + window_util::StridedBound(padded_width, window[i], stride[i]); + pad_low[i] = padding[i].first; } auto result = MakeUnique>(window_counts[0]); @@ -269,6 +281,51 @@ ReferenceUtil::ReduceWindow1DAdd( return result; } +/* static */ std::unique_ptr> ReferenceUtil::ReduceWindow3DAdd( + const Array3D& operand, float init, + const tensorflow::gtl::ArraySlice& window, + const tensorflow::gtl::ArraySlice& stride, Padding padding) { + std::vector dim_lengths{operand.n1(), operand.n2(), operand.n3()}; + auto padding_both = xla::MakePadding(dim_lengths, window, stride, padding); + + std::vector window_counts(window.size(), 0); + std::vector pad_low(window.size(), 0); + for (int64 i = 0; i < window.size(); ++i) { + window_counts[i] = + WindowCount(dim_lengths[i], window[i], stride[i], padding); + pad_low[i] = padding_both[i].first; + } + auto result = MakeUnique>(window_counts[0], window_counts[1], + window_counts[2]); + + for (int64 i0 = 0; i0 < window_counts[0]; ++i0) { + for (int64 i1 = 0; i1 < window_counts[1]; ++i1) { + for (int64 i2 = 0; i2 < window_counts[2]; ++i2) { + int64 i0_base = i0 * stride[0] - pad_low[0]; + int64 i1_base = i1 * stride[1] - pad_low[1]; + int64 i2_base = i2 * stride[2] - pad_low[2]; + + float val = init; + for (int64 i0_win = 0; i0_win < window[0]; ++i0_win) { + for (int64 i1_win = 0; i1_win < window[1]; ++i1_win) { + for (int64 i2_win = 0; i2_win < window[2]; ++i2_win) { + if (i0_base + i0_win >= 0 && i1_base + i1_win >= 0 && + i2_base + i2_win >= 0 && i0_base + i0_win < operand.n1() && + i1_base + i1_win < operand.n2() && + i2_base + i2_win < operand.n3()) { + val += operand(i0_base + i0_win, i1_base + i1_win, + i2_base + i2_win); + } + } + } + } + (*result)(i0, i1, i2) = val; + } + } + } + return result; +} + /* static */ std::unique_ptr> ReferenceUtil::ReduceWindow4DGeneric( const Array4D& operand, float init, @@ -520,7 +577,7 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated( HloEvaluator evaluator; std::unique_ptr result_literal = - evaluator.Evaluate(*computation, {}).ConsumeValueOrDie(); + evaluator.Evaluate(*computation, {}).ConsumeValueOrDie(); CHECK_EQ(ShapeUtil::Rank(result_literal->shape()), 4); auto result = @@ -594,8 +651,12 @@ ReferenceUtil::ReduceToRowArray2D( i2 == 0 || (dim_set.count(2) && i2 < array.n3()); ++i2) { for (int64 i3 = 0; i3 == 0 || (dim_set.count(3) && i3 < array.n4()); ++i3) { - accumulator = reduce_function( - accumulator, array(a0 + i0, a1 + i1, a2 + i2, a3 + i3)); + // Handle zero-sized arrays. + if (array.n1() > 0 && array.n2() > 0 && array.n3() > 0 && + array.n4() > 0) { + accumulator = reduce_function( + accumulator, array(a0 + i0, a1 + i1, a2 + i2, a3 + i3)); + } } } } diff --git a/tensorflow/compiler/xla/reference_util.h b/tensorflow/compiler/xla/reference_util.h index 62d455d71a70407e903a1e0be42a7e9f1898e523..3ec96f2f38b8f91e1549419b60481327fa9bbd5f 100644 --- a/tensorflow/compiler/xla/reference_util.h +++ b/tensorflow/compiler/xla/reference_util.h @@ -70,7 +70,7 @@ class ReferenceUtil { // dilation factors. static std::unique_ptr> ConvArray4DGeneralDimensionsDilated( const Array4D& lhs, const Array4D& rhs, - std::pair stride, Padding padding, + std::pair kernel_stride, Padding padding, std::pair lhs_dilation, std::pair rhs_dilation, ConvolutionDimensionNumbers dnums); @@ -173,6 +173,10 @@ class ReferenceUtil { const Array2D& operand, float init, const tensorflow::gtl::ArraySlice& window, const tensorflow::gtl::ArraySlice& stride, Padding padding); + static std::unique_ptr> ReduceWindow3DAdd( + const Array3D& operand, float init, + const tensorflow::gtl::ArraySlice& window, + const tensorflow::gtl::ArraySlice& stride, Padding padding); static std::unique_ptr> ReduceWindow4DAdd( const Array4D& operand, float init, const tensorflow::gtl::ArraySlice& window, @@ -184,11 +188,18 @@ class ReferenceUtil { const std::function& reduce_func, const tensorflow::gtl::ArraySlice& window, const tensorflow::gtl::ArraySlice& stride, Padding padding); + static std::unique_ptr> ReduceWindow1DGeneric( + const tensorflow::gtl::ArraySlice& operand, float init, + const std::function& reduce_func, + const tensorflow::gtl::ArraySlice& window, + const tensorflow::gtl::ArraySlice& stride, + const tensorflow::gtl::ArraySlice>& padding); static std::unique_ptr> ReduceWindow4DGeneric( const Array4D& operand, float init, const std::function& reduce_func, const tensorflow::gtl::ArraySlice& window, const tensorflow::gtl::ArraySlice& stride, Padding padding); + // With arbitrary padding. static std::unique_ptr> ReduceWindow4DGeneric( const Array4D& operand, float init, const std::function& reduce_func, diff --git a/tensorflow/compiler/xla/reference_util_test.cc b/tensorflow/compiler/xla/reference_util_test.cc index 846ccdc83df900e3afedb6ababe07ebb1bd68f41..9da9bc60a2025e63b57a3be9ed360d150f88d73c 100644 --- a/tensorflow/compiler/xla/reference_util_test.cc +++ b/tensorflow/compiler/xla/reference_util_test.cc @@ -86,6 +86,13 @@ TEST_F(ReferenceUtilTest, ReduceToRowArray2D) { ErrorSpec(0.0001)); } +TEST_F(ReferenceUtilTest, Reduce4Dto1DZeroSizedArray) { + auto result = Literal::CreateR1(ReferenceUtil::Reduce4DTo1D( + Array4D(1, 0, 1, 1), /*init=*/0, /*dims=*/{0, 1, 2}, + [](float a, float b) { return a + b; })); + LiteralTestUtil::ExpectR1Equal({0}, *result); +} + TEST_F(ReferenceUtilTest, MapArray2D) { auto identity = [](float value) { return log(exp(value)); }; auto result = ReferenceUtil::MapArray2D(*matrix_, identity); diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index d3175c1e49974b060cc495d463d4995c925abcf7..71341c6f1e9a359a6d2a8aa9f2fb97b140ade23d 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -108,6 +108,7 @@ tf_cc_test( "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test", @@ -115,6 +116,7 @@ tf_cc_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/service:hlo_element_type_converter", "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep @@ -1009,9 +1011,9 @@ tf_cc_test( ) cc_library( - name = "batchnorm_rewriter", - srcs = ["batchnorm_rewriter.cc"], - hdrs = ["batchnorm_rewriter.h"], + name = "batchnorm_expander", + srcs = ["batchnorm_expander.cc"], + hdrs = ["batchnorm_expander.h"], deps = [ ":hlo", ":hlo_pass", @@ -1029,11 +1031,11 @@ cc_library( ) tf_cc_test( - name = "batchnorm_rewriter_test", + name = "batchnorm_expander_test", size = "small", - srcs = ["batchnorm_rewriter_test.cc"], + srcs = ["batchnorm_expander_test.cc"], deps = [ - ":batchnorm_rewriter", + ":batchnorm_expander", ":hlo", ":hlo_matchers", ":hlo_pass", @@ -1099,6 +1101,8 @@ cc_library( ":hlo", ":hlo_evaluator", ":hlo_pass", + ":tuple_util", + ":while_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/core:lib", ], @@ -1143,6 +1147,21 @@ tf_cc_test( ], ) +cc_library( + name = "dot_decomposer", + srcs = ["dot_decomposer.cc"], + hdrs = ["dot_decomposer.h"], + deps = [ + ":hlo", + ":hlo_pass", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:lib", + ], +) + cc_library( name = "tuple_simplifier", srcs = ["tuple_simplifier.cc"], @@ -1663,6 +1682,7 @@ tf_cc_test( ":hlo", ":hlo_graph_dumper", ":hlo_matchers", + ":hlo_runner", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", @@ -1670,7 +1690,6 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", ], ) @@ -1703,6 +1722,22 @@ cc_library( ], ) +tf_cc_test( + name = "hlo_verifier_test", + srcs = ["hlo_verifier_test.cc"], + deps = [ + ":hlo", + ":hlo_verifier", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", + ], +) + cc_library( name = "hlo_rematerialization", srcs = ["hlo_rematerialization.cc"], @@ -1889,6 +1924,22 @@ tf_cc_test( ], ) +cc_library( + name = "hlo_element_type_converter", + srcs = ["hlo_element_type_converter.cc"], + hdrs = ["hlo_element_type_converter.h"], + deps = [ + ":hlo", + ":hlo_evaluator", + ":hlo_pass", + ":hlo_query", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:types", + "//tensorflow/core:lib", + ], +) + cc_library( name = "device_memory_allocator", srcs = ["device_memory_allocator.cc"], @@ -2021,6 +2072,7 @@ cc_library( "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla:xla_proto", "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", "//tensorflow/core:regexp_internal", ], alwayslink = 1, @@ -2074,6 +2126,41 @@ tf_cc_test( ], ) +cc_library( + name = "zero_sized_hlo_elimination", + srcs = ["zero_sized_hlo_elimination.cc"], + hdrs = ["zero_sized_hlo_elimination.h"], + deps = [ + ":hlo", + ":hlo_pass", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "zero_sized_hlo_elimination_test", + srcs = ["zero_sized_hlo_elimination_test.cc"], + deps = [ + ":hlo", + ":shape_inference", + ":zero_sized_hlo_elimination", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + ], +) + cc_library( name = "pool", hdrs = ["pool.h"], @@ -2170,6 +2257,78 @@ cc_library( ], ) +cc_library( + name = "tuple_util", + srcs = ["tuple_util.cc"], + hdrs = ["tuple_util.h"], + deps = [ + ":hlo", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "tuple_util_test", + srcs = ["tuple_util_test.cc"], + deps = [ + ":tuple_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/compiler/xla/tools/parser:hlo_parser", + ], +) + +cc_library( + name = "while_util", + srcs = ["while_util.cc"], + hdrs = ["while_util.h"], + deps = [ + ":call_inliner", + ":hlo", + ":tuple_util", + ], +) + +tf_cc_test( + name = "while_util_test", + srcs = ["while_util_test.cc"], + deps = [ + ":while_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/compiler/xla/tools/parser:hlo_parser", + ], +) + +cc_library( + name = "while_loop_invariant_code_motion", + srcs = ["while_loop_invariant_code_motion.cc"], + hdrs = ["while_loop_invariant_code_motion.h"], + deps = [ + ":hlo", + ":hlo_pass", + ":tuple_util", + ":while_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "while_loop_invariant_code_motion_test", + srcs = ["while_loop_invariant_code_motion_test.cc"], + deps = [ + ":hlo_matchers", + ":while_loop_invariant_code_motion", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/core:test", + ], +) + # ----------------------------------------------------------------------------- filegroup( diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 71491218aa221cb26ea45f288ddc47173a15df3f..90a3f0b6748fc00c9cd9226700805bf243a1acdd 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -193,6 +193,33 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { enable_dot_strength_reduction_(enable_dot_strength_reduction), enable_conv_simplification_(enable_conv_simplification) {} + // Transforms Dots where at least one input is a vector or has a degenerate + // dimension and converts it into a multiply and reduce. This should enable + // more fusion than leaving the nodes as Dot operations. + StatusOr HandleDotStrengthReduction(HloInstruction* dot); + + // Reshapes an instruction to rank 1 if it is not already rank 1. + HloInstruction* Flatten(HloInstruction* hlo) { + if (ShapeUtil::Rank(hlo->shape()) == 1) { + return hlo; + } + return computation_->AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(hlo->shape().element_type(), + {ShapeUtil::ElementsIn(hlo->shape())}), + hlo)); + } + + // Helper method to perform and add reduction in a single dimension. + HloInstruction* AddReduce(HloInstruction* hlo, int64 dim) { + HloInstruction* zero = computation_->AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); + HloComputation* AddReduce_computation = CreateScalarBinaryComputation( + computation_->parent(), F32, HloOpcode::kAdd); + Shape shape = ShapeUtil::DeleteDimension(dim, hlo->shape()); + return computation_->AddInstruction(HloInstruction::CreateReduce( + shape, hlo, zero, {dim}, AddReduce_computation)); + } + // Convenience method for replacing an instruction with a bitcast. void ReplaceWithBitcast(HloInstruction* instruction); @@ -252,6 +279,11 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { return Status::OK(); } + StatusOr OptimizeDotOfConcat(HloInstruction* dot); + StatusOr OptimizeDotOfConcatHelper( + const Shape& dot_shape, HloInstruction* lhs, int64 lhs_contracting_dim, + HloInstruction* rhs, int64 rhs_contracting_dim, bool swapped); + // Current HloComputation instance the AlgebraicSimplifierVisitor is // traversing. HloComputation* computation_; @@ -329,6 +361,39 @@ Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add) { return Status::OK(); } + // Canonicalization: Put constants on the right. This makes the reassociation + // rules below simpler. + VLOG(10) << "trying transform [Const + A => A + Const]"; + if (lhs->IsConstant() && !rhs->IsConstant()) { + return ReplaceWithNewInstruction( + add, + HloInstruction::CreateBinary(add->shape(), HloOpcode::kAdd, rhs, lhs)); + } + + // Reassociate to allow constant folding. + // + // Note: This is not general. For example, we won't reassociate + // + // (A + C1) + (B + C2) => A + B + (C1 + C2). + // + VLOG(10) << "trying transform [(A + C1) + C2 => A + (C1 + C2)]"; + if (rhs->IsConstant() && lhs->opcode() == HloOpcode::kAdd && + !lhs->operand(0)->IsConstant() && lhs->operand(1)->IsConstant()) { + auto* c1 = lhs->mutable_operand(1); + auto* c2 = rhs; + TF_ASSIGN_OR_RETURN( + Shape sum_of_constants_shape, + ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, c1, c2)); + + auto* sum_of_constants = + computation_->AddInstruction(HloInstruction::CreateBinary( + sum_of_constants_shape, HloOpcode::kAdd, c1, c2)); + return ReplaceWithNewInstruction( + add, HloInstruction::CreateBinary(add->shape(), HloOpcode::kAdd, + lhs->mutable_operand(0), + sum_of_constants)); + } + return Status::OK(); } @@ -433,13 +498,14 @@ static HloInstruction* BuildTupleConstant(HloComputation* computation, if (ShapeUtil::IsTuple(literal.shape())) { std::vector elems; elems.reserve(ShapeUtil::TupleElementCount(literal.shape())); - for (const Literal& child : literal.tuple_literals()) { - elems.push_back(BuildTupleConstant(computation, child)); + for (int i = 0; i < ShapeUtil::TupleElementCount(literal.shape()); ++i) { + elems.push_back( + BuildTupleConstant(computation, LiteralView::Create(literal, {i}))); } return computation->AddInstruction(HloInstruction::CreateTuple(elems)); } else { return computation->AddInstruction( - HloInstruction::CreateConstant(MakeUnique(literal))); + HloInstruction::CreateConstant(literal.CloneToUnique())); } } @@ -462,6 +528,16 @@ Status AlgebraicSimplifierVisitor::HandleSubtract(HloInstruction* sub) { return Status::OK(); } + // Canonicalize subtraction of a constant to addition. + VLOG(10) << "trying transform [A - Const => A + (-Const)]"; + if (rhs->IsConstant() && !lhs->IsConstant()) { + HloInstruction* negative_const = computation_->AddInstruction( + HloInstruction::CreateUnary(rhs->shape(), HloOpcode::kNegate, rhs)); + return ReplaceWithNewInstruction( + sub, HloInstruction::CreateBinary(sub->shape(), HloOpcode::kAdd, lhs, + negative_const)); + } + return Status::OK(); } @@ -523,6 +599,23 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) { return Status::OK(); } + // A / Const => A * (1 / Const) + // + // (Backends can do this transformation, but generally only if the constant is + // a scalar.) + if (lhs->opcode() != HloOpcode::kConstant && + rhs->opcode() == HloOpcode::kConstant) { + HloInstruction* one = + computation_->AddInstruction(HloInstruction::CreateConstant( + Literal::One(lhs->shape().element_type()).CloneToUnique())); + HloInstruction* inverse = + computation_->AddInstruction(HloInstruction::CreateBinary( + rhs->shape(), HloOpcode::kDivide, one, rhs)); + return ReplaceWithNewInstruction( + divide, HloInstruction::CreateBinary( + divide->shape(), HloOpcode::kMultiply, lhs, inverse)); + } + // (A / B) / (C / D) => (A / B)*(D / C) => (A * D) / (B * C) if (lhs->opcode() == HloOpcode::kDivide && rhs->opcode() == HloOpcode::kDivide) { @@ -574,70 +667,72 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) { return Status::OK(); } -Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { - auto lhs = dot->mutable_operand(0); - auto rhs = dot->mutable_operand(1); +StatusOr AlgebraicSimplifierVisitor::HandleDotStrengthReduction( + HloInstruction* dot) { + HloInstruction* lhs = dot->mutable_operand(0); + HloInstruction* rhs = dot->mutable_operand(1); + int64 lhs_collapsing_dim = + dot->dot_dimension_numbers().lhs_contracting_dimensions(0); + if (lhs->IsRank2Transpose()) { + lhs = lhs->mutable_operand(0); + lhs_collapsing_dim = 1 - lhs_collapsing_dim; + } + const int64 lhs_kept_dim = 1 - lhs_collapsing_dim; + + int64 rhs_collapsing_dim = + dot->dot_dimension_numbers().rhs_contracting_dimensions(0); + if (rhs->IsRank2Transpose()) { + rhs = rhs->mutable_operand(0); + rhs_collapsing_dim = 1 - rhs_collapsing_dim; + } + const int64 rhs_kept_dim = 1 - rhs_collapsing_dim; + + auto reshape_if_necessary = [&](HloInstruction* hlo) { + if (ShapeUtil::SameDimensions(hlo->shape(), dot->shape())) { + return hlo; + } + return computation_->AddInstruction( + HloInstruction::CreateReshape(dot->shape(), hlo)); + }; - // Only optimize F32 dot operations where the dot, rhs and lhs are rank 2 or - // below. - if (dot->shape().element_type() != F32 || ShapeUtil::Rank(lhs->shape()) > 2 || - ShapeUtil::Rank(rhs->shape()) > 2 || ShapeUtil::Rank(dot->shape()) > 2) { - return Status::OK(); - } + auto broadcast_to_dim = [&](HloInstruction* hlo, const Shape& shape, + int64 dim) { + return computation_->AddInstruction( + HloInstruction::CreateBroadcast(shape, hlo, {dim})); + }; - // Replace a zero element dot with a broadcast of the constant 0. - if (ShapeUtil::HasZeroElements(dot->shape()) || - ShapeUtil::HasZeroElements(lhs->shape()) || - ShapeUtil::HasZeroElements(rhs->shape())) { - auto zero = computation_->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); - return ReplaceWithNewInstruction( - dot, HloInstruction::CreateBroadcast(dot->shape(), zero, {})); - } + auto multiply = [&](HloInstruction* local_lhs, HloInstruction* local_rhs) { + return computation_->AddInstruction(HloInstruction::CreateBinary( + local_lhs->shape(), HloOpcode::kMultiply, local_lhs, local_rhs)); + }; - // Simplify dot(transpose(a), transpose(b)) to transpose(dot(b,a)). - if (lhs->IsRank2Transpose() && rhs->IsRank2Transpose()) { - auto new_dot = computation_->AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::PermuteDimensions({1, 0}, dot->shape()), HloOpcode::kDot, - rhs->mutable_operand(0), lhs->mutable_operand(0))); - return ReplaceWithNewInstruction( - dot, HloInstruction::CreateTranspose(dot->shape(), new_dot, {1, 0})); + // Strength reduce dot(a[K] , b[K]) = + // reshape(result.shape, + // reduce_sum(multiply(a, b), {0})) + if (ShapeUtil::Rank(rhs->shape()) == 1 && + ShapeUtil::Rank(lhs->shape()) == 1) { + TF_RETURN_IF_ERROR( + ReplaceInstruction(dot, reshape_if_necessary(AddReduce( + multiply(Flatten(lhs), Flatten(rhs)), 0)))); + return true; } - if (!enable_dot_strength_reduction_) { - return Status::OK(); + if (ShapeUtil::IsEffectiveScalar(rhs->shape()) && + ShapeUtil::IsEffectiveScalar(lhs->shape())) { + TF_RETURN_IF_ERROR(ReplaceInstruction( + dot, reshape_if_necessary(multiply(Flatten(lhs), Flatten(rhs))))); + return true; } // Simplify outer product into multiply with implicit broadcasting. // // A dot(a[M, 1], b[1, N]) = multiply(a [M,1], b [1, N]) - if (ShapeUtil::Rank(rhs->shape()) == 2 && rhs->shape().dimensions(0) == 1) { - return ReplaceWithNewInstruction( - dot, HloInstruction::CreateBinary(dot->shape(), HloOpcode::kMultiply, - lhs, rhs)); - } - - // The following graph transformations take Dots where at least one input is a - // vector or has a degenerate dimension and converts it into a multiply and - // reduce. This should enable more fusion than leaving the nodes as Dot - // operations. - - // Strength reduce dot(a[K] , b[K]) = - // reshape(result.shape, - // reduce_sum(multiply(a, b), {0})) - if (ShapeUtil::Rank(rhs->shape()) == 1 && - ShapeUtil::Rank(lhs->shape()) == 1) { - auto multiply = computation_->AddInstruction(HloInstruction::CreateBinary( - rhs->shape(), HloOpcode::kMultiply, lhs, rhs)); - HloComputation* add_reduce_computation = CreateScalarBinaryComputation( - computation_->parent(), F32, HloOpcode::kAdd); - auto zero = computation_->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); - auto reduce = computation_->AddInstruction(HloInstruction::CreateReduce( - ShapeUtil::MakeShape(dot->shape().element_type(), {}), multiply, zero, - {0}, add_reduce_computation)); - return ReplaceWithNewInstruction( - dot, HloInstruction::CreateReshape(dot->shape(), reduce)); + if (ShapeUtil::Rank(rhs->shape()) == 2 && + rhs->shape().dimensions(rhs_collapsing_dim) == 1) { + TF_RETURN_IF_ERROR(ReplaceInstruction( + dot, multiply(broadcast_to_dim(Flatten(lhs), dot->shape(), 0), + broadcast_to_dim(Flatten(rhs), dot->shape(), 1)))); + return true; } // Strength reduce dot(a[1, K], b) = @@ -648,35 +743,21 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { // ) // ) if (ShapeUtil::Rank(lhs->shape()) == 1 || - (ShapeUtil::Rank(lhs->shape()) == 2 && lhs->shape().dimensions(0) == 1)) { - auto new_lhs = computation_->AddInstruction(HloInstruction::CreateReshape( - ShapeUtil::MakeShape(lhs->shape().element_type(), - {ShapeUtil::ElementsIn(lhs->shape())}), - lhs)); - HloComputation* add_reduce_computation = CreateScalarBinaryComputation( - computation_->parent(), F32, HloOpcode::kAdd); - auto zero = computation_->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); - HloInstruction* reduce; + (ShapeUtil::Rank(lhs->shape()) == 2 && + lhs->shape().dimensions(lhs_kept_dim) == 1)) { if (ShapeUtil::Rank(rhs->shape()) == 1) { - auto multiply = computation_->AddInstruction(HloInstruction::CreateBinary( - rhs->shape(), HloOpcode::kMultiply, new_lhs, rhs)); - reduce = computation_->AddInstruction(HloInstruction::CreateReduce( - ShapeUtil::MakeShape(dot->shape().element_type(), {}), multiply, zero, - {0}, add_reduce_computation)); - } else { - new_lhs = computation_->AddInstruction( - HloInstruction::CreateBroadcast(rhs->shape(), new_lhs, {0})); - auto multiply = computation_->AddInstruction(HloInstruction::CreateBinary( - rhs->shape(), HloOpcode::kMultiply, new_lhs, rhs)); - - reduce = computation_->AddInstruction(HloInstruction::CreateReduce( - ShapeUtil::MakeShape(dot->shape().element_type(), - {rhs->shape().dimensions(1)}), - multiply, zero, {0}, add_reduce_computation)); + TF_RETURN_IF_ERROR(ReplaceInstruction( + dot, + reshape_if_necessary(AddReduce(multiply(Flatten(lhs), rhs), 0)))); + return true; } - return ReplaceWithNewInstruction( - dot, HloInstruction::CreateReshape(dot->shape(), reduce)); + TF_RETURN_IF_ERROR(ReplaceInstruction( + dot, reshape_if_necessary( + AddReduce(multiply(broadcast_to_dim(Flatten(lhs), rhs->shape(), + rhs_collapsing_dim), + rhs), + rhs_collapsing_dim)))); + return true; } // Strength reduce dot(a, b[K, 1]) = @@ -684,26 +765,208 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { // reduce_sum(multiply(a, broadcast(reshape([K],b), {1})), {0}) // ) if (ShapeUtil::Rank(rhs->shape()) == 1 || - (ShapeUtil::Rank(rhs->shape()) == 2 && rhs->shape().dimensions(1) == 1)) { - auto new_rhs = computation_->AddInstruction(HloInstruction::CreateReshape( - ShapeUtil::MakeShape(rhs->shape().element_type(), - {ShapeUtil::ElementsIn(rhs->shape())}), - rhs)); - new_rhs = computation_->AddInstruction( - HloInstruction::CreateBroadcast(lhs->shape(), new_rhs, {1})); - auto multiply = computation_->AddInstruction(HloInstruction::CreateBinary( - lhs->shape(), HloOpcode::kMultiply, lhs, new_rhs)); - HloComputation* add_reduce_computation = CreateScalarBinaryComputation( - computation_->parent(), F32, HloOpcode::kAdd); + (ShapeUtil::Rank(rhs->shape()) == 2 && + rhs->shape().dimensions(rhs_kept_dim) == 1)) { + TF_RETURN_IF_ERROR(ReplaceInstruction( + dot, reshape_if_necessary(AddReduce( + multiply(lhs, broadcast_to_dim(Flatten(rhs), lhs->shape(), + lhs_collapsing_dim)), + lhs_collapsing_dim)))); + return true; + } + return false; +} + +StatusOr AlgebraicSimplifierVisitor::OptimizeDotOfConcat( + HloInstruction* dot) { + const DotDimensionNumbers& dnums = dot->dot_dimension_numbers(); + if (dnums.lhs_contracting_dimensions_size() != 1 || + dnums.lhs_batch_dimensions_size() != 0) { + return nullptr; + } + + const int64 lhs_contracting_dim = dnums.lhs_contracting_dimensions(0); + const int64 rhs_contracting_dim = dnums.rhs_contracting_dimensions(0); + HloInstruction* lhs = dot->mutable_operand(0); + HloInstruction* rhs = dot->mutable_operand(1); + + TF_ASSIGN_OR_RETURN( + HloInstruction * optimized_lhs_concat, + OptimizeDotOfConcatHelper(dot->shape(), lhs, lhs_contracting_dim, rhs, + rhs_contracting_dim, /*swapped=*/false)); + if (optimized_lhs_concat) { + return optimized_lhs_concat; + } + + return OptimizeDotOfConcatHelper(dot->shape(), rhs, rhs_contracting_dim, lhs, + lhs_contracting_dim, /*swapped=*/true); +} + +StatusOr AlgebraicSimplifierVisitor::OptimizeDotOfConcatHelper( + const Shape& dot_shape, HloInstruction* lhs, int64 lhs_contracting_dim, + HloInstruction* rhs, int64 rhs_contracting_dim, bool swapped) { + bool can_optimize = lhs->opcode() == HloOpcode::kConcatenate && + lhs->concatenate_dimension() == lhs_contracting_dim && + rhs->opcode() == HloOpcode::kConstant; + if (!can_optimize) { + return nullptr; + } + + // We're replacing this: + // + // +-----+-----+-----+ +-------------------+ + // | | | | | | + // | | | | | R_0 | + // | | | | | | + // | | | | +-------------------+ + // | | | | | | + // | L_0 | L_1 | L_2 | * | R_1 | + // | | | | | | + // | | | | +-------------------+ + // | | | | | | + // | | | | | R_2 | + // | | | | | | + // +-----+-----+-----+ +-------------------+ + // + // with this: + // + // [Sum over i] + // + // +-----+ +-------------------+ + // | | | | + // | | * | R_i | + // | | | | + // | | +-------------------+ + // | | + // | L_i | + // | | + // | | + // | | + // | | + // | | + // +-----+ + // + // where the LHS is a concatenate operation (so we can "split" the LHS tensor + // for free) and the RHS is a constant tensor (and thus can be split at + // compile time). In the future, we may also want to do this when both the + // LHS and the RHS are concatenate operations that line up along the dimension + // being contracted over. + // + // We should be able to generalize this transform to work on a non-constant + // RHS when/if we have in-place slices or support input-fusing slices into + // Dots. + + // Dimension numbers for the new dot instructions we'll create (L_i * R_i in + // the diagram above). + DotDimensionNumbers new_dot_dnums; + new_dot_dnums.add_lhs_contracting_dimensions(swapped ? rhs_contracting_dim + : lhs_contracting_dim); + new_dot_dnums.add_rhs_contracting_dimensions(swapped ? lhs_contracting_dim + : rhs_contracting_dim); + + // Here we use the MKN notation, where the contracted dimension has K + // elements and the two non-contracted dimensions have M and N elements. + HloInstruction* add_result = nullptr; + int64 rhs_contracting_dim_offset = 0; + int64 n = rhs->shape().dimensions(1 - rhs_contracting_dim); + for (HloInstruction* concat_op : lhs->operands()) { + int64 sub_k = concat_op->shape().dimensions(lhs_contracting_dim); + Shape rhs_slice_shape(rhs->shape()); + rhs_slice_shape.set_dimensions(rhs_contracting_dim, sub_k); + + std::array start_indices; + start_indices[rhs_contracting_dim] = rhs_contracting_dim_offset; + start_indices[1 - rhs_contracting_dim] = 0; + + std::array limit_indices; + limit_indices[rhs_contracting_dim] = rhs_contracting_dim_offset + sub_k; + limit_indices[1 - rhs_contracting_dim] = n; + + HloInstruction* rhs_slice = + computation_->AddInstruction(HloInstruction::CreateSlice( + rhs_slice_shape, rhs, /*start_indices=*/start_indices, + /*limit_indices=*/limit_indices, /*strides=*/{1, 1})); + + // TODO(b/69062148): We can get rid of `swapped` once all backends support + // "non-canonical" contraction dimensions (that contracts dimension 1 of the + // LHS with dimension 0 of the RHS). But for now we keep the same + // contraction dimensions as the incoming dot operation to ensure the new + // dot operations can be lowered. + HloInstruction *new_dot_lhs, *new_dot_rhs; + if (swapped) { + new_dot_lhs = rhs_slice; + new_dot_rhs = concat_op; + } else { + new_dot_lhs = concat_op; + new_dot_rhs = rhs_slice; + } + + auto* new_dot = computation_->AddInstruction(HloInstruction::CreateDot( + dot_shape, new_dot_lhs, new_dot_rhs, new_dot_dnums)); + + if (add_result) { + add_result = computation_->AddInstruction(HloInstruction::CreateBinary( + dot_shape, HloOpcode::kAdd, add_result, new_dot)); + } else { + add_result = new_dot; + } + + rhs_contracting_dim_offset += sub_k; + } + + return add_result; +} + +Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { + auto lhs = dot->mutable_operand(0); + auto rhs = dot->mutable_operand(1); + + // Only optimize F32 dot operations where the dot, rhs and lhs are rank 2 or + // below. + if (dot->shape().element_type() != F32 || ShapeUtil::Rank(lhs->shape()) > 2 || + ShapeUtil::Rank(rhs->shape()) > 2 || ShapeUtil::Rank(dot->shape()) > 2) { + return Status::OK(); + } + + // Replace a zero element dot with a broadcast of the constant 0. + if (ShapeUtil::HasZeroElements(dot->shape()) || + ShapeUtil::HasZeroElements(lhs->shape()) || + ShapeUtil::HasZeroElements(rhs->shape())) { auto zero = computation_->AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); - auto reduce = computation_->AddInstruction(HloInstruction::CreateReduce( - ShapeUtil::MakeShape(dot->shape().element_type(), - {lhs->shape().dimensions(0)}), - multiply, zero, {1}, add_reduce_computation)); return ReplaceWithNewInstruction( - dot, HloInstruction::CreateReshape(dot->shape(), reduce)); + dot, HloInstruction::CreateBroadcast(dot->shape(), zero, {})); + } + + TF_ASSIGN_OR_RETURN(HloInstruction * dot_of_concat_optimized, + OptimizeDotOfConcat(dot)); + if (dot_of_concat_optimized) { + VLOG(10) << "Replaced dot(concat(...), constant) with add(dot(..., " + "constant)...)"; + return ReplaceInstruction(dot, dot_of_concat_optimized); + } + + if (enable_dot_strength_reduction_ && !is_layout_sensitive_) { + TF_ASSIGN_OR_RETURN(bool did_strength_reduction, + HandleDotStrengthReduction(dot)); + if (did_strength_reduction) { + return Status::OK(); + } } + + // Simplify dot(transpose(a), transpose(b)) to transpose(dot(b,a)). + if (lhs->IsRank2Transpose() && rhs->IsRank2Transpose()) { + DotDimensionNumbers dot_dimension_numbers; + dot_dimension_numbers.add_lhs_contracting_dimensions(1); + dot_dimension_numbers.add_rhs_contracting_dimensions(0); + auto new_dot = computation_->AddInstruction(HloInstruction::CreateDot( + ShapeUtil::PermuteDimensions({1, 0}, dot->shape()), + rhs->mutable_operand(0), lhs->mutable_operand(0), + dot_dimension_numbers)); + return ReplaceWithNewInstruction( + dot, HloInstruction::CreateTranspose(dot->shape(), new_dot, {1, 0})); + } + return Status::OK(); } @@ -980,6 +1243,11 @@ Status AlgebraicSimplifierVisitor::HandleImag(HloInstruction* imag) { } Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) { + if (ShapeUtil::HasZeroElements(pad->operand(0)->shape())) { + return ReplaceWithNewInstruction( + pad, HloInstruction::CreateBroadcast(pad->shape(), + pad->mutable_operand(1), {})); + } // Eliminate nop pads (padding all zero), and replace a pad with negative // padding with a pad with non-negative padding followed by a slice. bool all_zero = true; @@ -1120,6 +1388,27 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) { power, HloInstruction::CreateBinary(power->shape(), HloOpcode::kDivide, broadcast_one, lhs)); } + + VLOG(10) << "trying transform [pow(pow(A, X), Y) => pow(A, X*Y)]: " + << power->ToString(); + + // Don't perform this optimization if either of the exponents is complex; this + // identity is true only for real-valued exponents. In addition, we cowardly + // refuse to do this transformation if the two expontents have different + // element types. + if (lhs->opcode() == HloOpcode::kPower && + !ShapeUtil::ElementIsComplex(lhs->operand(1)->shape()) && + !ShapeUtil::ElementIsComplex(rhs->shape()) && + ShapeUtil::SameElementType(lhs->operand(1)->shape(), rhs->shape())) { + auto exponent_product = + computation_->AddInstruction(HloInstruction::CreateBinary( + rhs->shape(), HloOpcode::kMultiply, lhs->mutable_operand(1), rhs)); + return ReplaceWithNewInstruction( + power, HloInstruction::CreateBinary(power->shape(), HloOpcode::kPower, + lhs->mutable_operand(0), + exponent_product)); + } + return Status::OK(); } @@ -1173,7 +1462,7 @@ StatusOr AlgebraicSimplifierVisitor:: ShapeUtil::MakeShapeWithLayout( user->shape().element_type(), AsInt64Slice(operand->shape().dimensions()), - AsInt64Slice(operand->shape().layout().minor_to_major())), + LayoutUtil::MinorToMajor(operand->shape())), new_user_operands)); VLOG(4) << " new user: " << new_user->ToString(); HloInstruction* new_reshape_or_broadcast = nullptr; @@ -1183,8 +1472,7 @@ StatusOr AlgebraicSimplifierVisitor:: ShapeUtil::MakeShapeWithLayout( user->shape().element_type(), AsInt64Slice(reshape_or_broadcast->shape().dimensions()), - AsInt64Slice( - reshape_or_broadcast->shape().layout().minor_to_major())), + LayoutUtil::MinorToMajor(reshape_or_broadcast->shape())), new_user)); } else { TF_RET_CHECK(reshape_or_broadcast->opcode() == HloOpcode::kBroadcast); @@ -1193,8 +1481,7 @@ StatusOr AlgebraicSimplifierVisitor:: ShapeUtil::MakeShapeWithLayout( user->shape().element_type(), AsInt64Slice(reshape_or_broadcast->shape().dimensions()), - AsInt64Slice( - reshape_or_broadcast->shape().layout().minor_to_major())), + LayoutUtil::MinorToMajor(reshape_or_broadcast->shape())), new_user, reshape_or_broadcast->dimensions())); } VLOG(4) << " new reshape/broadcast: " @@ -1403,6 +1690,12 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) { Status AlgebraicSimplifierVisitor::HandleReduceWindow( HloInstruction* reduce_window) { + if (ShapeUtil::HasZeroElements(reduce_window->operand(0)->shape())) { + return ReplaceWithNewInstruction( + reduce_window, + HloInstruction::CreateBroadcast(reduce_window->shape(), + reduce_window->mutable_operand(1), {})); + } auto operand = reduce_window->mutable_operand(0); const Window& window = reduce_window->window(); auto function = reduce_window->to_apply(); @@ -1473,7 +1766,6 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow( Status AlgebraicSimplifierVisitor::HandleTranspose(HloInstruction* transpose) { auto operand = transpose->mutable_operand(0); - if (std::is_sorted(transpose->dimensions().begin(), transpose->dimensions().end())) { VLOG(10) << "deleting no-op transpose"; @@ -1500,6 +1792,18 @@ Status AlgebraicSimplifierVisitor::HandleConvolution( HloInstruction* convolution) { auto lhs = convolution->mutable_operand(0); auto rhs = convolution->mutable_operand(1); + if (ShapeUtil::HasZeroElements(lhs->shape()) || + ShapeUtil::HasZeroElements(rhs->shape())) { + return ReplaceWithNewInstruction( + convolution, + HloInstruction::CreateBroadcast( + convolution->shape(), + computation_->AddInstruction(HloInstruction::CreateConvert( + ShapeUtil::MakeShape(convolution->shape().element_type(), {}), + computation_->AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(0.0f))))), + {})); + } const auto& window = convolution->window(); if (!enable_conv_simplification_) { return Status::OK(); @@ -1556,15 +1860,15 @@ Status AlgebraicSimplifierVisitor::HandleConvolution( // still convert Conv into more efficient Matmul with operand transposition // (such as the transposition flags in cuBLAS SGEMM). if (!LayoutUtil::Equal(input_shape.layout(), convolution_shape.layout()) || - input_shape.layout().minor_to_major(0) != + LayoutUtil::Minor(input_shape.layout(), 0) != dnums.input_feature_dimension() || - convolution_shape.layout().minor_to_major(0) != + LayoutUtil::Minor(convolution_shape.layout(), 0) != dnums.output_feature_dimension() || // The input feature dimension should come later in the minor-to-major // order. - (PositionInContainer(filter_shape.layout().minor_to_major(), + (PositionInContainer(LayoutUtil::MinorToMajor(filter_shape), dnums.kernel_input_feature_dimension()) < - PositionInContainer(filter_shape.layout().minor_to_major(), + PositionInContainer(LayoutUtil::MinorToMajor(filter_shape), dnums.kernel_output_feature_dimension()))) { return Status::OK(); } @@ -1592,18 +1896,15 @@ Status AlgebraicSimplifierVisitor::HandleConvolution( // We already checked feature_dimension is most minor, so data in input_shape // and row-major {conv_width,input_channels} are bitwise identical. - const Shape new_input_shape = - ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout( - input_shape.element_type(), {conv_width, input_channels}); + const Shape new_input_shape = ShapeUtil::MakeShapeWithDescendingLayout( + input_shape.element_type(), {conv_width, input_channels}); // We already checked input_feature_dimension is more major than // output_feature_dimension, so data in filter_shape and row-major // {input_channels,output_channels} are bitwise identical. - const Shape new_filter_shape = - ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout( - filter_shape.element_type(), {input_channels, output_channels}); - const Shape dot_output_shape = - ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout( - convolution_shape.element_type(), {conv_width, output_channels}); + const Shape new_filter_shape = ShapeUtil::MakeShapeWithDescendingLayout( + filter_shape.element_type(), {input_channels, output_channels}); + const Shape dot_output_shape = ShapeUtil::MakeShapeWithDescendingLayout( + convolution_shape.element_type(), {conv_width, output_channels}); // We cannot insert bitcasts if the layouts will not be compatible. // TODO(b/33178038): Consider inserting a transpose if a bitcast would be @@ -1616,8 +1917,11 @@ Status AlgebraicSimplifierVisitor::HandleConvolution( auto new_lhs = add_bitcast(new_input_shape, lhs); auto new_rhs = add_bitcast(new_filter_shape, rhs); - auto dot = computation_->AddInstruction(HloInstruction::CreateBinary( - dot_output_shape, HloOpcode::kDot, new_lhs, new_rhs)); + DotDimensionNumbers dot_dimension_numbers; + dot_dimension_numbers.add_lhs_contracting_dimensions(1); + dot_dimension_numbers.add_rhs_contracting_dimensions(0); + auto dot = computation_->AddInstruction(HloInstruction::CreateDot( + dot_output_shape, new_lhs, new_rhs, dot_dimension_numbers)); return ReplaceInstruction(convolution, add_bitcast(convolution_shape, dot)); } diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index 56dfb1cf0bc22ed62653d1f0772fdcae58498c27..e7c4dfb0a1690683bbdb7e61067392b48fdba8a5 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -71,6 +71,55 @@ TEST_F(AlgebraicSimplifierTest, AddZero) { EXPECT_EQ(root, param0); } +// Test that Const + A is canonicalized to A + Const. +TEST_F(AlgebraicSimplifierTest, AddConstOnLHS) { + Shape r0f32 = ShapeUtil::MakeShape(F32, {}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0f32, "param0")); + HloInstruction* constant = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + builder.AddInstruction( + HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, constant, param0)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kAdd); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_THAT(root, op::Add(param0, op::Constant())); +} + +// Test that [(A + C1) + C2] => [A + (C1 + C2)] for constants C1 and C2. +TEST_F(AlgebraicSimplifierTest, AddReassociateMergeConstants) { + Shape r0f32 = ShapeUtil::MakeShape(F32, {}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0f32, "param0")); + HloInstruction* constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + HloInstruction* constant2 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(3.14159f))); + + HloInstruction* add1 = builder.AddInstruction( + HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param0, constant1)); + builder.AddInstruction( + HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, add1, constant2)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kAdd); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_THAT(root, op::Add(param0, op::Add(constant1, constant2))); +} + TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR0Operand) { Shape r2f32 = ShapeUtil::MakeShape(F32, {3, 2}); HloComputation::Builder builder(TestName()); @@ -139,6 +188,28 @@ TEST_F(AlgebraicSimplifierTest, SubZero) { EXPECT_EQ(root, param0); } +// Test that A - Const is canonicalized to A + (-Const). +TEST_F(AlgebraicSimplifierTest, SubConstCanonicalization) { + Shape r0f32 = ShapeUtil::MakeShape(F32, {}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0f32, "param0")); + HloInstruction* constant = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + builder.AddInstruction(HloInstruction::CreateBinary( + r0f32, HloOpcode::kSubtract, param0, constant)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kSubtract); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_THAT(root, op::Add(param0, op::Negate(constant))); +} + // Test that (A/B)/C is simplified to A/(B*C). TEST_F(AlgebraicSimplifierTest, LhsDivOfDiv) { Shape r0f32 = ShapeUtil::MakeShape(F32, {}); @@ -327,6 +398,78 @@ TEST_F(AlgebraicSimplifierTest, DivOfBroadcastingPower) { EXPECT_EQ(0, negate_shape.dimensions_size()); } +// A / Const => A * (1 / Const) +TEST_F(AlgebraicSimplifierTest, DivideByConstant) { + Shape r1f32 = ShapeUtil::MakeShape(F32, {3}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r1f32, "param0")); + HloInstruction* constant = + builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR1({0.f, 1.f, 2.f}))); + builder.AddInstruction(HloInstruction::CreateBinary(r1f32, HloOpcode::kDivide, + param0, constant)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + + EXPECT_THAT(computation->root_instruction(), + op::Multiply(param0, op::Divide(op::Constant(), constant))); +} + +// pow(pow(A, X), Y) => pow(A, X*Y) +TEST_F(AlgebraicSimplifierTest, PowerOfPower) { + Shape r0f32 = ShapeUtil::MakeShape(F32, {}); + Shape r1f32 = ShapeUtil::MakeShape(F32, {7}); + HloComputation::Builder builder(TestName()); + HloInstruction* base = builder.AddInstruction( + HloInstruction::CreateParameter(0, r1f32, "param0")); + HloInstruction* exp1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, r0f32, "param1")); + HloInstruction* exp2 = builder.AddInstruction( + HloInstruction::CreateParameter(2, r0f32, "param2")); + HloInstruction* inner_power = builder.AddInstruction( + HloInstruction::CreateBinary(r1f32, HloOpcode::kPower, base, exp1)); + builder.AddInstruction(HloInstruction::CreateBinary(r1f32, HloOpcode::kPower, + inner_power, exp2)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + EXPECT_THAT(computation->root_instruction(), + op::Power(base, op::Multiply(exp1, exp2))); +} + +// Don't simplify pow(pow(A, X), Y) => pow(A, X*Y) if X and Y are complex +// numbers. +TEST_F(AlgebraicSimplifierTest, PowerOfPowerComplex) { + Shape r0c64 = ShapeUtil::MakeShape(C64, {}); + Shape r1f32 = ShapeUtil::MakeShape(F32, {7}); + HloComputation::Builder builder(TestName()); + HloInstruction* base = builder.AddInstruction( + HloInstruction::CreateParameter(0, r1f32, "param0")); + HloInstruction* exp1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, r0c64, "param1")); + HloInstruction* exp2 = builder.AddInstruction( + HloInstruction::CreateParameter(2, r0c64, "param2")); + HloInstruction* inner_power = builder.AddInstruction( + HloInstruction::CreateBinary(r1f32, HloOpcode::kPower, base, exp1)); + builder.AddInstruction(HloInstruction::CreateBinary(r1f32, HloOpcode::kPower, + inner_power, exp2)); + + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie()); +} + // Test that A/1 is simplified to A for a scalar. TEST_F(AlgebraicSimplifierTest, DivOneScalar) { Shape r0f32 = ShapeUtil::MakeShape(F32, {}); @@ -767,6 +910,120 @@ TEST_F(AlgebraicSimplifierTest, PowNegative1) { 1); } +TEST_F(AlgebraicSimplifierTest, ZeroSizedConvolution) { + auto builder = HloComputation::Builder(TestName()); + HloInstruction* lhs = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {3, 3, 0}), "lhs")); + + HloInstruction* rhs = builder.AddInstruction(HloInstruction::CreateParameter( + 1, ShapeUtil::MakeShape(F32, {3, 0, 3}), "rhs")); + + ConvolutionDimensionNumbers dnums; + dnums.set_input_batch_dimension(0); + dnums.add_input_spatial_dimensions(1); + dnums.set_input_feature_dimension(2); + + dnums.set_output_batch_dimension(0); + dnums.add_output_spatial_dimensions(1); + dnums.set_output_feature_dimension(2); + + dnums.add_kernel_spatial_dimensions(0); + dnums.set_kernel_input_feature_dimension(1); + dnums.set_kernel_output_feature_dimension(2); + Window window; + WindowDimension* dim = window.add_dimensions(); + dim->set_size(3); + dim->set_padding_low(0); + dim->set_padding_high(0); + dim->set_stride(1); + dim->set_window_dilation(1); + dim->set_base_dilation(1); + dim->set_window_reversal(false); + // Create add computation. + std::unique_ptr module = CreateNewModule(); + builder.AddInstruction(HloInstruction::CreateConvolve( + ShapeUtil::MakeShape(F32, {3, 3, 3}), lhs, rhs, window, dnums)); + module->AddEntryComputation(builder.Build()); + HloPassFix simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Convolution(lhs, rhs)); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Broadcast(op::Constant())); +} + +TEST_F(AlgebraicSimplifierTest, ZeroSizedReduceWindow) { + auto builder = HloComputation::Builder(TestName()); + HloInstruction* param = + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {3, 0}), "op")); + Window window; + for (int64 i = 0; i < 2; ++i) { + WindowDimension* dim = window.add_dimensions(); + dim->set_size(1); + dim->set_padding_low(1); + dim->set_padding_high(1); + dim->set_window_dilation(1); + dim->set_base_dilation(1); + } + // Create add computation. + std::unique_ptr module = CreateNewModule(); + HloComputation* add_computation = nullptr; + { + HloComputation::Builder builder(TestName() + ".add"); + const Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); + HloInstruction* p0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "p0")); + HloInstruction* p1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape, "p1")); + builder.AddInstruction( + HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1)); + add_computation = module->AddEmbeddedComputation(builder.Build()); + } + builder.AddInstruction(HloInstruction::CreateReduceWindow( + ShapeUtil::MakeShape(F32, {5, 2}), param, + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(0.0f))), + window, add_computation)); + module->AddEntryComputation(builder.Build()); + HloPassFix simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::ReduceWindow(param, op::Constant())); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Broadcast(op::Constant())); +} + +TEST_F(AlgebraicSimplifierTest, ZeroSizedPad) { + auto builder = HloComputation::Builder(TestName()); + HloInstruction* param = + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {3, 0}), "op")); + PaddingConfig padding; + for (int i = 0; i < 2; ++i) { + PaddingConfig::PaddingConfigDimension* dimension = padding.add_dimensions(); + dimension->set_edge_padding_low(1); + dimension->set_edge_padding_high(1); + dimension->set_interior_padding(0); + } + builder.AddInstruction(HloInstruction::CreatePad( + ShapeUtil::MakeShape(F32, {5, 2}), param, + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(0.0f))), + padding)); + std::unique_ptr module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Pad(param, op::Constant())); + HloPassFix simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Broadcast(op::Constant())); +} + TEST_F(AlgebraicSimplifierTest, ReshapeBroadcast) { Shape r0f32 = ShapeUtil::MakeShape(F32, {}); @@ -1260,7 +1517,7 @@ TEST_F(AlgebraicSimplifierTest, CopiesMerged) { HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout(F32, {2, 2, 2}), + 0, ShapeUtil::MakeShapeWithDescendingLayout(F32, {2, 2, 2}), "param0")); HloInstruction* copy1 = builder.AddInstruction(HloInstruction::CreateUnary( @@ -2138,8 +2395,10 @@ TEST_F(AlgebraicSimplifierTest, IteratorInvalidation) { builder.AddInstruction(HloInstruction::CreateParameter(0, r1f32, "x")); HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter(1, r1f32, "y")); - builder.AddInstruction( - HloInstruction::CreateBinary(r1f32, HloOpcode::kDot, x, y)); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + builder.AddInstruction(HloInstruction::CreateDot(r1f32, x, y, dot_dnums)); std::unique_ptr dot_computation(builder.Build()); HloComputation::Builder call_builder(TestName() + ".Call"); @@ -2236,5 +2495,210 @@ TEST_F(AlgebraicSimplifierTest, TrivialDynamicUpdateSlice) { op::DynamicSlice(op::Parameter(), op::Parameter())); } +class DotStrengthReductionTest + : public AlgebraicSimplifierTest, + public ::testing::WithParamInterface< + ::testing::tuple> {}; +TEST_P(DotStrengthReductionTest, DotStrengthReduction) { + int m, k, n; + bool transpose_lhs, transpose_rhs; + std::tie(m, k, n, transpose_lhs, transpose_rhs) = GetParam(); + + Shape dot_shape = ShapeUtil::MakeShape(F32, {m, n}); + Shape lhs_shape = ShapeUtil::MakeShape(F32, {m, k}); + Shape transposed_lhs_shape = ShapeUtil::MakeShape(F32, {k, m}); + Shape rhs_shape = ShapeUtil::MakeShape(F32, {k, n}); + Shape transposed_rhs_shape = ShapeUtil::MakeShape(F32, {n, k}); + HloComputation::Builder builder(TestName()); + + auto lhs = builder.AddInstruction(HloInstruction::CreateParameter( + 0, transpose_lhs ? transposed_lhs_shape : lhs_shape, "lhs")); + if (transpose_lhs) { + lhs = builder.AddInstruction( + HloInstruction::CreateTranspose(lhs_shape, lhs, {1, 0})); + } + auto rhs = builder.AddInstruction(HloInstruction::CreateParameter( + 1, transpose_rhs ? transposed_rhs_shape : rhs_shape, "rhs")); + if (transpose_rhs) { + rhs = builder.AddInstruction( + HloInstruction::CreateTranspose(rhs_shape, rhs, {1, 0})); + } + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + builder.AddInstruction( + HloInstruction::CreateDot(dot_shape, lhs, rhs, dot_dnums)); + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + TF_ASSERT_OK_AND_ASSIGN(bool changed, simplifier.Run(module.get())); + const bool dot_should_be_transformed = m == 1 || k == 1 || n == 1; + const bool computation_should_be_modified = + dot_should_be_transformed || (transpose_lhs && transpose_rhs); + EXPECT_EQ(changed, computation_should_be_modified); + bool has_no_dot = true; + for (const auto& hlo : computation->instructions()) { + if (hlo->opcode() == HloOpcode::kDot) { + has_no_dot = false; + break; + } + } + EXPECT_EQ(has_no_dot, dot_should_be_transformed); +} + +INSTANTIATE_TEST_CASE_P( + DotStrengthReductionTestInstantiation, DotStrengthReductionTest, + ::testing::Combine(::testing::Values(1, 2), ::testing::Values(1, 2), + ::testing::Values(1, 2), ::testing::Bool(), + ::testing::Bool())); + +struct DotOfConcatTestSpec { + int64 m; + int64 k; + int64 n; +}; + +class DotOfConcatSimplificationTest + : public HloTestBase, + public ::testing::WithParamInterface {}; + +// Test that we transform +// dot(const, concat(A, B, C)) +// to +// add(dot(const_0, A), dot(const_1, B), dot(const_2, C)) +TEST_P(DotOfConcatSimplificationTest, ConstantLHS) { + HloComputation::Builder builder(TestName()); + + DotOfConcatTestSpec spec = GetParam(); + + ASSERT_GE(spec.k, 3); + + int64 k0 = spec.k / 3; + int64 k1 = spec.k / 3; + int64 k2 = spec.k - k0 - k1; + + Shape lhs_shape = ShapeUtil::MakeShape(F32, {spec.m, spec.k}); + auto* lhs = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR2F32Linspace( + /*from=*/10.0, /*to=*/10000.0, /*rows=*/spec.m, /*cols=*/spec.k))); + + Shape rhs0_shape = ShapeUtil::MakeShape(F32, {k0, spec.n}); + Shape rhs1_shape = ShapeUtil::MakeShape(F32, {k1, spec.n}); + Shape rhs2_shape = ShapeUtil::MakeShape(F32, {k2, spec.n}); + + HloInstruction* rhs0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, rhs0_shape, "rhs0")); + HloInstruction* rhs1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, rhs1_shape, "rhs1")); + HloInstruction* rhs2 = builder.AddInstruction( + HloInstruction::CreateParameter(2, rhs2_shape, "rhs2")); + + Shape rhs_shape = ShapeUtil::MakeShape(F32, {spec.k, spec.n}); + HloInstruction* rhs = builder.AddInstruction( + HloInstruction::CreateConcatenate(rhs_shape, {rhs0, rhs1, rhs2}, 0)); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + + Shape dot_shape = ShapeUtil::MakeShape(F32, {spec.m, spec.n}); + builder.AddInstruction( + HloInstruction::CreateDot(dot_shape, lhs, rhs, dot_dnums)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(module.get())); + ASSERT_TRUE(run_successful); + + EXPECT_TRUE( + ShapeUtil::Equal(computation->root_instruction()->shape(), dot_shape)); + + auto match_dot_0 = op::Dot(op::Slice(op::Constant()), op::Parameter(0)); + auto match_dot_1 = op::Dot(op::Slice(op::Constant()), op::Parameter(1)); + auto match_dot_2 = op::Dot(op::Slice(op::Constant()), op::Parameter(2)); + EXPECT_THAT(computation->root_instruction(), + op::Add(op::Add(match_dot_0, match_dot_1), match_dot_2)); +} + +// Test that we transform +// dot(concat(A, B, C), const) +// to +// add(dot(A, const_0), dot(B, const_1), dot(C, const_2)) +TEST_P(DotOfConcatSimplificationTest, ConstantRHS) { + HloComputation::Builder builder(TestName()); + + DotOfConcatTestSpec spec = GetParam(); + + ASSERT_GE(spec.k, 4); + + int64 k0 = spec.k / 4; + int64 k1 = spec.k / 4; + int64 k2 = spec.k / 4; + int64 k3 = spec.k - k0 - k1 - k2; + + Shape lhs0_shape = ShapeUtil::MakeShape(F32, {spec.m, k0}); + Shape lhs1_shape = ShapeUtil::MakeShape(F32, {spec.m, k1}); + Shape lhs2_shape = ShapeUtil::MakeShape(F32, {spec.m, k2}); + Shape lhs3_shape = ShapeUtil::MakeShape(F32, {spec.m, k3}); + + HloInstruction* lhs0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, lhs0_shape, "lhs0")); + HloInstruction* lhs1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, lhs1_shape, "lhs1")); + HloInstruction* lhs2 = builder.AddInstruction( + HloInstruction::CreateParameter(2, lhs2_shape, "lhs2")); + HloInstruction* lhs3 = builder.AddInstruction( + HloInstruction::CreateParameter(3, lhs2_shape, "lhs3")); + + Shape lhs_shape = ShapeUtil::MakeShape(F32, {spec.m, spec.k}); + HloInstruction* lhs = + builder.AddInstruction(HloInstruction::CreateConcatenate( + lhs_shape, {lhs0, lhs1, lhs2, lhs3}, 1)); + + Shape rhs_shape = ShapeUtil::MakeShape(F32, {spec.k, spec.m}); + auto* rhs = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR2F32Linspace( + /*from=*/10.0, /*to=*/10000.0, /*rows=*/spec.k, /*cols=*/spec.m))); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + + Shape dot_shape = ShapeUtil::MakeShape(F32, {spec.m, spec.n}); + builder.AddInstruction( + HloInstruction::CreateDot(dot_shape, lhs, rhs, dot_dnums)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(module.get())); + ASSERT_TRUE(run_successful); + EXPECT_TRUE( + ShapeUtil::Equal(computation->root_instruction()->shape(), dot_shape)); + + auto match_dot_0 = op::Dot(op::Parameter(0), op::Slice(op::Constant())); + auto match_dot_1 = op::Dot(op::Parameter(1), op::Slice(op::Constant())); + auto match_dot_2 = op::Dot(op::Parameter(2), op::Slice(op::Constant())); + auto match_dot_3 = op::Dot(op::Parameter(3), op::Slice(op::Constant())); + EXPECT_THAT(computation->root_instruction(), + op::Add(op::Add(op::Add(match_dot_0, match_dot_1), match_dot_2), + match_dot_3)); +} + +DotOfConcatTestSpec kDotOfConcatTestSpecs[] = { + {/*m=*/3, /*k=*/9, /*n=*/3}, // + {/*m=*/3, /*k=*/20, /*n=*/3}, // + {/*m=*/1, /*k=*/18, /*n=*/5}, // + {/*m=*/20, /*k=*/20, /*n=*/1}, // + {/*m=*/1, /*k=*/16, /*n=*/1}, // +}; + +INSTANTIATE_TEST_CASE_P(DotOfConcatSimplificationTestInstantiation, + DotOfConcatSimplificationTest, + ::testing::ValuesIn(kDotOfConcatTestSpecs)); } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/allocation_tracker.cc b/tensorflow/compiler/xla/service/allocation_tracker.cc index ad2fee2d39a8ca183b87212bdeea22c351aaa88a..4e80679c11dfdf7fdf8077a9f354139a4cab6803 100644 --- a/tensorflow/compiler/xla/service/allocation_tracker.cc +++ b/tensorflow/compiler/xla/service/allocation_tracker.cc @@ -27,191 +27,161 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/stream_executor_no_cuda.h" - -namespace se = ::perftools::gputools; namespace xla { -AllocationTracker::AllocationTracker() : next_handle_(1) {} - -GlobalDataHandle AllocationTracker::Register(Backend* backend, - int device_ordinal, - se::DeviceMemoryBase device_memory, - const Shape& shape, - const string& tag) { - tensorflow::mutex_lock lock(allocation_mutex_); +StatusOr AllocationTracker::Register( + std::unique_ptr shaped_buffer, const string& tag) { + tensorflow::mutex_lock lock(mutex_); VLOG(2) << "Register"; - return RegisterInternal(backend, device_ordinal, device_memory, shape, tag, - /*initial_ref_count=*/1); + return RegisterInternal(std::move(shaped_buffer), tag); } -GlobalDataHandle AllocationTracker::RegisterInternal( - Backend* backend, int device_ordinal, se::DeviceMemoryBase device_memory, - const Shape& shape, const string& tag, int initial_ref_count) { +StatusOr AllocationTracker::RegisterInternal( + std::unique_ptr shaped_buffer, const string& tag) { VLOG(2) << "RegisterInternal(" << "tag: \"" << tag << "\" " - << "device_ordinal: " << device_ordinal << " " - << "device_memory: " << device_memory.opaque() << " " - << "shape: " << shape.ShortDebugString() << ")"; - TF_CHECK_OK(ShapeUtil::ValidateShape(shape)); - - int64 handle; - HandleMap& handle_map = GetOrCreateOpaqueToHandleMap(device_ordinal); - auto handle_it = handle_map.find(device_memory.opaque()); - if (handle_it != handle_map.end()) { - handle = handle_it->second; - auto& allocation = FindOrDie(handle_to_allocation_, handle); - int ref_count = allocation->ref_count(); - CHECK_GT(ref_count, 0); - VLOG(2) << "ref_count: " << ref_count << " -> " << - (ref_count + initial_ref_count); - allocation->increment_ref_count(initial_ref_count); - } else { - handle = next_handle_++; - VLOG(2) << "ref_count: " << initial_ref_count; - InsertOrDie(&handle_map, device_memory.opaque(), handle); - auto inserted = handle_to_allocation_.emplace( - handle, MakeUnique(backend, device_ordinal, device_memory, - shape, tag, initial_ref_count)); - CHECK(inserted.second); + << "shaped_buffer: " << *shaped_buffer; + if (shaped_buffer->platform() != backend_->platform()) { + return InvalidArgument( + "AllocationTracker for platform %s cannot register buffer from " + "platform %s", + backend_->platform()->Name().c_str(), + shaped_buffer->platform()->Name().c_str()); } + int64 handle = next_handle_++; + std::vector shape_indices; + ShapeUtil::ForEachSubshape(shaped_buffer->on_device_shape(), + [this, &shape_indices](const Shape& /*subshape*/, + const ShapeIndex& index) { + shape_indices.push_back(index); + }); + for (const ShapeIndex& index : shape_indices) { + AddAllocationOrIncrementRefCount(shaped_buffer->buffer(index), + shaped_buffer->device_ordinal()); + } GlobalDataHandle result; result.set_handle(handle); + + handle_to_shaped_buffer_[handle] = std::move(shaped_buffer); + VLOG(2) << "handle: " << handle; return result; } tensorflow::Status AllocationTracker::Unregister(const GlobalDataHandle& data) { - tensorflow::mutex_lock lock(allocation_mutex_); - TF_ASSIGN_OR_RETURN(Allocation * allocation, ResolveInternal(data)); - std::set deallocated_buffers; - TF_RETURN_IF_ERROR( - DeallocateShape(allocation->backend(), allocation->device_ordinal(), - allocation->mutable_device_memory(), allocation->shape(), - &deallocated_buffers)); - return tensorflow::Status::OK(); -} - -tensorflow::Status AllocationTracker::DeallocateShape( - Backend* backend, int device_ordinal, se::DeviceMemoryBase* device_memory, - const Shape& shape, std::set* deallocated_buffers) { - VLOG(2) << "DeallocateShape(" - << "shape: \"" << shape.ShortDebugString() << "\" " - << "device_memory: " << device_memory->opaque() << ")"; - if (ContainsKey(*deallocated_buffers, device_memory->opaque())) { - // Buffer has already been deallocated. Nothing to do. - VLOG(2) << "already deallocated"; - return tensorflow::Status::OK(); - } - - // Add buffer to deallocated set so we do not try to deallocate it again - // if it is encountered again while traversing a tuple. - deallocated_buffers->insert(device_memory->opaque()); - - HandleMap& handle_map = GetOrCreateOpaqueToHandleMap(device_ordinal); - auto handle_it = handle_map.find(device_memory->opaque()); - if (handle_it != handle_map.end()) { - int64 handle = handle_it->second; - auto& allocation = FindOrDie(handle_to_allocation_, handle); - int ref_count = allocation->ref_count(); - VLOG(2) << "ref_count: " << ref_count << " -> " << ref_count - 1; - allocation->decrement_ref_count(); - if (allocation->ref_count() > 0) { - // Buffer is referred to by another allocation. Don't deallocate it. - return tensorflow::Status::OK(); - } - handle_map.erase(device_memory->opaque()); + tensorflow::mutex_lock lock(mutex_); + VLOG(2) << "Unregister(" + << "handle: " << data.handle() << ")"; + TF_ASSIGN_OR_RETURN(ShapedBuffer * shaped_buffer, ResolveInternal(data)); + std::vector shape_indices; + ShapeUtil::ForEachSubshape(shaped_buffer->on_device_shape(), + [this, &shape_indices](const Shape& /*subshape*/, + const ShapeIndex& index) { + shape_indices.push_back(index); + }); + for (const ShapeIndex& index : shape_indices) { + TF_RETURN_IF_ERROR(DecrementRefCount(shaped_buffer->buffer(index), + shaped_buffer->device_ordinal())); } - if (ShapeUtil::IsTuple(shape)) { - // Traverse into tuple recursively deallocating buffers. - TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor, - backend->stream_executor(device_ordinal)); - TF_ASSIGN_OR_RETURN(std::vector elements, - backend->transfer_manager()->ShallowCopyTupleFromDevice( - executor, *device_memory, shape)); - - TF_RET_CHECK(ShapeUtil::TupleElementCount(shape) == elements.size()) - << "tuple has unexpected number of elements: " << elements.size() - << " != " << ShapeUtil::TupleElementCount(shape); - for (size_t i = 0; i < elements.size(); ++i) { - VLOG(2) << "recursing onto the tuple elements"; - TF_RETURN_IF_ERROR(DeallocateShape(backend, device_ordinal, &elements[i], - shape.tuple_shapes(i), - deallocated_buffers)); - } - } + // Keep a nullptr as a tombstone for unregistered handles. This enables better + // error messages. That is, "handle has been deallocated" versus "handle does + // not exist". + handle_to_shaped_buffer_.at(data.handle()).reset(); - return backend->memory_allocator()->Deallocate(device_ordinal, device_memory); + return tensorflow::Status::OK(); } StatusOr> AllocationTracker::DeconstructTuple( const GlobalDataHandle& data) { - tensorflow::mutex_lock lock(allocation_mutex_); - TF_ASSIGN_OR_RETURN(Allocation * allocation, ResolveInternal(data)); + tensorflow::mutex_lock lock(mutex_); - if (!ShapeUtil::IsTuple(allocation->shape())) { + TF_ASSIGN_OR_RETURN(ShapedBuffer * shaped_buffer, ResolveInternal(data)); + if (!ShapeUtil::IsTuple(shaped_buffer->on_host_shape())) { return InvalidArgument("global data handle %lld is not a tuple", data.handle()); } + // If the on-host representation is a tuple, then the on-device one should be + // as well. + TF_RET_CHECK(ShapeUtil::IsTuple(shaped_buffer->on_device_shape())); - if (ShapeUtil::IsNestedTuple(allocation->shape())) { + if (ShapeUtil::IsNestedTuple(shaped_buffer->on_device_shape())) { return Unimplemented("deconstructing nested tuples not yet supported"); } - TF_ASSIGN_OR_RETURN( - se::StreamExecutor * executor, - allocation->backend()->stream_executor(allocation->device_ordinal())); - TF_ASSIGN_OR_RETURN( - std::vector element_bases, - allocation->backend()->transfer_manager()->ShallowCopyTupleFromDevice( - executor, allocation->device_memory(), allocation->shape())); - std::vector element_handles; - element_handles.reserve(element_bases.size()); - for (int i = 0; i < element_bases.size(); ++i) { - element_handles.push_back(RegisterInternal( - allocation->backend(), allocation->device_ordinal(), element_bases[i], - ShapeUtil::GetSubshape(allocation->shape(), {i}), - tensorflow::strings::StrCat(allocation->tag(), ".element_", i), - /*initial_ref_count=*/2)); + for (int i = 0; + i < ShapeUtil::TupleElementCount(shaped_buffer->on_device_shape()); + ++i) { + auto element_buffer = MakeUnique( + ShapeUtil::GetTupleElementShape(shaped_buffer->on_host_shape(), i), + ShapeUtil::GetTupleElementShape(shaped_buffer->on_device_shape(), i), + shaped_buffer->platform(), shaped_buffer->device_ordinal()); + element_buffer->set_buffer(shaped_buffer->buffer(/*index=*/{i}), + /*index=*/{}); + TF_ASSIGN_OR_RETURN( + GlobalDataHandle element_handle, + RegisterInternal(std::move(element_buffer), "deconstructed tuple")); + + element_handles.push_back(element_handle); } return std::move(element_handles); } -StatusOr AllocationTracker::Resolve( +StatusOr AllocationTracker::Resolve( const GlobalDataHandle& data) { - tensorflow::mutex_lock lock(allocation_mutex_); + tensorflow::mutex_lock lock(mutex_); return AllocationTracker::ResolveInternal(data); } -StatusOr AllocationTracker::ResolveInternal( +StatusOr AllocationTracker::ResolveInternal( const GlobalDataHandle& data) { VLOG(2) << "resolve:" << data.handle(); - auto it = handle_to_allocation_.find(data.handle()); - if (it == handle_to_allocation_.end()) { + auto it = handle_to_shaped_buffer_.find(data.handle()); + if (it == handle_to_shaped_buffer_.end()) { return NotFound("no allocation record for global data handle: %lld", data.handle()); } - Allocation* allocation = it->second.get(); + ShapedBuffer* shaped_buffer = it->second.get(); - if (allocation->is_deallocated()) { + if (shaped_buffer == nullptr) { return InvalidArgument("global data handle %lld was previously deallocated", data.handle()); } - return allocation; + return shaped_buffer; +} + +void AllocationTracker::AddAllocationOrIncrementRefCount( + perftools::gputools::DeviceMemoryBase device_memory, int device_ordinal) { + AllocationMap& allocation_map = opaque_to_allocation_map_[device_ordinal]; + auto it = allocation_map.find(device_memory.opaque()); + if (it == allocation_map.end()) { + allocation_map[device_memory.opaque()] = {device_memory, device_ordinal, + /*ref_count=*/1}; + } else { + it->second.ref_count++; + } } -AllocationTracker::HandleMap& AllocationTracker::GetOrCreateOpaqueToHandleMap( - int device_ordinal) { - if (opaque_to_handle_.size() <= device_ordinal) { - opaque_to_handle_.resize(device_ordinal + 1); +Status AllocationTracker::DecrementRefCount( + perftools::gputools::DeviceMemoryBase device_memory, int device_ordinal) { + AllocationMap& allocation_map = opaque_to_allocation_map_[device_ordinal]; + auto it = allocation_map.find(device_memory.opaque()); + TF_RET_CHECK(it != allocation_map.end()); + Allocation& allocation = it->second; + TF_RET_CHECK(allocation.ref_count >= 1); + if (allocation.ref_count == 1) { + TF_RETURN_IF_ERROR(backend_->memory_allocator()->Deallocate( + device_ordinal, &device_memory)); + allocation_map.erase(it); + } else { + allocation.ref_count--; } - return opaque_to_handle_[device_ordinal]; + return tensorflow::Status::OK(); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/allocation_tracker.h b/tensorflow/compiler/xla/service/allocation_tracker.h index ebbf35b6fe87bc7322ccb99cfe8f8eed56de06b3..807af8694972083d097604a67ee46d2f73d9545a 100644 --- a/tensorflow/compiler/xla/service/allocation_tracker.h +++ b/tensorflow/compiler/xla/service/allocation_tracker.h @@ -28,147 +28,92 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/mutex.h" -#include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/thread_annotations.h" #include "tensorflow/core/platform/types.h" namespace xla { -// A global allocation in device space, tracked by the XLA service. -class Allocation { - public: - Allocation(Backend* backend, int device_ordinal, - perftools::gputools::DeviceMemoryBase device_memory, - const Shape& shape, const string& tag, int initial_ref_count) - : backend_(backend), - device_ordinal_(device_ordinal), - device_memory_(device_memory), - shape_(shape), - tag_(tag), - ref_count_(initial_ref_count) {} - - Backend* backend() const { return backend_; } - int device_ordinal() const { return device_ordinal_; } - perftools::gputools::DeviceMemoryBase device_memory() const { - return device_memory_; - } - const Shape& shape() const { return shape_; } - const string& tag() const { return tag_; } - - bool is_deallocated() const { - CHECK_GE(ref_count_, 0); - return ref_count_ == 0; - } - int ref_count() const { - CHECK_GE(ref_count_, 0); - return ref_count_; - } - void increment_ref_count(int inc) { - CHECK_GT(ref_count_, 0); - CHECK_LE(ref_count_, INT_MAX - inc); - ref_count_ += inc; - } - void decrement_ref_count() { - CHECK_GT(ref_count_, 0); - --ref_count_; - } - perftools::gputools::DeviceMemoryBase* mutable_device_memory() { - return &device_memory_; - } - - private: - // The backend that the memory is allocated on. - Backend* backend_; - - // The device that the memory is allocated on. - int device_ordinal_; - - // The pointer to this allocation. - perftools::gputools::DeviceMemoryBase device_memory_; - - // The shape of this allocation. - Shape shape_; - - // An informal description of this allocation shown in tools. - string tag_; - - // This is the number of Allocation objects which refer to this memory - // allocation. - int ref_count_; - - // Return a string representation of this allocation for debugging or logging - // purposes. - string ToString() const; -}; - // Tracks allocations for the XLA service; allocations can be registered // with shape/device/tag and resolved from a handle for later use. class AllocationTracker { public: - AllocationTracker(); + // The allocator is used for deallocating memory when allocations are + // deregistered. All registered allocations must have the same platform as the + // allocator. + AllocationTracker(Backend* backend) : backend_(backend), next_handle_(1) {} - // Registers device memory with a given shape, device identifier, and tag, and - // returns a corresponding handle that can be used for talking to XLA - // clients. - GlobalDataHandle Register(Backend* backend, int device_ordinal, - perftools::gputools::DeviceMemoryBase device_memory, - const Shape& shape, const string& tag); + // Registers a shaped buffer of device memory, and returns a corresponding + // handle that can be used for talking to XLA clients. + StatusOr Register( + std::unique_ptr shaped_buffer, const string& tag); // Unregister the allocation for the given data handle. - tensorflow::Status Unregister(const GlobalDataHandle& data); + Status Unregister(const GlobalDataHandle& data); // Returns a vector of global data handles that point to the tuple elements. StatusOr> DeconstructTuple( const GlobalDataHandle& Data); - // Resolve a handle from an XLA client to an allocation, or provide an - // error status to say whether it was not found (or found, but found - // deallocated). - StatusOr Resolve(const GlobalDataHandle& data); + // Resolve a handle from an XLA client to a shaped buffer, or provide an error + // status to say whether it was not found (or found, but found deallocated). + StatusOr Resolve(const GlobalDataHandle& data); private: - // Internal helper which resolves the given GlobalDataHandle to an Allocation. - StatusOr ResolveInternal(const GlobalDataHandle& data) - EXCLUSIVE_LOCKS_REQUIRED(allocation_mutex_); - - GlobalDataHandle RegisterInternal( - Backend* backend, int device_ordinal, - perftools::gputools::DeviceMemoryBase device_memory, const Shape& shape, - const string& tag, int initial_ref_count) - EXCLUSIVE_LOCKS_REQUIRED(allocation_mutex_); - - // Helper function which deallocates the memory buffer containing the given - // shape referred to by device_memory. Tuples are traversed recursively - // deallocating all nested buffers. The parameter deallocated_buffers contains - // the set of buffers deallocated so far stored as opaque values (void *) from - // DeviceMemoryBase. Keeping track of deallocated buffers prevents - // double-freeing of buffers which may be referred to more than once in a - // nested tuple. - tensorflow::Status DeallocateShape( - Backend* backend, int device_ordinal, - perftools::gputools::DeviceMemoryBase* device_memory, const Shape& shape, - std::set* deallocated_buffers) - EXCLUSIVE_LOCKS_REQUIRED(allocation_mutex_); - - // Returns the opaque_to_handle_ map for the given device_ordinal, creating - // a new map if there is not one for the device_ordinal. - using HandleMap = std::map; - HandleMap& GetOrCreateOpaqueToHandleMap(int device_ordinal) - EXCLUSIVE_LOCKS_REQUIRED(allocation_mutex_); - - tensorflow::mutex allocation_mutex_; // Guards the allocation mapping. + // Data structure encapsulating single memory allocation on the device. + struct Allocation { + // The pointer to this allocation. + perftools::gputools::DeviceMemoryBase device_memory; + + // The device that the memory is allocated on. + int device_ordinal; + + // This is the number of times this memory allocation is referred to by + // registered data handles. + int ref_count; + }; + + // Internal helper which resolves the given GlobalDataHandle to a + // ShapedBuffer. + StatusOr ResolveInternal(const GlobalDataHandle& data) + EXCLUSIVE_LOCKS_REQUIRED(mutex_); + + // Internal helper which registers a shaped buffer. + StatusOr RegisterInternal( + std::unique_ptr shaped_buffer, const string& tag) + EXCLUSIVE_LOCKS_REQUIRED(mutex_); + + // Adds the given device address to the allocation tracker, or if it already + // exists, then increment it's reference count. + void AddAllocationOrIncrementRefCount( + perftools::gputools::DeviceMemoryBase device_memory, int device_ordinal) + EXCLUSIVE_LOCKS_REQUIRED(mutex_); + + // Decrements the reference count of the given device memory. Then, if it is + // zero, deallocate the memory. + Status DecrementRefCount(perftools::gputools::DeviceMemoryBase device_memory, + int device_ordinal) EXCLUSIVE_LOCKS_REQUIRED(mutex_); + + // A map from device memory opaque value to allocation. One such map is + // maintained per device ordinal. + using AllocationMap = tensorflow::gtl::FlatMap; + + tensorflow::mutex mutex_; + + // Backend to use with this tracker. The backend supplies the memory allocator + // to use when deallocating memory. + Backend* backend_; // The next handle to assign to an allocation, guarded by the same mutex as // the mapping as they'll be mutated at the same time. - int64 next_handle_ GUARDED_BY(allocation_mutex_); + int64 next_handle_ GUARDED_BY(mutex_); - // A map from DeviceMemoryBase to handle for each device_ordinal. - std::vector opaque_to_handle_ GUARDED_BY(allocation_mutex_); + // A map from device ordinal to AllocationMap. + tensorflow::gtl::FlatMap opaque_to_allocation_map_ + GUARDED_BY(mutex_); - // Mapping from GlobalDataHandle handle to the corresponding registered - // Allocation object. - std::map> handle_to_allocation_ - GUARDED_BY(allocation_mutex_); + // A map from data handle to ShapedBuffer. + tensorflow::gtl::FlatMap> + handle_to_shaped_buffer_ GUARDED_BY(mutex_); TF_DISALLOW_COPY_AND_ASSIGN(AllocationTracker); }; diff --git a/tensorflow/compiler/xla/service/batchnorm_rewriter.cc b/tensorflow/compiler/xla/service/batchnorm_expander.cc similarity index 58% rename from tensorflow/compiler/xla/service/batchnorm_rewriter.cc rename to tensorflow/compiler/xla/service/batchnorm_expander.cc index c6193b3fbbd651088a823605af3ba84bca4a77ee..27ddfd47aa3096afd3e245af1ac3cedd9b48ce4a 100644 --- a/tensorflow/compiler/xla/service/batchnorm_rewriter.cc +++ b/tensorflow/compiler/xla/service/batchnorm_expander.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/batchnorm_rewriter.h" +#include "tensorflow/compiler/xla/service/batchnorm_expander.h" #include #include @@ -45,9 +45,9 @@ limitations under the License. namespace xla { -// BatchNormRewriterVisitor traverses the HLO computation and rewrites BatchNorm +// BatchNormExpanderVisitor traverses the HLO computation and rewrites BatchNorm // operations into smaller operations. -class BatchNormRewriterVisitor : public DfsHloVisitorWithDefault { +class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault { public: // Default visitor action is to do nothing and return OK. Status DefaultAction(HloInstruction* /*hlo_instruction*/) override { @@ -68,10 +68,10 @@ class BatchNormRewriterVisitor : public DfsHloVisitorWithDefault { // Returns whether any batch norm ops were rewritten. const bool changed() const { return changed_; } - ~BatchNormRewriterVisitor() override = default; + ~BatchNormExpanderVisitor() override = default; private: - explicit BatchNormRewriterVisitor(HloComputation* computation, + explicit BatchNormExpanderVisitor(HloComputation* computation, bool rewrite_training_op, bool rewrite_inference_op, bool rewrite_grad_op, bool use_fusion) @@ -94,7 +94,7 @@ class BatchNormRewriterVisitor : public DfsHloVisitorWithDefault { return computation_->parent()->AddEmbeddedComputation(b.Build(scalar_op)); } - // Current HloComputation instance the BatchNormRewriter is + // Current HloComputation instance the BatchNormExpander is // traversing. HloComputation* computation_; @@ -130,11 +130,11 @@ class BatchNormRewriterVisitor : public DfsHloVisitorWithDefault { } }; -bool BatchNormRewriterVisitor::Run(HloComputation* computation, +bool BatchNormExpanderVisitor::Run(HloComputation* computation, bool rewrite_training_op, bool rewrite_inference_op, bool rewrite_grad_op, bool use_fusion) { - BatchNormRewriterVisitor visitor( + BatchNormExpanderVisitor visitor( computation, /*rewrite_training_op=*/rewrite_training_op, /*rewrite_inference_op=*/rewrite_inference_op, @@ -144,11 +144,20 @@ bool BatchNormRewriterVisitor::Run(HloComputation* computation, return visitor.changed_; } -Status BatchNormRewriterVisitor::HandleBatchNormTraining( +Status BatchNormExpanderVisitor::HandleBatchNormTraining( HloInstruction* batch_norm) { if (!rewrite_training_op_) { return Status::OK(); } + + std::vector added_instructions; + auto add = [&](std::unique_ptr inst) { + HloInstruction* added_inst = computation_->AddInstruction(std::move(inst)); + added_instructions.push_back(added_inst); + return added_inst; + }; + int64 instruction_count_before = computation_->instruction_count(); + // Expand batch norm training into smaller HLO ops. HloInstruction* operand = batch_norm->mutable_operand(0); const Shape operand_shape = operand->shape(); @@ -160,7 +169,7 @@ Status BatchNormRewriterVisitor::HandleBatchNormTraining( Literal::CreateR0(size_in_elements / feature_count); TF_ASSIGN_OR_RETURN(elements_per_feature_literal, elements_per_feature_literal->Convert(ptype)); - auto elements_per_feature = computation_->AddInstruction( + auto elements_per_feature = add( HloInstruction::CreateConstant(std::move(elements_per_feature_literal))); HloInstruction* scale = batch_norm->mutable_operand(1); @@ -169,14 +178,12 @@ Status BatchNormRewriterVisitor::HandleBatchNormTraining( auto zero_literal = Literal::CreateR0(0.0f); TF_ASSIGN_OR_RETURN(zero_literal, zero_literal->Convert(ptype)); - auto zero = computation_->AddInstruction( - HloInstruction::CreateConstant(std::move(zero_literal))); + auto zero = add(HloInstruction::CreateConstant(std::move(zero_literal))); auto epsilon_literal = Literal::CreateR0(batch_norm->epsilon()); TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype)); - auto epsilon = computation_->AddInstruction( - HloInstruction::CreateConstant(std::move(epsilon_literal))); - + auto epsilon = + add(HloInstruction::CreateConstant(std::move(epsilon_literal))); std::vector dimensions_without_feature; for (int64 i = 0; i < ShapeUtil::Rank(operand_shape); ++i) { @@ -185,109 +192,116 @@ Status BatchNormRewriterVisitor::HandleBatchNormTraining( } } - auto scale_broadcasted = computation_->AddInstruction( + auto scale_broadcasted = add( HloInstruction::CreateBroadcast(operand_shape, scale, {feature_index})); - auto offset_broadcasted = computation_->AddInstruction( + auto offset_broadcasted = add( HloInstruction::CreateBroadcast(operand_shape, offset, {feature_index})); HloComputation* add_reduce_computation = GetScalarBinaryComputation(ptype, HloOpcode::kAdd); // X^2. - auto operand_squared = - computation_->AddInstruction(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kMultiply, operand, operand)); + auto operand_squared = add(HloInstruction::CreateBinary( + operand_shape, HloOpcode::kMultiply, operand, operand)); // Sum[X]. - auto sum = computation_->AddInstruction(HloInstruction::CreateReduce( - feature_shape, operand, zero, dimensions_without_feature, - add_reduce_computation)); + auto sum = add(HloInstruction::CreateReduce(feature_shape, operand, zero, + dimensions_without_feature, + add_reduce_computation)); // Sum[X^2]. - auto squared_sum = computation_->AddInstruction(HloInstruction::CreateReduce( + auto squared_sum = add(HloInstruction::CreateReduce( feature_shape, operand_squared, zero, dimensions_without_feature, add_reduce_computation)); // Fuse two parallel reduces together to improve performance. - if (use_fusion_) { - auto tuple = computation_->AddInstruction( - HloInstruction::CreateTuple({sum, squared_sum})); + if (use_fusion_ && !batch_norm->has_sharding()) { + auto tuple = add(HloInstruction::CreateTuple({sum, squared_sum})); auto fused = computation_->CreateFusionInstruction( {tuple, sum, squared_sum, operand_squared}, HloInstruction::FusionKind::kInput); - sum = computation_->AddInstruction( - HloInstruction::CreateGetTupleElement(feature_shape, fused, 0)); + sum = add(HloInstruction::CreateGetTupleElement(feature_shape, fused, 0)); - squared_sum = computation_->AddInstruction( - HloInstruction::CreateGetTupleElement(feature_shape, fused, 1)); + squared_sum = + add(HloInstruction::CreateGetTupleElement(feature_shape, fused, 1)); } // E[X]. - auto mean = computation_->AddInstruction(HloInstruction::CreateBinary( + auto mean = add(HloInstruction::CreateBinary( feature_shape, HloOpcode::kDivide, sum, elements_per_feature)); - auto mean_broadcasted = computation_->AddInstruction( + auto mean_broadcasted = add( HloInstruction::CreateBroadcast(operand_shape, mean, {feature_index})); // E[X^2]. - auto square_mean = computation_->AddInstruction(HloInstruction::CreateBinary( + auto square_mean = add(HloInstruction::CreateBinary( feature_shape, HloOpcode::kDivide, squared_sum, elements_per_feature)); // E^2[X]. - auto mean_square = computation_->AddInstruction(HloInstruction::CreateBinary( + auto mean_square = add(HloInstruction::CreateBinary( feature_shape, HloOpcode::kMultiply, mean, mean)); // Var[X]. - auto var = computation_->AddInstruction(HloInstruction::CreateBinary( + auto var = add(HloInstruction::CreateBinary( feature_shape, HloOpcode::kSubtract, square_mean, mean_square)); - auto var_broadcasted = computation_->AddInstruction( - HloInstruction::CreateBroadcast(operand_shape, var, {feature_index})); + auto var_broadcasted = + add(HloInstruction::CreateBroadcast(operand_shape, var, {feature_index})); // Var[X] + epsilon. - auto var_add_epsilon = - computation_->AddInstruction(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon)); + auto var_add_epsilon = add(HloInstruction::CreateBinary( + operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon)); auto neg_half_literal = Literal::CreateR0(-0.5f); TF_ASSIGN_OR_RETURN(neg_half_literal, neg_half_literal->Convert(ptype)); - auto neg_half = computation_->AddInstruction( - HloInstruction::CreateConstant(std::move(neg_half_literal))); + auto neg_half = + add(HloInstruction::CreateConstant(std::move(neg_half_literal))); // 1 / Sqrt[Var[X] + epsilon]. - auto rsqrt_var_add_epsilon = - computation_->AddInstruction(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kPower, var_add_epsilon, neg_half)); + auto rsqrt_var_add_epsilon = add(HloInstruction::CreateBinary( + operand_shape, HloOpcode::kPower, var_add_epsilon, neg_half)); // X - E[X]. - auto operand_minus_mean = - computation_->AddInstruction(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kSubtract, operand, mean_broadcasted)); + auto operand_minus_mean = add(HloInstruction::CreateBinary( + operand_shape, HloOpcode::kSubtract, operand, mean_broadcasted)); // (X - E[X]) / Sqrt[Var[X] + epsilon]. - auto normalized = computation_->AddInstruction( + auto normalized = add( HloInstruction::CreateBinary(operand_shape, HloOpcode::kMultiply, operand_minus_mean, rsqrt_var_add_epsilon)); // (X - E[X]) / Sqrt[Var[X] + epsilon] * scale. - auto scaled_normalized = - computation_->AddInstruction(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kMultiply, normalized, scale_broadcasted)); + auto scaled_normalized = add(HloInstruction::CreateBinary( + operand_shape, HloOpcode::kMultiply, normalized, scale_broadcasted)); // (X - E[X]) / Sqrt[Var[X] + epsilon] * scale + offset. - auto shifted_normalized = computation_->AddInstruction( - HloInstruction::CreateBinary(operand_shape, HloOpcode::kAdd, - scaled_normalized, offset_broadcasted)); - - TF_CHECK_OK(ReplaceWithNewInstruction( - batch_norm, - HloInstruction::CreateTuple({shifted_normalized, mean, var}))); + auto shifted_normalized = add(HloInstruction::CreateBinary( + operand_shape, HloOpcode::kAdd, scaled_normalized, offset_broadcasted)); + + auto tuple = HloInstruction::CreateTuple({shifted_normalized, mean, var}); + + if (batch_norm->has_sharding()) { + int64 instruction_count_after = computation_->instruction_count(); + CHECK_EQ(instruction_count_after, + instruction_count_before + added_instructions.size()); + HloSharding operand_sharding = + batch_norm->sharding().GetAsShapeTree(batch_norm->shape()).element({0}); + for (HloInstruction* inst : added_instructions) { + if (ShapeUtil::Equal(inst->shape(), operand_shape)) { + inst->set_sharding(operand_sharding); + } else { + inst->set_sharding(HloSharding::Replicate()); + } + } + tuple->set_sharding(batch_norm->sharding()); + } + TF_CHECK_OK(ReplaceWithNewInstruction(batch_norm, std::move(tuple))); return Status::OK(); } -Status BatchNormRewriterVisitor::HandleBatchNormInference( +Status BatchNormExpanderVisitor::HandleBatchNormInference( HloInstruction* batch_norm) { if (!rewrite_inference_op_) { return Status::OK(); @@ -317,58 +331,75 @@ Status BatchNormRewriterVisitor::HandleBatchNormInference( } } - auto scale_broadcasted = computation_->AddInstruction( + std::vector added_instructions; + auto add = [&](std::unique_ptr inst) { + HloInstruction* added_inst = computation_->AddInstruction(std::move(inst)); + added_instructions.push_back(added_inst); + return added_inst; + }; + int64 instruction_count_before = computation_->instruction_count(); + + auto scale_broadcasted = add( HloInstruction::CreateBroadcast(operand_shape, scale, {feature_index})); - auto offset_broadcasted = computation_->AddInstruction( + auto offset_broadcasted = add( HloInstruction::CreateBroadcast(operand_shape, offset, {feature_index})); - auto mean_broadcasted = computation_->AddInstruction( + auto mean_broadcasted = add( HloInstruction::CreateBroadcast(operand_shape, mean, {feature_index})); - auto var_broadcasted = computation_->AddInstruction( - HloInstruction::CreateBroadcast(operand_shape, var, {feature_index})); + auto var_broadcasted = + add(HloInstruction::CreateBroadcast(operand_shape, var, {feature_index})); // Var[X] + epsilon. - auto var_add_epsilon = - computation_->AddInstruction(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon)); + auto var_add_epsilon = add(HloInstruction::CreateBinary( + operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon)); auto neg_half_literal = Literal::CreateR0(-0.5f); TF_ASSIGN_OR_RETURN(neg_half_literal, neg_half_literal->Convert(ptype)); - auto neg_half = computation_->AddInstruction( - HloInstruction::CreateConstant(std::move(neg_half_literal))); + auto neg_half = + add(HloInstruction::CreateConstant(std::move(neg_half_literal))); // 1 / Sqrt[Var[X] + epsilon]. - auto rsqrt_var_add_epsilon = - computation_->AddInstruction(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kPower, var_add_epsilon, neg_half)); + auto rsqrt_var_add_epsilon = add(HloInstruction::CreateBinary( + operand_shape, HloOpcode::kPower, var_add_epsilon, neg_half)); // X - E[X]. - auto operand_minus_mean = - computation_->AddInstruction(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kSubtract, operand, mean_broadcasted)); + auto operand_minus_mean = add(HloInstruction::CreateBinary( + operand_shape, HloOpcode::kSubtract, operand, mean_broadcasted)); // (X - E[X]) / Sqrt[Var[X] + epsilon]. - auto normalized = computation_->AddInstruction( + auto normalized = add( HloInstruction::CreateBinary(operand_shape, HloOpcode::kMultiply, operand_minus_mean, rsqrt_var_add_epsilon)); // (X - E[X]) / Sqrt[Var[X] + epsilon] * scale. - auto scaled_normalized = - computation_->AddInstruction(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kMultiply, normalized, scale_broadcasted)); + auto scaled_normalized = add(HloInstruction::CreateBinary( + operand_shape, HloOpcode::kMultiply, normalized, scale_broadcasted)); // (X - E[X]) / Sqrt[Var[X] + epsilon] * scale + offset. auto shifted_normalized = HloInstruction::CreateBinary( operand_shape, HloOpcode::kAdd, scaled_normalized, offset_broadcasted); + int64 instruction_count_after = computation_->instruction_count(); + CHECK_EQ(instruction_count_after, + instruction_count_before + added_instructions.size()); + if (batch_norm->has_sharding()) { + for (HloInstruction* inst : added_instructions) { + if (ShapeUtil::Equal(inst->shape(), operand_shape)) { + inst->set_sharding(batch_norm->sharding()); + } else { + inst->set_sharding(HloSharding::Replicate()); + } + } + shifted_normalized->set_sharding(batch_norm->sharding()); + } TF_CHECK_OK( ReplaceWithNewInstruction(batch_norm, std::move(shifted_normalized))); return Status::OK(); } -Status BatchNormRewriterVisitor::HandleBatchNormGrad( +Status BatchNormExpanderVisitor::HandleBatchNormGrad( HloInstruction* batch_norm) { // Use the following formulas to calculate gradients: // scale_grad = @@ -385,6 +416,13 @@ Status BatchNormRewriterVisitor::HandleBatchNormGrad( if (!rewrite_grad_op_) { return Status::OK(); } + std::vector added_instructions; + auto add = [&](std::unique_ptr inst) { + HloInstruction* added_inst = computation_->AddInstruction(std::move(inst)); + added_instructions.push_back(added_inst); + return added_inst; + }; + int64 instruction_count_before = computation_->instruction_count(); HloInstruction* activation = batch_norm->mutable_operand(0); const Shape activation_shape = activation->shape(); @@ -403,23 +441,22 @@ Status BatchNormRewriterVisitor::HandleBatchNormGrad( Literal::CreateR0(size_in_elements / feature_count); TF_ASSIGN_OR_RETURN(elements_per_feature_literal, elements_per_feature_literal->Convert(ptype)); - auto elements_per_feature = computation_->AddInstruction( + auto elements_per_feature = add( HloInstruction::CreateConstant(std::move(elements_per_feature_literal))); auto zero_literal = Literal::CreateR0(0.0f); TF_ASSIGN_OR_RETURN(zero_literal, zero_literal->Convert(ptype)); - auto zero = computation_->AddInstruction( - HloInstruction::CreateConstant(std::move(zero_literal))); + auto zero = add(HloInstruction::CreateConstant(std::move(zero_literal))); auto neg_half_literal = Literal::CreateR0(-0.5f); TF_ASSIGN_OR_RETURN(neg_half_literal, neg_half_literal->Convert(ptype)); - auto neg_half = computation_->AddInstruction( - HloInstruction::CreateConstant(std::move(neg_half_literal))); + auto neg_half = + add(HloInstruction::CreateConstant(std::move(neg_half_literal))); auto epsilon_literal = Literal::CreateR0(batch_norm->epsilon()); TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype)); - auto epsilon = computation_->AddInstruction( - HloInstruction::CreateConstant(std::move(epsilon_literal))); + auto epsilon = + add(HloInstruction::CreateConstant(std::move(epsilon_literal))); std::vector dimensions_without_feature; @@ -429,141 +466,148 @@ Status BatchNormRewriterVisitor::HandleBatchNormGrad( } } - auto scale_broadcasted = - computation_->AddInstruction(HloInstruction::CreateBroadcast( - activation_shape, scale, {feature_index})); - auto variance_broadcasted = - computation_->AddInstruction(HloInstruction::CreateBroadcast( - activation_shape, variance, {feature_index})); + auto scale_broadcasted = add(HloInstruction::CreateBroadcast( + activation_shape, scale, {feature_index})); + auto variance_broadcasted = add(HloInstruction::CreateBroadcast( + activation_shape, variance, {feature_index})); // E[X]. - auto mean_broadcasted = computation_->AddInstruction( + auto mean_broadcasted = add( HloInstruction::CreateBroadcast(activation_shape, mean, {feature_index})); // rsqrt[Var[X] + epsilon]. - auto rsqrt_var_add_epsilon_broadcasted = - computation_->AddInstruction(HloInstruction::CreateBinary( - activation_shape, HloOpcode::kPower, - computation_->AddInstruction( - HloInstruction::CreateBinary(activation_shape, HloOpcode::kAdd, - variance_broadcasted, epsilon)), - neg_half)); - - auto rsqrt_var_add_epsilon = - computation_->AddInstruction(HloInstruction::CreateBinary( - feature_shape, HloOpcode::kPower, - computation_->AddInstruction(HloInstruction::CreateBinary( - feature_shape, HloOpcode::kAdd, variance, epsilon)), - neg_half)); + auto rsqrt_var_add_epsilon_broadcasted = add(HloInstruction::CreateBinary( + activation_shape, HloOpcode::kPower, + add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kAdd, + variance_broadcasted, epsilon)), + neg_half)); + + auto rsqrt_var_add_epsilon = add(HloInstruction::CreateBinary( + feature_shape, HloOpcode::kPower, + add(HloInstruction::CreateBinary(feature_shape, HloOpcode::kAdd, variance, + epsilon)), + neg_half)); // X - E[X]. - auto activation_minus_mean = computation_->AddInstruction( - HloInstruction::CreateBinary(activation_shape, HloOpcode::kSubtract, - activation, mean_broadcasted)); + auto activation_minus_mean = add(HloInstruction::CreateBinary( + activation_shape, HloOpcode::kSubtract, activation, mean_broadcasted)); // Grad[Y] * (X - E[X]). - auto grad_output_times_activiation_minus_mean = computation_->AddInstruction( - HloInstruction::CreateBinary(activation_shape, HloOpcode::kMultiply, - grad_output, activation_minus_mean)); + auto grad_output_times_activiation_minus_mean = + add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kMultiply, + grad_output, activation_minus_mean)); HloComputation* add_reduce_computation = GetScalarBinaryComputation(ptype, HloOpcode::kAdd); // sum(Grad[Y] * (X - E[X])). auto sum_grad_output_times_activiation_minus_mean = - computation_->AddInstruction(HloInstruction::CreateReduce( + add(HloInstruction::CreateReduce( feature_shape, grad_output_times_activiation_minus_mean, zero, dimensions_without_feature, add_reduce_computation)); // Grad[beta] = Sum(Grad[Y]). - auto grad_beta = computation_->AddInstruction(HloInstruction::CreateReduce( + auto grad_beta = add(HloInstruction::CreateReduce( feature_shape, grad_output, zero, dimensions_without_feature, add_reduce_computation)); - if (use_fusion_) { - auto tuple = computation_->AddInstruction(HloInstruction::CreateTuple( + if (use_fusion_ && !batch_norm->has_sharding()) { + auto tuple = add(HloInstruction::CreateTuple( {sum_grad_output_times_activiation_minus_mean, grad_beta})); auto fused = computation_->CreateFusionInstruction( {tuple, sum_grad_output_times_activiation_minus_mean, grad_beta}, HloInstruction::FusionKind::kInput); - sum_grad_output_times_activiation_minus_mean = computation_->AddInstruction( - HloInstruction::CreateGetTupleElement(feature_shape, fused, 0)); + sum_grad_output_times_activiation_minus_mean = + add(HloInstruction::CreateGetTupleElement(feature_shape, fused, 0)); - grad_beta = computation_->AddInstruction( - HloInstruction::CreateGetTupleElement(feature_shape, fused, 1)); + grad_beta = + add(HloInstruction::CreateGetTupleElement(feature_shape, fused, 1)); } // Grad[scale] = Sum(Grad[Y] * (X - E[X]) * rsqrt[Var[X] + epsilon]). - auto grad_scale = computation_->AddInstruction(HloInstruction::CreateBinary( + auto grad_scale = add(HloInstruction::CreateBinary( feature_shape, HloOpcode::kMultiply, sum_grad_output_times_activiation_minus_mean, rsqrt_var_add_epsilon)); // I2 = Sum(Grad[Y]) - auto I2 = computation_->AddInstruction(HloInstruction::CreateBroadcast( - activation_shape, grad_beta, {feature_index})); + auto i2 = add(HloInstruction::CreateBroadcast(activation_shape, grad_beta, + {feature_index})); // I3 = Sum(Grad[Y] * (X - E[X])) - auto I3 = computation_->AddInstruction(HloInstruction::CreateBroadcast( + auto i3 = add(HloInstruction::CreateBroadcast( activation_shape, sum_grad_output_times_activiation_minus_mean, {feature_index})); // I4 = (X - E[X]) * I3 - auto I4 = computation_->AddInstruction(HloInstruction::CreateBinary( - activation_shape, HloOpcode::kMultiply, I3, activation_minus_mean)); + auto i4 = add(HloInstruction::CreateBinary( + activation_shape, HloOpcode::kMultiply, i3, activation_minus_mean)); // I5 = I4 / (Var[X] + epsilon) - auto I5 = computation_->AddInstruction(HloInstruction::CreateBinary( - activation_shape, HloOpcode::kDivide, I4, - computation_->AddInstruction(HloInstruction::CreateBinary( - activation_shape, HloOpcode::kAdd, variance_broadcasted, epsilon)))); + auto i5 = add(HloInstruction::CreateBinary( + activation_shape, HloOpcode::kDivide, i4, + add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kAdd, + variance_broadcasted, epsilon)))); // scale * rsqrt[Var[X] + epsilon] * 1/N - auto scale_times_rsqrt_var_add_epsilon = - computation_->AddInstruction(HloInstruction::CreateBinary( - activation_shape, HloOpcode::kMultiply, scale_broadcasted, - rsqrt_var_add_epsilon_broadcasted)); + auto scale_times_rsqrt_var_add_epsilon = add(HloInstruction::CreateBinary( + activation_shape, HloOpcode::kMultiply, scale_broadcasted, + rsqrt_var_add_epsilon_broadcasted)); - scale_times_rsqrt_var_add_epsilon = - computation_->AddInstruction(HloInstruction::CreateBinary( - activation_shape, HloOpcode::kDivide, - scale_times_rsqrt_var_add_epsilon, elements_per_feature)); + scale_times_rsqrt_var_add_epsilon = add(HloInstruction::CreateBinary( + activation_shape, HloOpcode::kDivide, scale_times_rsqrt_var_add_epsilon, + elements_per_feature)); - auto I1 = computation_->AddInstruction( - HloInstruction::CreateBinary(activation_shape, HloOpcode::kMultiply, - grad_output, elements_per_feature)); + auto i1 = + add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kMultiply, + grad_output, elements_per_feature)); // I6 = I1 - I2 - I5 - auto I6 = computation_->AddInstruction(HloInstruction::CreateBinary( + auto i6 = add(HloInstruction::CreateBinary( activation_shape, HloOpcode::kSubtract, - computation_->AddInstruction(HloInstruction::CreateBinary( - activation_shape, HloOpcode::kSubtract, I1, I2)), - I5)); + add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kSubtract, + i1, i2)), + i5)); // Grad[X] = scale * rsqrt[Var[X] + epsilon] * 1/N * I6. - auto grad_activation = computation_->AddInstruction( - HloInstruction::CreateBinary(activation_shape, HloOpcode::kMultiply, - scale_times_rsqrt_var_add_epsilon, I6)); + auto grad_activation = + add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kMultiply, + scale_times_rsqrt_var_add_epsilon, i6)); + auto tuple = + HloInstruction::CreateTuple({grad_activation, grad_scale, grad_beta}); + if (batch_norm->has_sharding()) { + int64 instruction_count_after = computation_->instruction_count(); + CHECK_EQ(instruction_count_after, + instruction_count_before + added_instructions.size()); + HloSharding activation_sharding = + batch_norm->sharding().GetAsShapeTree(batch_norm->shape()).element({0}); + for (HloInstruction* inst : added_instructions) { + if (ShapeUtil::Equal(inst->shape(), activation_shape)) { + inst->set_sharding(activation_sharding); + } else { + inst->set_sharding(HloSharding::Replicate()); + } + } + tuple->set_sharding(batch_norm->sharding()); + } - TF_CHECK_OK(ReplaceWithNewInstruction( - batch_norm, - HloInstruction::CreateTuple({grad_activation, grad_scale, grad_beta}))); + TF_CHECK_OK(ReplaceWithNewInstruction(batch_norm, std::move(tuple))); return Status::OK(); } -StatusOr BatchNormRewriter::Run(HloModule* module) { - XLA_VLOG_LINES(2, "BatchNormRewriter::Run(), before:\n" + module->ToString()); +StatusOr BatchNormExpander::Run(HloModule* module) { + XLA_VLOG_LINES(2, "BatchNormExpander::Run(), before:\n" + module->ToString()); bool changed = false; for (auto* comp : module->MakeNonfusionComputations()) { - if (BatchNormRewriterVisitor::Run(comp, rewrite_training_op_, + if (BatchNormExpanderVisitor::Run(comp, rewrite_training_op_, rewrite_inference_op_, rewrite_grad_op_, use_fusion_)) { changed = true; } } - XLA_VLOG_LINES(2, "BatchNormRewriter::Run(), after:\n" + module->ToString()); + XLA_VLOG_LINES(2, "BatchNormExpander::Run(), after:\n" + module->ToString()); return changed; } diff --git a/tensorflow/compiler/xla/service/batchnorm_rewriter.h b/tensorflow/compiler/xla/service/batchnorm_expander.h similarity index 83% rename from tensorflow/compiler/xla/service/batchnorm_rewriter.h rename to tensorflow/compiler/xla/service/batchnorm_expander.h index f601741d964376058a2bafade311ede4c8567fd2..4ad987085da91684bb7891070afeefd19be4138f 100644 --- a/tensorflow/compiler/xla/service/batchnorm_rewriter.h +++ b/tensorflow/compiler/xla/service/batchnorm_expander.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_BATCHNORM_REWRITER_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_BATCHNORM_REWRITER_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_BATCHNORM_EXPANDER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_BATCHNORM_EXPANDER_H_ #include @@ -26,18 +26,18 @@ namespace xla { // A pass which rewrites batch norm operations into more operations. Breaking a // big operation into smaller operations helps leverage our generic fusion // logic. -class BatchNormRewriter : public HloPassInterface { +class BatchNormExpander : public HloPassInterface { public: // When use_fusion is set, a multi-output fusion node is created. - BatchNormRewriter(bool rewrite_training_op = false, + BatchNormExpander(bool rewrite_training_op = false, bool rewrite_inference_op = false, bool rewrite_grad_op = false, bool use_fusion = true) : rewrite_training_op_(rewrite_training_op), rewrite_inference_op_(rewrite_inference_op), rewrite_grad_op_(rewrite_grad_op), use_fusion_(use_fusion) {} - ~BatchNormRewriter() = default; - tensorflow::StringPiece name() const override { return "batchnorm_rewriter"; } + ~BatchNormExpander() = default; + tensorflow::StringPiece name() const override { return "batchnorm_expander"; } // Run operation expander on the given computation. Returns whether the // computation was changed. @@ -52,4 +52,4 @@ class BatchNormRewriter : public HloPassInterface { } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_BATCHNORM_REWRITER_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_BATCHNORM_EXPANDER_H_ diff --git a/tensorflow/compiler/xla/service/batchnorm_rewriter_test.cc b/tensorflow/compiler/xla/service/batchnorm_expander_test.cc similarity index 93% rename from tensorflow/compiler/xla/service/batchnorm_rewriter_test.cc rename to tensorflow/compiler/xla/service/batchnorm_expander_test.cc index 590f79aee51ccf410823b91fd8ad09fc7c429c7d..aa36e64b07099a372dab67babc7a18a2d39596bc 100644 --- a/tensorflow/compiler/xla/service/batchnorm_rewriter_test.cc +++ b/tensorflow/compiler/xla/service/batchnorm_expander_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/batchnorm_rewriter.h" +#include "tensorflow/compiler/xla/service/batchnorm_expander.h" #include #include @@ -36,10 +36,10 @@ limitations under the License. namespace xla { namespace { -using BatchNormRewriterTest = HloTestBase; +using BatchNormExpanderTest = HloTestBase; // Test that we expand BatchNormTraining. -TEST_F(BatchNormRewriterTest, BatchNormTraining) { +TEST_F(BatchNormExpanderTest, BatchNormTraining) { Shape input_shape = ShapeUtil::MakeShape(F32, {2, 2, 2, 2}); Shape scale_shape = ShapeUtil::MakeShape(F32, {2}); Shape offset_shape = ShapeUtil::MakeShape(F32, {2}); @@ -63,7 +63,7 @@ TEST_F(BatchNormRewriterTest, BatchNormTraining) { auto computation = module->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kBatchNormTraining); - BatchNormRewriter rewriter(/*rewrite_training_op=*/true, + BatchNormExpander rewriter(/*rewrite_training_op=*/true, /*rewrite_inference_op=*/true, /*rewrite_grad_op=*/true); ASSERT_TRUE(rewriter.Run(module.get()).ValueOrDie()); @@ -73,7 +73,7 @@ TEST_F(BatchNormRewriterTest, BatchNormTraining) { } // Test that we expand BatchNormGrad. -TEST_F(BatchNormRewriterTest, BatchNormGrad) { +TEST_F(BatchNormExpanderTest, BatchNormGrad) { Shape input_shape = ShapeUtil::MakeShape(F32, {2, 2, 2, 2}); Shape scale_shape = ShapeUtil::MakeShape(F32, {2}); Shape mean_shape = ShapeUtil::MakeShape(F32, {2}); @@ -105,7 +105,7 @@ TEST_F(BatchNormRewriterTest, BatchNormGrad) { auto computation = module->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kBatchNormGrad); - BatchNormRewriter rewriter(/*rewrite_training_op=*/true, + BatchNormExpander rewriter(/*rewrite_training_op=*/true, /*rewrite_inference_op=*/true, /*rewrite_grad_op=*/true); ASSERT_TRUE(rewriter.Run(module.get()).ValueOrDie()); diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc index 19a9ff04def5fc3d0b3739bbcf546a74114759a6..33fe11b81db1a1db40285d5c77d8900722025d1c 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -73,9 +73,10 @@ void BufferAllocation::AddAssignment(const LogicalBuffer& buffer, int64 offset, CHECK_LE(offset, size_) << "LogicalBuffer " << buffer << " offset out of range"; CHECK_LE(offset + size, size_) - << "LogicalBuffer " << buffer << " size out of range"; + << "LogicalBuffer " << buffer + << " size out of range at offset: " << offset << " with size: " << size; CHECK_EQ(buffer.color(), color()) - << "Buffer color " << buffer.color() + << "Buffer color " << buffer.color() << " for buffer " << buffer << " does not match allocation color " << color() << "."; OffsetSize offset_size; offset_size.offset = offset; @@ -581,6 +582,7 @@ Status GatherComputationsByAllocationType( instruction->called_computations()) { switch (instruction->opcode()) { case HloOpcode::kCall: + case HloOpcode::kConditional: case HloOpcode::kWhile: // Call and while must be called from a computation with global // allocations as they may return references to buffers inside the @@ -976,8 +978,8 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( const HloOrdering& hlo_ordering = assignment->liveness().hlo_ordering(); if (run_whole_module_heap_simulation) { // Run the heap simulation over the whole module. This reduces memory usage, - // since buffers for kCall and kWhile sub-computations are only live for the - // duration of their calling instructions. + // since buffers for kCall, kWhile, and kConditional sub-computations are + // only live for the duration of their calling instructions. VLOG(1) << "Running whole-module heap simulation"; SequentialHloOrdering::HloModuleSequence module_sequence; FlatSet all_buffers_to_assign; @@ -1272,7 +1274,8 @@ const LogicalBuffer* AddBufferToColocatedSet( } // namespace // Builds sets of buffers in 'colocated_buffer_sets' which should be colocated -// in the same allocation (currently just supports kWhile and kCall). +// in the same allocation (currently just supports kWhile, kCall, and +// kConditional). void BufferAssigner::BuildColocatedBufferSets( const HloModule* module, const BufferLiveness& buffer_liveness, const LogicalBuffer::SizeFunction& buffer_size, @@ -1336,6 +1339,26 @@ void BufferAssigner::BuildColocatedBufferSets( &colocated_set); AddSetToColocatedBufferSets(colocated_set, colocated_buffer_sets); }); + } else if (opcode == HloOpcode::kConditional) { + const HloInstruction* conditional_hlo = instruction; + ShapeUtil::ForEachSubshape( + conditional_hlo->shape(), + [this, conditional_hlo, &points_to_analysis, colocated_buffer_sets]( + const Shape& /*subshape*/, const ShapeIndex& index) { + std::vector colocated_set; + // Add conditional.result. + AddBufferToColocatedSet(conditional_hlo, index, + points_to_analysis, &colocated_set); + // Add conditional.true_computation.root. + AddBufferToColocatedSet( + conditional_hlo->true_computation()->root_instruction(), + index, points_to_analysis, &colocated_set); + // Add conditional.false_computation.root. + AddBufferToColocatedSet( + conditional_hlo->false_computation()->root_instruction(), + index, points_to_analysis, &colocated_set); + AddSetToColocatedBufferSets(colocated_set, colocated_buffer_sets); + }); } } } @@ -1363,14 +1386,15 @@ void BufferAssigner::AssignColocatedBufferSets( } for (const LogicalBuffer* buffer : colocated_buffer_set) { + const int64 buffer_size = assignment->buffer_size_(*buffer); if (allocation == nullptr) { // TODO(b/32491382) Avoid current trivial solution of using new // allocations for each colocated buffer set. When liveness has // module-level scope, we can allow buffers to be shared across // computations (in some cases). - allocation = assignment->NewAllocation( - *buffer, assignment->buffer_size_(*buffer), - /*is_thread_local=*/false, /*is_reusable=*/true); + allocation = assignment->NewAllocation(*buffer, buffer_size, + /*is_thread_local=*/false, + /*is_reusable=*/true); if (entry_parameter_number >= 0) { // This colocated buffer set contains an entry parameter and other // logical buffers which use the parameter as read-only in a while @@ -1381,8 +1405,11 @@ void BufferAssigner::AssignColocatedBufferSets( } colocated_allocations->insert(allocation->index()); } else { + CHECK_EQ(buffer_size, allocation->size()) + << "Buffer: " << *buffer << " size mismatch in colocated buffer " + << "allocation: " << *allocation; assignment->AddAssignment(allocation, *buffer, /*offset=*/0, - assignment->buffer_size_(*buffer)); + buffer_size); } colocated_buffers->insert(buffer); } diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc index 8fba8ef5e5c799eaac429017f4a0ff6a0315ba7c..6fc9d783f1b34de8c0f93c6aa342591891d08eaf 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc @@ -166,6 +166,15 @@ class BufferAssignmentTest : public HloTestBase { return builder.Build(); } + std::unique_ptr BuildR0F32UnaryOpComputation( + HloOpcode opcode, const string& name) { + auto builder = HloComputation::Builder(name); + auto param = + builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "x")); + builder.AddInstruction(HloInstruction::CreateUnary(r0f32_, opcode, param)); + return builder.Build(); + } + // Verifies that the given instruction hlo has a valid input buffer assigned, // i.e., the parameter number matches the op's. const BufferAllocation& GetAssignedInputAllocation( @@ -740,6 +749,56 @@ TEST_F(BufferAssignmentTest, ExampleWhile) { << " instructions; total buffer size " << size0 + sizec + sizeb; } +TEST_F(BufferAssignmentTest, ExampleConditional) { + auto module = CreateNewModule(); + auto true_computation = module->AddEmbeddedComputation( + BuildR0F32UnaryOpComputation(HloOpcode::kCeil, "Ceil")); + auto false_computation = module->AddEmbeddedComputation( + BuildR0F32UnaryOpComputation(HloOpcode::kFloor, "Floor")); + + auto builder = HloComputation::Builder(TestName()); + auto pred = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(false))); + auto const1 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(56.4f))); + auto const2 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(12.4f))); + auto conditional = builder.AddInstruction(HloInstruction::CreateConditional( + r0f32_, pred, const1, true_computation, const2, false_computation)); + module->AddEntryComputation(builder.Build()); + + const std::vector conditional_instrs = + GetInstructions(conditional); + const std::vector true_instrs = + GetInstructions(true_computation->root_instruction()); + const std::vector false_instrs = + GetInstructions(false_computation->root_instruction()); + EXPECT_EQ(4, conditional_instrs.size()); + EXPECT_EQ(2, true_instrs.size()); + EXPECT_EQ(2, false_instrs.size()); + + auto buffers = RunBufferAssignment(module.get()); + ValidateBuffers(conditional_instrs, *buffers); + ValidateBuffers(true_instrs, *buffers); + ValidateBuffers(false_instrs, *buffers); + + EXPECT_FALSE(BuffersDistinct(conditional_instrs, true_instrs, *buffers)) + << "Should be reuse between conditional and true computation."; + EXPECT_FALSE(BuffersDistinct(conditional_instrs, false_instrs, *buffers)) + << "Should be reuse between conditional and false computation."; + EXPECT_FALSE(BuffersDistinct(true_instrs, false_instrs, *buffers)) + << "Should be reuse between true and false computations."; + + const BufferAllocation& conditional_buffer = + GetTopLevelAllocation(*buffers, conditional); + const BufferAllocation& true_buffer = + GetTopLevelAllocation(*buffers, true_computation->root_instruction()); + const BufferAllocation& false_buffer = + GetTopLevelAllocation(*buffers, false_computation->root_instruction()); + EXPECT_EQ(conditional_buffer.size(), true_buffer.size()); + EXPECT_EQ(conditional_buffer.size(), false_buffer.size()); +} + TEST_F(BufferAssignmentTest, UnaryOpReuseChain) { // param0[100] ---> (exp) ---> (tanh) ---> (exp) ---> (neg) auto builder = HloComputation::Builder(TestName()); @@ -1360,10 +1419,13 @@ TEST_F(BufferAssignmentTest, OneTempAllocation) { HloInstruction::CreateParameter(1, shape_3x4, "param_b")); auto param_c = builder.AddInstruction( HloInstruction::CreateParameter(2, shape_4x4, "param_c")); - auto dot_ab = builder.AddInstruction(HloInstruction::CreateBinary( - shape_2x4, HloOpcode::kDot, param_a, param_b)); - auto dot_bc = builder.AddInstruction(HloInstruction::CreateBinary( - shape_3x4, HloOpcode::kDot, param_b, param_c)); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + auto dot_ab = builder.AddInstruction( + HloInstruction::CreateDot(shape_2x4, param_a, param_b, dot_dnums)); + auto dot_bc = builder.AddInstruction( + HloInstruction::CreateDot(shape_3x4, param_b, param_c, dot_dnums)); builder.AddInstruction( HloInstruction::CreateConcatenate(shape_5x4, {dot_ab, dot_bc}, 1)); @@ -1708,9 +1770,8 @@ TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) { BufferAssigner::Run( module.get(), xla::MakeUnique(module.get(), sequence), - ByteSizeOf, - [](LogicalBuffer::Color) { return 1; }) - .ConsumeValueOrDie(); + ByteSizeOf, [](LogicalBuffer::Color) { return 1; }) + .ConsumeValueOrDie(); EXPECT_TRUE(BuffersDistinct({while0}, {while1}, *assignment)); } diff --git a/tensorflow/compiler/xla/service/buffer_liveness.cc b/tensorflow/compiler/xla/service/buffer_liveness.cc index 513bfa3b7f7b45696093d03c1dd8250c548d260a..e7749252ce44f0daf7016f72d80401695eaaacb9 100644 --- a/tensorflow/compiler/xla/service/buffer_liveness.cc +++ b/tensorflow/compiler/xla/service/buffer_liveness.cc @@ -102,8 +102,8 @@ bool BufferLiveness::live_range_strictly_before(const LogicalBuffer& a, return false; } - // Every user of 'a' must be a predecessor of 'b' or 'b' itself. for (const BufferAlias& alias : points_to_analysis_->GetBufferAliases(a)) { + // Every user of 'a' must be a predecessor of 'b' or 'b' itself. for (auto user : alias.instruction()->users()) { if (DoesNotUseOperandBuffer(alias.instruction(), alias.index(), user, points_to_analysis())) { @@ -114,6 +114,16 @@ bool BufferLiveness::live_range_strictly_before(const LogicalBuffer& a, return false; } } + + // If the root instruction aliases the buffer 'a', the live range of 'a' is + // until the end of the computation and can never be strictly before another + // buffer. This is needed to prevent the root instruction's buffers from + // being reused by later instructions even when the root is not the last + // instruction in the schedule. + if (alias.instruction()->parent()->root_instruction() == + alias.instruction()) { + return false; + } } // If 'b' is a user of 'a' then the buffers interfere unless 'a.instruction' diff --git a/tensorflow/compiler/xla/service/buffer_liveness_test.cc b/tensorflow/compiler/xla/service/buffer_liveness_test.cc index bbb42d494b8003176d4911bacbe8a10dc5fc7c6a..f623aef67a4f98b447a9a15634a78deb60cfe6f1 100644 --- a/tensorflow/compiler/xla/service/buffer_liveness_test.cc +++ b/tensorflow/compiler/xla/service/buffer_liveness_test.cc @@ -167,11 +167,10 @@ TEST_F(BufferLivenessTest, MultipleEntryParameters_Sequential) { SequentialHloOrdering::HloModuleSequence sequence; sequence.insert({entry, {param0, negate, param1, exp, add}}); - auto liveness = BufferLiveness::Run( - module.get(), - xla::MakeUnique( - module.get(), sequence)) - .ConsumeValueOrDie(); + auto liveness = + BufferLiveness::Run(module.get(), xla::MakeUnique( + module.get(), sequence)) + .ConsumeValueOrDie(); // Entry parameters interfere as if they are defined simultaneously at // the very beginning. @@ -296,7 +295,7 @@ TEST_F(BufferLivenessTest, OverlappedBuffersSequentialOrder) { module_sequence.emplace(computation, order); auto liveness = BufferLiveness::Run(module.get(), xla::MakeUnique( - module.get(), module_sequence)) + module.get(), module_sequence)) .ConsumeValueOrDie(); EXPECT_TRUE(InstructionsMayInterfere(*liveness, param, negate)); @@ -312,6 +311,48 @@ TEST_F(BufferLivenessTest, OverlappedBuffersSequentialOrder) { EXPECT_FALSE(InstructionsMayInterfere(*liveness, add, exp)); } +TEST_F(BufferLivenessTest, RootInstructionIsNotLastInSequentialOrder) { + // Tests that when the root instruction is not the last instruction in the + // schedule, the live range of its buffers interfere with the buffers of the + // later instructions. + // + // Two sets of independent instructions are executed in the computation. + // param --> add (root) + // recv --> recv-done --> send --> send-done + // + // Sequential order: + // param, add (root), recv, recv-done, send, send-done + auto builder = HloComputation::Builder(TestName()); + auto param = + builder.AddInstruction(HloInstruction::CreateParameter(0, vec_, "param")); + auto add = builder.AddInstruction( + HloInstruction::CreateBinary(vec_, HloOpcode::kAdd, param, param)); + auto recv = builder.AddInstruction( + HloInstruction::CreateRecv(vec_, /*channel_id=*/0)); + auto recv_done = builder.AddInstruction(HloInstruction::CreateRecvDone(recv)); + auto send = builder.AddInstruction( + HloInstruction::CreateSend(recv_done, /*channel_id=*/1)); + auto send_done = builder.AddInstruction(HloInstruction::CreateSendDone(send)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build(add)); + + SequentialHloOrdering::HloModuleSequence module_sequence; + std::vector order = {param, add, recv, + recv_done, send, send_done}; + module_sequence.emplace(computation, order); + auto liveness = + BufferLiveness::Run(module.get(), xla::MakeUnique( + module.get(), module_sequence)) + .ConsumeValueOrDie(); + + EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, add)); + // Check the root instruction (add) buffer interferes with the recv buffer. + EXPECT_TRUE( + liveness->MayInterfere(GetBuffer(*liveness, add, /*index=*/{}), + GetBuffer(*liveness, recv, /*index=*/{0}))); +} + TEST_F(BufferLivenessTest, TupleLiveOut) { // Verify MaybeLiveOut with nested tuples. Result of computation looks like: // @@ -625,9 +666,8 @@ class FusedDynamicUpdateSliceLivenessTest : public BufferLivenessTest { // Run BufferLiveness on 'module'. auto liveness = - BufferLiveness::Run(module.get(), - xla::MakeUnique( - module.get())) + BufferLiveness::Run( + module.get(), xla::MakeUnique(module.get())) .ConsumeValueOrDie(); // Return whether or not buffers interference is detected between // 'tuple_param0' and 'tuple_root' at shape index '{1}'. @@ -738,9 +778,8 @@ class DynamicUpdateSliceLivenessTest : public BufferLivenessTest { module->AddEmbeddedComputation(builder.Build()); // Run BufferLiveness on 'module'. auto liveness = - BufferLiveness::Run(module.get(), - xla::MakeUnique( - module.get())) + BufferLiveness::Run( + module.get(), xla::MakeUnique(module.get())) .ConsumeValueOrDie(); // Return whether or not buffers interference is detected between // 'tuple_param0' and 'tuple_root' at shape index '{1}'. diff --git a/tensorflow/compiler/xla/service/call_graph.cc b/tensorflow/compiler/xla/service/call_graph.cc index 1adecdb939cb2c1259003d3be2c90b5a299b0f30..13eb02ca012f44b2b5ed7c6f5becb7d54b07c33c 100644 --- a/tensorflow/compiler/xla/service/call_graph.cc +++ b/tensorflow/compiler/xla/service/call_graph.cc @@ -54,6 +54,7 @@ std::ostream& operator<<(std::ostream& out, const CallContext& context) { CallContext GetInstructionCallContext(const HloInstruction* instruction) { switch (instruction->opcode()) { case HloOpcode::kCall: + case HloOpcode::kConditional: case HloOpcode::kWhile: return CallContext::kSequential; case HloOpcode::kMap: diff --git a/tensorflow/compiler/xla/service/call_graph_test.cc b/tensorflow/compiler/xla/service/call_graph_test.cc index 0395ea8c8b52315f7ca2221f412750ebadda2dd8..1ea7d538cd515c3098b6a1f03c6146d288330406 100644 --- a/tensorflow/compiler/xla/service/call_graph_test.cc +++ b/tensorflow/compiler/xla/service/call_graph_test.cc @@ -34,12 +34,13 @@ using ::testing::UnorderedElementsAre; class CallGraphTest : public HloTestBase { protected: // Build and return a trivial computation taking and returning a scalar. - std::unique_ptr MakeScalarComputation() { + std::unique_ptr MakeScalarComputation( + HloOpcode opcode = HloOpcode::kNegate) { HloComputation::Builder builder(TestName() + ".ScalarComputation"); HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, kScalarShape, "param0")); builder.AddInstruction( - HloInstruction::CreateUnary(kScalarShape, HloOpcode::kNegate, param0)); + HloInstruction::CreateUnary(kScalarShape, opcode, param0)); return builder.Build(); } @@ -236,6 +237,54 @@ TEST_F(CallGraphTest, ContextBothComputations) { EXPECT_EQ(CallContext::kBoth, sub_node.context()); } +TEST_F(CallGraphTest, ComputationWithConditional) { + // Test a call graph of a module with a conditional. + auto module = CreateNewModule(); + HloComputation* true_computation = + module->AddEmbeddedComputation(MakeScalarComputation(HloOpcode::kCeil)); + HloComputation* false_computation = + module->AddEmbeddedComputation(MakeScalarComputation(HloOpcode::kFloor)); + + HloComputation::Builder builder(TestName()); + HloInstruction* pred = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction* const1 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(56.4f))); + HloInstruction* const2 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(12.6f))); + HloInstruction* conditional = + builder.AddInstruction(HloInstruction::CreateConditional( + kScalarShape, pred, const1, true_computation, const2, + false_computation)); + HloComputation* entry_computation = + module->AddEntryComputation(builder.Build()); + + std::unique_ptr call_graph = CallGraph::Build(module.get()); + + EXPECT_EQ(3, call_graph->nodes().size()); + + const CallGraphNode& entry_node = call_graph->GetNode(entry_computation); + EXPECT_EQ(entry_computation, entry_node.computation()); + EXPECT_EQ(1, entry_node.callsites().size()); + + const CallSite& conditional_callsite = entry_node.callsites()[0]; + EXPECT_EQ(conditional, conditional_callsite.instruction()); + EXPECT_THAT(conditional_callsite.called_computations(), + UnorderedElementsAre(true_computation, false_computation)); + EXPECT_EQ(CallContext::kSequential, conditional_callsite.context()); + EXPECT_EQ(entry_node.GetCallSite(conditional), &conditional_callsite); + + const CallGraphNode& true_node = call_graph->GetNode(true_computation); + EXPECT_TRUE(true_node.callees().empty()); + EXPECT_EQ(1, true_node.callers().size()); + EXPECT_EQ(entry_computation, true_node.callers()[0]); + + const CallGraphNode& false_node = call_graph->GetNode(false_computation); + EXPECT_TRUE(false_node.callees().empty()); + EXPECT_EQ(1, false_node.callers().size()); + EXPECT_EQ(entry_computation, false_node.callers()[0]); +} + TEST_F(CallGraphTest, ComplexGraph) { // Test a call graph of a module with several computation called in various // contexts. The call graph looks like: diff --git a/tensorflow/compiler/xla/service/call_inliner.cc b/tensorflow/compiler/xla/service/call_inliner.cc index 3aa7f5c4d5829ccc0e8df697c1363754128ff436..482ccc5b67109258f544e5657ecfa0e8f62192c0 100644 --- a/tensorflow/compiler/xla/service/call_inliner.cc +++ b/tensorflow/compiler/xla/service/call_inliner.cc @@ -82,6 +82,10 @@ class SubcomputationInsertionVisitor : public DfsHloVisitorWithDefault { return outer_->ReplaceInstruction(call_, new_root); } + CallInliner::InlinedInstructionMap ConsumeInstructionMap() { + return std::move(subcomputation_hlo_to_new_hlo_); + } + private: // Resolves the callee subcomputation_hlo to the new (inline) HLO in the // caller computation, or returns a NotFound error if that subcomputation HLO @@ -112,13 +116,13 @@ class SubcomputationInsertionVisitor : public DfsHloVisitorWithDefault { HloInstruction* call_; HloComputation* outer_; - std::unordered_map - subcomputation_hlo_to_new_hlo_; + CallInliner::InlinedInstructionMap subcomputation_hlo_to_new_hlo_; }; } // namespace -/* static */ Status CallInliner::Inline(HloInstruction* call) { +/* static */ StatusOr CallInliner::Inline( + HloInstruction* call) { TF_RET_CHECK(call->opcode() == HloOpcode::kCall) << "Instruction was not a call op: " << call->opcode(); const auto& callees = call->called_computations(); @@ -126,7 +130,8 @@ class SubcomputationInsertionVisitor : public DfsHloVisitorWithDefault { HloComputation* callee = callees[0]; // We visit the callee, cloning its body into its caller. SubcomputationInsertionVisitor visitor(call); - return callee->Accept(&visitor); + TF_RETURN_IF_ERROR(callee->Accept(&visitor)); + return visitor.ConsumeInstructionMap(); } StatusOr CallInliner::Run(HloModule* module) { @@ -140,7 +145,7 @@ StatusOr CallInliner::Run(HloModule* module) { VLOG(1) << "Visiting callsite: " << callsite.ToString(); if (callsite.instruction()->opcode() == HloOpcode::kCall) { HloInstruction* call = callsite.instruction(); - TF_RETURN_IF_ERROR(Inline(call)); + TF_RETURN_IF_ERROR(Inline(call).status()); did_mutate = true; } } diff --git a/tensorflow/compiler/xla/service/call_inliner.h b/tensorflow/compiler/xla/service/call_inliner.h index 2dbd38bf1ac90d3efa1453e6af6f791668d5e72a..a8345a394d46c90a48305313dac0bcd9b06938ac 100644 --- a/tensorflow/compiler/xla/service/call_inliner.h +++ b/tensorflow/compiler/xla/service/call_inliner.h @@ -27,8 +27,12 @@ namespace xla { // called function, and proceed recursively. class CallInliner : public HloPassInterface { public: - // Inlines one call instruction. - static Status Inline(HloInstruction* call); + using InlinedInstructionMap = + std::unordered_map; + + // Inlines one call instruction. Returns a mapping from the original + // instructions to their inlined versions. + static StatusOr Inline(HloInstruction* call); ~CallInliner() override = default; tensorflow::StringPiece name() const override { return "CallInliner"; } diff --git a/tensorflow/compiler/xla/service/call_inliner_test.cc b/tensorflow/compiler/xla/service/call_inliner_test.cc index 865ed993da121d26ceb61123f1822d93814cbb9b..738d00881dd057fc13c115006c15e8f5b6d14a1d 100644 --- a/tensorflow/compiler/xla/service/call_inliner_test.cc +++ b/tensorflow/compiler/xla/service/call_inliner_test.cc @@ -135,7 +135,7 @@ TEST_F(CallInlinerTest, InlineWithoutRunningPass) { HloInstruction::CreateCall(pred, {}, false_computation)); auto computation = module->AddEntryComputation(call_false_builder.Build()); - TF_ASSERT_OK(CallInliner::Inline(call)); + TF_ASSERT_OK(CallInliner::Inline(call).status()); EXPECT_THAT(computation->root_instruction(), op::Constant()); EXPECT_THAT(computation->root_instruction()->control_successors(), ElementsAre(op::Constant())); diff --git a/tensorflow/compiler/xla/service/copy_insertion_test.cc b/tensorflow/compiler/xla/service/copy_insertion_test.cc index 3278fd5f064902459ded4d9367b5390cf8a63f27..128ee726ea6e4a8b63727fdc9762d865cee1c985 100644 --- a/tensorflow/compiler/xla/service/copy_insertion_test.cc +++ b/tensorflow/compiler/xla/service/copy_insertion_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_runner.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" @@ -339,7 +340,7 @@ TEST_F(CopyInsertionTest, ElementOfNestedTupleParameter) { ShapeUtil::MakeShape(F32, {42})}), "param0")); - // The return value of the computation is the zero-th elemnt of the nested + // The return value of the computation is the zero-th element of the nested // tuple. This element is itself a tuple. auto gte = builder.AddInstruction(HloInstruction::CreateGetTupleElement( ShapeUtil::GetSubshape(param->shape(), {0}), param, 0)); @@ -1726,5 +1727,189 @@ void BM_ParallelWhiles(int num_iters, int num_whiles) { BENCHMARK(BM_SequentialWhiles)->Arg(512)->Arg(1024)->Arg(2048)->Arg(4096); BENCHMARK(BM_ParallelWhiles)->Arg(512)->Arg(1024)->Arg(2048)->Arg(4096); +TEST_F(CopyInsertionTest, SimpleControlFlowTest) { + const string& hlo_string = R"( +HloModule TestModule + +if-body.v5 { + constant.3 = s32[] constant(-1) + p.1 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0) + get-tuple-element.18 = (s32[], s32[], s32[]) get-tuple-element(p.1), index=1 + get-tuple-element.65 = s32[] get-tuple-element(get-tuple-element.18), index=0 + get-tuple-element.66 = s32[] get-tuple-element(get-tuple-element.18), index=1 + add.3 = s32[] add(get-tuple-element.65, get-tuple-element.66) + tuple.33 = (s32[]) tuple(add.3) + ROOT tuple.34 = (s32[], (s32[], s32[], s32[]), (s32[])) tuple(constant.3, get-tuple-element.18, tuple.33) +} + +if-condition.v4 { + p.2 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0) + get-tuple-element.67 = s32[] get-tuple-element(p.2), index=0 + constant.4 = s32[] constant(0) + ROOT equal-to = pred[] equal-to(get-tuple-element.67, constant.4) +} + +_functionalize_body_1__.v28 { + arg_tuple.4 = (s32[], s32[], s32[], s32[]) parameter(0) + get-tuple-element.68 = s32[] get-tuple-element(arg_tuple.4), index=0 + constant.7 = s32[] constant(1) + add.4 = s32[] add(get-tuple-element.68, constant.7) + get-tuple-element.69 = s32[] get-tuple-element(arg_tuple.4), index=1 + get-tuple-element.70 = s32[] get-tuple-element(arg_tuple.4), index=2 + less-than-or-equal-to = pred[] less-than-or-equal-to(get-tuple-element.69, get-tuple-element.70) + constant.8 = s32[] constant(0) + select = s32[] select(less-than-or-equal-to, constant.8, constant.7) + get-tuple-element.71 = s32[] get-tuple-element(arg_tuple.4), index=3 + tuple.35 = (s32[], s32[], s32[]) tuple(get-tuple-element.69, get-tuple-element.71, get-tuple-element.70) + tuple.36 = (s32[]) tuple(constant.8) + tuple.37 = (s32[], (s32[], s32[], s32[]), (s32[])) tuple(select, tuple.35, tuple.36) + while = (s32[], (s32[], s32[], s32[]), (s32[])) while(tuple.37), condition=if-condition.v4, body=if-body.v5 + get-tuple-element.72 = (s32[]) get-tuple-element(while), index=2 + get-tuple-element.73 = s32[] get-tuple-element(get-tuple-element.72), index=0 + ROOT tuple.38 = (s32[], s32[], s32[], s32[]) tuple(add.4, get-tuple-element.69, get-tuple-element.70, get-tuple-element.73) +} + +cond_wrapper.v3.1 { + inputs.1 = (s32[], s32[], s32[], s32[]) parameter(0) + get-tuple-element.75 = s32[] get-tuple-element(inputs.1), index=0 + constant.11 = s32[] constant(7) + ROOT less-than.2 = pred[] less-than(get-tuple-element.75, constant.11) +} + +_functionalize_body_2__.v25 { + arg_tuple.5 = (s32[], s32[], s32[], s32[], s32[]) parameter(0) + get-tuple-element.76 = s32[] get-tuple-element(arg_tuple.5), index=0 + get-tuple-element.77 = s32[] get-tuple-element(arg_tuple.5), index=2 + get-tuple-element.78 = s32[] get-tuple-element(arg_tuple.5), index=3 + get-tuple-element.79 = s32[] get-tuple-element(arg_tuple.5), index=4 + tuple.39 = (s32[], s32[], s32[], s32[]) tuple(get-tuple-element.76, get-tuple-element.77, get-tuple-element.78, get-tuple-element.79) + while.2 = (s32[], s32[], s32[], s32[]) while(tuple.39), condition=cond_wrapper.v3.1, body=_functionalize_body_1__.v28 + get-tuple-element.80 = s32[] get-tuple-element(while.2), index=0 + get-tuple-element.81 = s32[] get-tuple-element(arg_tuple.5), index=1 + constant.12 = s32[] constant(1) + add.5 = s32[] add(get-tuple-element.81, constant.12) + get-tuple-element.82 = s32[] get-tuple-element(while.2), index=3 + ROOT tuple.40 = (s32[], s32[], s32[], s32[], s32[]) tuple(get-tuple-element.80, add.5, get-tuple-element.77, get-tuple-element.78, get-tuple-element.82) +} + +cond_wrapper.v3.2 { + inputs.2 = (s32[], s32[], s32[], s32[], s32[]) parameter(0) + get-tuple-element.83 = s32[] get-tuple-element(inputs.2), index=1 + constant.13 = s32[] constant(5) + ROOT less-than.3 = pred[] less-than(get-tuple-element.83, constant.13) +} + +ENTRY TestComputation { + arg_tuple.6 = (s32[], s32[], s32[], s32[], s32[]) parameter(0) + ROOT while.3 = (s32[], s32[], s32[], s32[], s32[]) while(arg_tuple.6), condition=cond_wrapper.v3.2, body=_functionalize_body_2__.v25 +} +)"; + auto module_or_status = + HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest()); + auto module = module_or_status.ConsumeValueOrDie(); + InsertCopies(module.get()); +} + +TEST_F(CopyInsertionTest, ControlFlowTest) { + const string& hlo_string = R"( +HloModule TestModule + +if-body.v5 { + constant.3 = s32[] constant(-1) + p.1 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0) + get-tuple-element.18 = (s32[], s32[], s32[]) get-tuple-element(p.1), index=1 + get-tuple-element.65 = s32[] get-tuple-element(get-tuple-element.18), index=0 + get-tuple-element.66 = s32[] get-tuple-element(get-tuple-element.18), index=1 + add.3 = s32[] add(get-tuple-element.65, get-tuple-element.66) + tuple.33 = (s32[]) tuple(add.3) + ROOT tuple.34 = (s32[], (s32[], s32[], s32[]), (s32[])) tuple(constant.3, get-tuple-element.18, tuple.33) +} + +if-condition.v4 { + p.2 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0) + get-tuple-element.67 = s32[] get-tuple-element(p.2), index=0 + constant.4 = s32[] constant(0) + ROOT equal-to = pred[] equal-to(get-tuple-element.67, constant.4) +} + +if-body.v5.1 { + constant.5 = s32[] constant(-1) + p.3 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0) + get-tuple-element.68 = (s32[], s32[], s32[]) get-tuple-element(p.3), index=1 + get-tuple-element.70 = s32[] get-tuple-element(get-tuple-element.68), index=2 + multiply.1 = s32[] multiply(get-tuple-element.70, get-tuple-element.70) + tuple.35 = (s32[]) tuple(multiply.1) + ROOT tuple.36 = (s32[], (s32[], s32[], s32[]), (s32[])) tuple(constant.5, get-tuple-element.68, tuple.35) +} + +if-condition.v4.1 { + p.4 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0) + get-tuple-element.71 = s32[] get-tuple-element(p.4), index=0 + constant.6 = s32[] constant(1) + ROOT equal-to.1 = pred[] equal-to(get-tuple-element.71, constant.6) +} + +_functionalize_body_1__.v28 { + arg_tuple.4 = (s32[], s32[], s32[], s32[]) parameter(0) + get-tuple-element.72 = s32[] get-tuple-element(arg_tuple.4), index=0 + constant.7 = s32[] constant(1) + add.4 = s32[] add(get-tuple-element.72, constant.7) + get-tuple-element.73 = s32[] get-tuple-element(arg_tuple.4), index=1 + get-tuple-element.74 = s32[] get-tuple-element(arg_tuple.4), index=2 + less-than-or-equal-to = pred[] less-than-or-equal-to(get-tuple-element.73, get-tuple-element.74) + constant.8 = s32[] constant(0) + select = s32[] select(less-than-or-equal-to, constant.8, constant.7) + get-tuple-element.75 = s32[] get-tuple-element(arg_tuple.4), index=3 + tuple.37 = (s32[], s32[], s32[]) tuple(get-tuple-element.73, get-tuple-element.75, get-tuple-element.74) + tuple.38 = (s32[]) tuple(constant.8) + tuple.39 = (s32[], (s32[], s32[], s32[]), (s32[])) tuple(select, tuple.37, tuple.38) + while = (s32[], (s32[], s32[], s32[]), (s32[])) while(tuple.39), condition=if-condition.v4, body=if-body.v5 + while.1 = (s32[], (s32[], s32[], s32[]), (s32[])) while(while), condition=if-condition.v4.1, body=if-body.v5.1 + get-tuple-element.76 = (s32[]) get-tuple-element(while.1), index=2 + get-tuple-element.77 = s32[] get-tuple-element(get-tuple-element.76), index=0 + ROOT tuple.40 = (s32[], s32[], s32[], s32[]) tuple(add.4, get-tuple-element.73, get-tuple-element.74, get-tuple-element.77) +} + +cond_wrapper.v3.1 { + inputs.1 = (s32[], s32[], s32[], s32[]) parameter(0) + get-tuple-element.78 = s32[] get-tuple-element(inputs.1), index=0 + constant.11 = s32[] constant(7) + ROOT less-than.2 = pred[] less-than(get-tuple-element.78, constant.11) +} + +_functionalize_body_2__.v25 { + arg_tuple.5 = (s32[], s32[], s32[], s32[], s32[]) parameter(0) + get-tuple-element.79 = s32[] get-tuple-element(arg_tuple.5), index=0 + get-tuple-element.80 = s32[] get-tuple-element(arg_tuple.5), index=2 + get-tuple-element.81 = s32[] get-tuple-element(arg_tuple.5), index=3 + get-tuple-element.82 = s32[] get-tuple-element(arg_tuple.5), index=4 + tuple.41 = (s32[], s32[], s32[], s32[]) tuple(get-tuple-element.79, get-tuple-element.80, get-tuple-element.81, get-tuple-element.82) + while.2 = (s32[], s32[], s32[], s32[]) while(tuple.41), condition=cond_wrapper.v3.1, body=_functionalize_body_1__.v28 + get-tuple-element.83 = s32[] get-tuple-element(while.2), index=0 + get-tuple-element.84 = s32[] get-tuple-element(arg_tuple.5), index=1 + constant.12 = s32[] constant(1) + add.5 = s32[] add(get-tuple-element.84, constant.12) + get-tuple-element.85 = s32[] get-tuple-element(while.2), index=3 + ROOT tuple.42 = (s32[], s32[], s32[], s32[], s32[]) tuple(get-tuple-element.83, add.5, get-tuple-element.80, get-tuple-element.81, get-tuple-element.85) +} + +cond_wrapper.v3.2 { + inputs.2 = (s32[], s32[], s32[], s32[], s32[]) parameter(0) + get-tuple-element.86 = s32[] get-tuple-element(inputs.2), index=1 + constant.13 = s32[] constant(5) + ROOT less-than.3 = pred[] less-than(get-tuple-element.86, constant.13) +} + +ENTRY TestComputation { + arg_tuple.6 = (s32[], s32[], s32[], s32[], s32[]) parameter(0) + ROOT while.3 = (s32[], s32[], s32[], s32[], s32[]) while(arg_tuple.6), condition=cond_wrapper.v3.2, body=_functionalize_body_2__.v25 +} +)"; + auto module_or_status = + HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest()); + auto module = module_or_status.ConsumeValueOrDie(); + InsertCopies(module.get()); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index e1eed498f6adfdae9df1dbf183f7c0505afd4ea2..2f0259163120dd5d62a5d1289deada8dc59c2c6c 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -81,14 +81,15 @@ cc_library( ":conv_canonicalization", ":cpu_copy_insertion", ":cpu_executable", + ":cpu_hlo_support_checker", ":cpu_instruction_fusion", + ":cpu_layout_assignment", ":cpu_options", ":cpu_parallelization_preparation", ":disassembler", ":dot_op_emitter", ":ir_emission_utils", ":ir_emitter", - ":layout_assignment", ":parallel_cpu_executable", ":parallel_task_assignment", ":simple_orc_jit", @@ -100,16 +101,18 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:algebraic_simplifier", - "//tensorflow/compiler/xla/service:batchnorm_rewriter", + "//tensorflow/compiler/xla/service:batchnorm_expander", "//tensorflow/compiler/xla/service:buffer_assignment", "//tensorflow/compiler/xla/service:buffer_liveness", "//tensorflow/compiler/xla/service:call_inliner", + "//tensorflow/compiler/xla/service:dot_decomposer", "//tensorflow/compiler/xla/service:executable", "//tensorflow/compiler/xla/service:flatten_call_graph", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_constant_folding", "//tensorflow/compiler/xla/service:hlo_cse", "//tensorflow/compiler/xla/service:hlo_dce", + "//tensorflow/compiler/xla/service:hlo_element_type_converter", "//tensorflow/compiler/xla/service:hlo_ordering", "//tensorflow/compiler/xla/service:hlo_pass", "//tensorflow/compiler/xla/service:hlo_pass_pipeline", @@ -124,7 +127,9 @@ cc_library( "//tensorflow/compiler/xla/service:reshape_mover", "//tensorflow/compiler/xla/service:transpose_folding", "//tensorflow/compiler/xla/service:tuple_simplifier", + "//tensorflow/compiler/xla/service:while_loop_invariant_code_motion", "//tensorflow/compiler/xla/service:while_loop_simplifier", + "//tensorflow/compiler/xla/service:zero_sized_hlo_elimination", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", # fixdeps: keep "//tensorflow/core:lib", # fixdeps: keep "//tensorflow/core:stream_executor_no_cuda", @@ -135,8 +140,6 @@ cc_library( "@llvm//:core", "@llvm//:mc", # fixdeps: keep "@llvm//:object", - "@llvm//:powerpc_code_gen", # fixdeps: keep - "@llvm//:powerpc_disassembler", # fixdeps: keep "@llvm//:support", "@llvm//:target", # fixdeps: keep "@llvm//:x86_code_gen", # fixdeps: keep @@ -147,7 +150,11 @@ cc_library( cc_library( name = "simple_orc_jit", - srcs = ["simple_orc_jit.cc"], + srcs = [ + "simple_orc_jit.cc", + "windows_compatibility.cc", + "windows_compatibility.h", + ], hdrs = ["simple_orc_jit.h"], deps = [ ":compiler_functor", @@ -160,6 +167,7 @@ cc_library( ":external_constant_pool", ":orc_jit_memory_mapper", ":runtime_conv2d", + ":runtime_fft", ":runtime_fork_join", ":runtime_matmul", ":runtime_single_threaded_conv2d", @@ -250,8 +258,11 @@ cc_library( ":dot_op_emitter", ":external_constant_pool", ":ir_emission_utils", + ":ir_function", + ":parallel_loop_emitter", ":shape_partition", ":simple_orc_jit", + ":target_machine_features", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", @@ -280,6 +291,54 @@ cc_library( ], ) +cc_library( + name = "target_machine_features", + srcs = [ + "target_machine_features.cc", + ], + hdrs = ["target_machine_features.h"], + deps = [ + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/core:lib", + "@llvm//:analysis", + "@llvm//:target", + ], +) + +cc_library( + name = "ir_function", + srcs = ["ir_function.cc"], + hdrs = ["ir_function.h"], + deps = [ + ":ir_emission_utils", + ":shape_partition", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla/service/cpu:cpu_runtime", + "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", + "//tensorflow/core:lib", + "@llvm//:core", + ], +) + +cc_library( + name = "parallel_loop_emitter", + srcs = ["parallel_loop_emitter.cc"], + hdrs = ["parallel_loop_emitter.h"], + deps = [ + ":ir_emission_utils", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service/llvm_ir:ir_array", + "//tensorflow/compiler/xla/service/llvm_ir:llvm_loop", + "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", + "//tensorflow/compiler/xla/service/llvm_ir:loop_emitter", + "//tensorflow/core:lib", + "@llvm//:core", + ], +) + cc_library( name = "dot_op_emitter", srcs = ["dot_op_emitter.cc"], @@ -287,6 +346,8 @@ cc_library( deps = [ ":cpu_options", ":cpu_runtime", + ":target_machine_features", + ":vector_support_library", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:types", @@ -298,7 +359,6 @@ cc_library( "//tensorflow/compiler/xla/service/llvm_ir:kernel_support_library", "//tensorflow/compiler/xla/service/llvm_ir:llvm_loop", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", - "//tensorflow/compiler/xla/service/llvm_ir:vector_support_library", "//tensorflow/core:lib", "@llvm//:core", ], @@ -336,7 +396,6 @@ cc_library( "@llvm//:mc", "@llvm//:mc_disassembler", "@llvm//:object", - "@llvm//:powerpc_disassembler", # fixdeps: keep "@llvm//:support", "@llvm//:target", "@llvm//:x86_disassembler", # fixdeps: keep @@ -462,6 +521,24 @@ cc_library( ], ) +cc_library( + name = "runtime_fft", + srcs = [ + "runtime_fft.cc", + "runtime_fft_impl.h", + ], + hdrs = ["runtime_fft.h"], + copts = runtime_copts(), + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/compiler/xla:executable_run_options", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:framework", + "//tensorflow/core:framework_lite", + "//third_party/eigen3", + ], +) + cc_library( name = "runtime_matvec", srcs = ["runtime_matvec.cc"], @@ -615,13 +692,14 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla/service:hlo", + "@llvm//:core", ], ) cc_library( - name = "layout_assignment", - srcs = ["layout_assignment.cc"], - hdrs = ["layout_assignment.h"], + name = "cpu_layout_assignment", + srcs = ["cpu_layout_assignment.cc"], + hdrs = ["cpu_layout_assignment.h"], deps = [ ":dot_op_emitter", ":ir_emission_utils", @@ -633,11 +711,11 @@ cc_library( ) tf_cc_test( - name = "layout_assignment_test", + name = "cpu_layout_assignment_test", size = "small", - srcs = ["layout_assignment_test.cc"], + srcs = ["cpu_layout_assignment_test.cc"], deps = [ - ":layout_assignment", + ":cpu_layout_assignment", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_layout", "//tensorflow/compiler/xla:shape_util", @@ -763,6 +841,20 @@ cc_library( ], ) +cc_library( + name = "vector_support_library", + srcs = ["vector_support_library.cc"], + hdrs = ["vector_support_library.h"], + deps = [ + ":target_machine_features", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", + "@llvm//:core", + ], +) + tf_cc_test( name = "cpu_copy_insertion_test", srcs = ["cpu_copy_insertion_test.cc"], @@ -783,6 +875,32 @@ tf_cc_test( ], ) +cc_library( + name = "cpu_hlo_support_checker", + srcs = ["cpu_hlo_support_checker.cc"], + hdrs = ["cpu_hlo_support_checker.h"], + deps = [ + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo_pass", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "cpu_hlo_support_checker_test", + srcs = ["cpu_hlo_support_checker_test.cc"], + deps = [ + ":cpu_hlo_support_checker", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + ], +) + # ----------------------------------------------------------------------------- filegroup( diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index addd7284c593f3dcdd86b1745f9aef7b6a1c30c6..f0507982b3749b179dbd7d76c46d39a209640661 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -31,6 +31,7 @@ limitations under the License. #include "llvm/IR/Function.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" +#include "llvm/IR/Verifier.h" #include "llvm/Object/ObjectFile.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/TargetRegistry.h" @@ -42,7 +43,7 @@ limitations under the License. #include "tensorflow/compiler/xla/protobuf_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/algebraic_simplifier.h" -#include "tensorflow/compiler/xla/service/batchnorm_rewriter.h" +#include "tensorflow/compiler/xla/service/batchnorm_expander.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/buffer_liveness.h" #include "tensorflow/compiler/xla/service/call_inliner.h" @@ -50,24 +51,27 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/conv_canonicalization.h" #include "tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h" #include "tensorflow/compiler/xla/service/cpu/cpu_executable.h" +#include "tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h" #include "tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.h" +#include "tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h" #include "tensorflow/compiler/xla/service/cpu/cpu_options.h" #include "tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.h" #include "tensorflow/compiler/xla/service/cpu/disassembler.h" #include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h" #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/cpu/ir_emitter.h" -#include "tensorflow/compiler/xla/service/cpu/layout_assignment.h" #include "tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h" #include "tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h" #include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/dot_decomposer.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_constant_folding.h" #include "tensorflow/compiler/xla/service/hlo_cse.h" #include "tensorflow/compiler/xla/service/hlo_dce.h" +#include "tensorflow/compiler/xla/service/hlo_element_type_converter.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" @@ -83,7 +87,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/reshape_mover.h" #include "tensorflow/compiler/xla/service/transpose_folding.h" #include "tensorflow/compiler/xla/service/tuple_simplifier.h" +#include "tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h" #include "tensorflow/compiler/xla/service/while_loop_simplifier.h" +#include "tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" @@ -149,11 +155,6 @@ CpuCompiler::CpuCompiler() { LLVMInitializeAArch64TargetMC(); LLVMInitializeAArch64AsmPrinter(); LLVMInitializeAArch64Disassembler(); - LLVMInitializePowerPCTarget(); - LLVMInitializePowerPCTargetInfo(); - LLVMInitializePowerPCTargetMC(); - LLVMInitializePowerPCAsmPrinter(); - LLVMInitializePowerPCDisassembler(); } namespace { @@ -166,42 +167,16 @@ namespace { // first module is compiled. std::once_flag llvm_command_line_options_initialized; -void InitializeLLVMCommandLineOptions(const HloModuleConfig& config) { - auto options = config.debug_options().xla_backend_extra_options(); - if (!options.empty()) { - std::vector fake_argv_storage; - fake_argv_storage.push_back(""); - for (const auto& it : options) { - // Skip options the XLA backend itself consumes. - if (!tensorflow::StringPiece(it.first).starts_with("xla_")) { - if (it.second.empty()) { - fake_argv_storage.push_back(it.first); - } else { - fake_argv_storage.push_back(it.first + "=" + it.second); - } - } - } - - VLOG(2) << "Passing argv to LLVM:"; - std::vector fake_argv; - for (const auto& s : fake_argv_storage) { - fake_argv.push_back(s.c_str()); - VLOG(2) << s; - } - llvm::cl::ParseCommandLineOptions(fake_argv.size(), &fake_argv[0]); - } -} - // This visitor records which HLO instructions should have profiling information // recorded. class CollectProfileCandidates : public DfsHloVisitorWithDefault { public: - static StatusOr> + static StatusOr> GetCandidatesForComputation( HloComputation* computation, const std::unordered_map& assigned_indices) { - std::unordered_map hlo_to_profile_idx; + std::unordered_map hlo_to_profile_idx; CollectProfileCandidates profile_candidates_for_computation( &hlo_to_profile_idx, assigned_indices); TF_RETURN_IF_ERROR( @@ -211,7 +186,7 @@ class CollectProfileCandidates : public DfsHloVisitorWithDefault { private: CollectProfileCandidates( - std::unordered_map* hlo_to_profile_idx, + std::unordered_map* hlo_to_profile_idx, const std::unordered_map& assigned_indices) : hlo_to_profile_idx_(hlo_to_profile_idx), assigned_indices_(assigned_indices) {} @@ -251,7 +226,7 @@ class CollectProfileCandidates : public DfsHloVisitorWithDefault { return Status::OK(); } - std::unordered_map* hlo_to_profile_idx_; + std::unordered_map* hlo_to_profile_idx_; const std::unordered_map& assigned_indices_; }; } // namespace @@ -259,7 +234,8 @@ class CollectProfileCandidates : public DfsHloVisitorWithDefault { Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) { // Optimization pipeline. HloPassPipeline pipeline("CPU"); - pipeline.AddInvariantChecker(ShapeSizeBytesFunction()); + pipeline.AddInvariantChecker(); + pipeline.AddPass(); ReducePrecisionInsertion::AddPasses( &pipeline, module->config().debug_options(), @@ -272,14 +248,14 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) { // TODO(b/65775800): Fix wrong output bug in Call and remove the CallInliner // pass. pipeline.AddPass(); - + pipeline.AddPass(); pipeline.AddPass(); { auto& pass = pipeline.AddPass>("simplification"); - pass.AddInvariantChecker(ShapeSizeBytesFunction()); + pass.AddInvariantChecker(); - pass.AddPass( + pass.AddPass( /*rewrite_training_op=*/true, /*rewrite_inference_op=*/true, /*rewrite_grad_op=*/true, @@ -288,6 +264,12 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) { /*is_layout_sensitive=*/false, [](const Shape&, const Shape&) { return false; }, /*enable_dot_strength_reduction=*/false); + + // BatchNormExpander can create zero-sized ops, so zero-sized HLO + // elimination has to come after that pass. + pipeline.AddPass(); + + pass.AddPass(); pass.AddPass(); pass.AddPass(); pass.AddPass(); @@ -318,6 +300,7 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) { [](const Shape&, const Shape&) { return true; }, /*enable_dot_strength_reduction=*/false); pipeline.AddPass(/*is_layout_sensitive=*/true); + pipeline.AddPass(BF16, F32); // Outline ops in the entry computation into calls to subcomputations. const int max_parallelism = module->config().intra_op_parallelism_threads() > 0 @@ -435,6 +418,21 @@ Status InitializeModuleHooks( return Status::OK(); } +Status VerifyLlvmModule(const llvm::Module& llvm_module) { + XLA_SCOPED_LOGGING_TIMER("CpuCompiler - Running LLVM verifier"); + + std::string err; + llvm::raw_string_ostream err_stream(err); + + // verifyModule() returns true if the module is broken. + TF_RET_CHECK(!llvm::verifyModule(llvm_module, &err_stream)) + << "Invalid LLVM IR before optimizations:\n" + << err_stream.str() + << "\nThis probably indicates a bug in the HLO -> LLVM IR lowering. " + "Rerun with --xla_dump_ir_to to get the IR. "; + return Status::OK(); +} + } // namespace StatusOr> CpuCompiler::RunHloPasses( @@ -460,7 +458,7 @@ StatusOr> CpuCompiler::RunBackend( VLOG(1) << "Compiling: " << module->name(); TF_RET_CHECK(stream_exec != nullptr); std::call_once(llvm_command_line_options_initialized, - &InitializeLLVMCommandLineOptions, module->config()); + &llvm_ir::InitializeLLVMCommandLineOptions, module->config()); ModuleHook pre_optimization_ir_hook; ModuleHook post_optimization_ir_hook; @@ -483,17 +481,19 @@ StatusOr> CpuCompiler::RunBackend( llvm_module->setDataLayout(jit->data_layout()); llvm_module->setTargetTriple(jit->target_triple().getTriple()); - HloComputation* computation = module->entry_computation(); - std::unordered_map hlo_to_profile_idx; + HloComputation* entry_computation = module->entry_computation(); + std::unordered_map instruction_to_profile_idx; + std::unordered_map computation_to_profile_idx; std::unique_ptr hlo_profile_index_map; std::unique_ptr hlo_profile_printer; if (module->config().hlo_profiling_enabled()) { hlo_profile_index_map = MakeUnique(*module); TF_ASSIGN_OR_RETURN( - hlo_to_profile_idx, + instruction_to_profile_idx, CollectProfileCandidates::GetCandidatesForComputation( - computation, hlo_profile_index_map->instruction_to_profile_idx())); + entry_computation, + hlo_profile_index_map->instruction_to_profile_idx())); auto shape_size_bytes = [](const Shape& shape) { // On the cpu, opaques are pointers. @@ -504,8 +504,11 @@ StatusOr> CpuCompiler::RunBackend( }; HloCostAnalysis cost_analysis(shape_size_bytes); + TF_RETURN_IF_ERROR(entry_computation->Accept(&cost_analysis)); hlo_profile_printer = CreateHloProfilePrinter(*hlo_profile_index_map, cost_analysis); + computation_to_profile_idx = + hlo_profile_index_map->computation_to_profile_idx(); } std::unique_ptr cpu_executable; @@ -528,9 +531,9 @@ StatusOr> CpuCompiler::RunBackend( // uses data dependencies for determining order. TF_ASSIGN_OR_RETURN( std::unique_ptr assignment, - BufferAssigner::Run(module.get(), - xla::MakeUnique(module.get()), - BufferSizeBytesFunction(), memory_alignment)); + BufferAssigner::Run( + module.get(), xla::MakeUnique(module.get()), + BufferSizeBytesFunction(), memory_alignment)); // BufferAssignment::ToString() includes a header, so no need for us to // print one ourselves. XLA_VLOG_LINES(2, assignment->ToString()); @@ -546,7 +549,7 @@ StatusOr> CpuCompiler::RunBackend( std::map parallel_computations; std::unordered_map> aligned_constants; - for (auto instruction : computation->MakeInstructionPostOrder()) { + for (auto instruction : entry_computation->MakeInstructionPostOrder()) { // Parameters and constants don't get their own computation. if (instruction->opcode() == HloOpcode::kParameter) { continue; @@ -554,7 +557,7 @@ StatusOr> CpuCompiler::RunBackend( if (instruction->opcode() == HloOpcode::kConstant) { // Copy the constant out of the ProtocolBuffer so that we can give it a // higher alignment. - const void* data = instruction->literal().InternalData(); + const void* data = instruction->literal().untyped_data(); int64 size = CpuExecutable::ShapeSizeBytes(instruction->shape()); auto iter = aligned_constants.emplace( instruction, xla::MakeUnique(size)); @@ -571,22 +574,15 @@ StatusOr> CpuCompiler::RunBackend( parallel_computations.emplace(to_apply, instruction); } - // We always profile the entire computation as a whole, even if hlo - // profiling is disabled. When hlo profiling is diabled, we pass in a - // profile counter array of just one element, which corresponds to the whole - // computation. - size_t entry_computation_profile_idx = - hlo_profile_index_map ? hlo_profile_index_map->GetProfileIndexFor( - *module->entry_computation()) - : 0; IrEmitter ir_emitter(*module, *assignment, llvm_module.get(), - hlo_to_profile_idx, entry_computation_profile_idx, + std::move(instruction_to_profile_idx), + std::move(computation_to_profile_idx), jit->target_machine(), jit->external_constant_pool()); std::unique_ptr> function_names( new HloInstructionMap()); for (auto embedded_computation : - computation->MakeEmbeddedComputationsList()) { + entry_computation->MakeEmbeddedComputationsList()) { if (embedded_computation->IsFusionComputation()) { continue; } @@ -600,7 +596,7 @@ StatusOr> CpuCompiler::RunBackend( llvm::Function * ir_function, ir_emitter.EmitComputation( embedded_computation, embedded_computation->name(), - /*is_entry_computation=*/computation_is_parallel, + /*is_top_level_computation=*/computation_is_parallel, /*instruction_order=*/nullptr)); // If this computation is parallel, remember it in the function name map. // This way we know what function to execute when we try to run code for @@ -616,6 +612,7 @@ StatusOr> CpuCompiler::RunBackend( if (embed_ir_in_executable) { ir_module_string = llvm_ir::DumpModuleToString(*llvm_module); } + TF_RETURN_IF_ERROR(VerifyLlvmModule(*llvm_module)); // JIT compile the LLVM IR module to in-memory machine code. jit->AddModule(std::move(llvm_module)); @@ -642,10 +639,10 @@ StatusOr> CpuCompiler::RunBackend( // temporary buffers are required to run the computation. TF_ASSIGN_OR_RETURN( std::unique_ptr assignment, - BufferAssigner::Run( - module.get(), - xla::MakeUnique(module.get(), module_sequence), - BufferSizeBytesFunction(), memory_alignment)); + BufferAssigner::Run(module.get(), + xla::MakeUnique( + module.get(), module_sequence), + BufferSizeBytesFunction(), memory_alignment)); // BufferAssignment::ToString() includes a header, so no need for us to // print one ourselves. XLA_VLOG_LINES(2, assignment->ToString()); @@ -655,14 +652,6 @@ StatusOr> CpuCompiler::RunBackend( TF_RETURN_IF_ERROR(protobuf_util::DumpProtoToDirectory( proto, xla_dump_hlo_proto_to, module->name())); } - // We always profile the entire computation as a whole, even if hlo - // profiling is disabled. When hlo profiling is diabled, we pass in a - // profile counter array of just one element, which corresponds to the whole - // computation. - size_t entry_computation_profile_idx = - hlo_profile_index_map ? hlo_profile_index_map->GetProfileIndexFor( - *module->entry_computation()) - : 0; // Each computation is a single function. Emit all embedded computations // before the entry computation. The order of computations returned from @@ -670,11 +659,12 @@ StatusOr> CpuCompiler::RunBackend( // before a caller computation. IrEmitter ir_emitter(*module, *assignment, llvm_module.get(), - hlo_to_profile_idx, entry_computation_profile_idx, + std::move(instruction_to_profile_idx), + std::move(computation_to_profile_idx), jit->target_machine(), jit->external_constant_pool()); for (auto embedded_computation : - computation->MakeEmbeddedComputationsList()) { + entry_computation->MakeEmbeddedComputationsList()) { if (embedded_computation->IsFusionComputation()) { continue; } @@ -682,23 +672,27 @@ StatusOr> CpuCompiler::RunBackend( ir_emitter .EmitComputation(embedded_computation, embedded_computation->name(), - /*is_entry_computation=*/false, + /*is_top_level_computation=*/false, &module_sequence.at(embedded_computation)) .status()); } - string function_name_prefix = - computation->name().empty() ? "__compute" : computation->name(); + string function_name_prefix = entry_computation->name().empty() + ? "__compute" + : entry_computation->name(); TF_ASSIGN_OR_RETURN( llvm::Function * entry_function, - ir_emitter.EmitComputation(computation, function_name_prefix, - /*is_entry_computation=*/true, - &module_sequence.at(computation))); + ir_emitter.EmitComputation(entry_computation, function_name_prefix, + /*is_top_level_computation=*/true, + &module_sequence.at(entry_computation))); string function_name = llvm_ir::AsString(entry_function->getName()); string ir_module_string; if (embed_ir_in_executable) { ir_module_string = llvm_ir::DumpModuleToString(*llvm_module); } + TF_RETURN_IF_ERROR(VerifyLlvmModule(*llvm_module)); + + XLA_VLOG_LINES(2, "LLVM IR:\n" + llvm_ir::DumpModuleToString(*llvm_module)); // JIT compile the LLVM IR module to in-memory machine code. jit->AddModule(std::move(llvm_module)); @@ -721,7 +715,8 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, const AotCompilationOptions& aot_options) { TF_RET_CHECK(!modules.empty()); std::call_once(llvm_command_line_options_initialized, - &InitializeLLVMCommandLineOptions, modules[0]->config()); + &llvm_ir::InitializeLLVMCommandLineOptions, + modules[0]->config()); // We can pass just one llvm::TargetOptions when we compile the LLVM module, // so we bail if the configs have conflicting flags. At the moment, the only @@ -824,7 +819,8 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, TF_ASSIGN_OR_RETURN( std::unique_ptr assignment, BufferAssigner::Run( - module, xla::MakeUnique(module, module_sequence), + module, + xla::MakeUnique(module, module_sequence), BufferSizeBytesFunction(), memory_alignment)); // BufferAssignment::ToString() includes a header, so no need for us to // print one ourselves. @@ -838,13 +834,13 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, proto, xla_dump_hlo_proto_to, module->name())); } - IrEmitter ir_emitter( - *module, *assignment, &llvm_module, - /*hlo_to_profile_idx=*/ - std::unordered_map{}, - /*entry_computation_profile_idx=*/tensorflow::gtl::nullopt, - target_machine.get(), - /*external_constant_pool=*/nullptr); + IrEmitter ir_emitter(*module, *assignment, &llvm_module, + /*instruction_to_profile_idx=*/ + std::unordered_map{}, + /*computation_to_profile_idx=*/ + std::unordered_map{}, + target_machine.get(), + /*external_constant_pool=*/nullptr); HloComputation* computation = module->entry_computation(); for (auto embedded_computation : computation->MakeEmbeddedComputationsList()) { @@ -855,7 +851,7 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, ir_emitter .EmitComputation(embedded_computation, embedded_computation->name(), - /*is_entry_computation=*/false, + /*is_top_level_computation=*/false, &module_sequence.at(embedded_computation)) .status()); } @@ -863,7 +859,7 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, TF_ASSIGN_OR_RETURN( llvm::Function * entry_function, ir_emitter.EmitComputation(computation, entry_point_name, - /*is_entry_computation=*/true, + /*is_top_level_computation=*/true, &module_sequence.at(computation))); CHECK(entry_function->getName() == llvm_ir::AsStringRef(entry_point_name)); @@ -874,6 +870,16 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, *module, user_pre_optimization_hook_, user_post_optimization_hook_, &pre_optimization_ir_dump_hook, &post_optimization_ir_dump_hook)); + // Run the LLVM verifier over the unoptimized LLVM IR. If it fails, run the + // pre-optimization IR dump hook before returning. + { + Status verify_status = VerifyLlvmModule(llvm_module); + if (!verify_status.ok() && pre_optimization_ir_dump_hook) { + pre_optimization_ir_dump_hook(llvm_module).IgnoreError(); + } + TF_RETURN_IF_ERROR(verify_status); + } + Disassembler disassembler(*target_machine); CompilerFunctor compiler_functor( target_machine.get(), &disassembler, opt_level, diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc index e956f478b86d9816615e2902f5bbeae6d6384162..f335bd1bbc7376d1cccc0fa6aa1c0a6d6ad559ab 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc @@ -73,28 +73,6 @@ CpuExecutable::CpuExecutable( reinterpret_cast(cantFail(sym.getAddress())); } -// Given a pointer to an output buffer (following the CPU JIT calling -// conventions), mark addresses that are "live". The initial pointer itself is -// trivially live. If the shape of the buffer is a tuple, this analysis looks -// into the tuple's elements and marks them live as well (since tuples keep -// pointers to buffers) and also works recursively. address is an in-memory -// buffer address that contains some runtime XLA object. shape is its -// shape. marked_addresses is the set of live addresses to populate. -static void MarkLiveAddressesInOutput( - const void* address, const Shape& shape, - std::unordered_set* marked_addresses) { - marked_addresses->insert(address); - const uintptr_t* address_buffer = static_cast(address); - if (ShapeUtil::IsTuple(shape)) { - for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { - const uintptr_t* element_address = address_buffer + i; - const void* element = reinterpret_cast(*element_address); - MarkLiveAddressesInOutput( - element, ShapeUtil::GetTupleElementShape(shape, i), marked_addresses); - } - } -} - Status CpuExecutable::AllocateBuffers( DeviceMemoryAllocator* memory_allocator, int device_ordinal, std::vector* buffers) { @@ -148,20 +126,6 @@ Status CpuExecutable::ExecuteComputeFunction( tensorflow::gtl::ArraySlice arguments, tensorflow::gtl::ArraySlice buffers, HloExecutionProfile* hlo_execution_profile) { - std::vector argument_buffers; - argument_buffers.reserve(arguments.size()); - for (const auto* argument : arguments) { - argument_buffers.push_back(argument->buffer(/*index=*/{})); - } - return ExecuteComputeFunction(run_options, argument_buffers, buffers, - hlo_execution_profile); -} - -Status CpuExecutable::ExecuteComputeFunction( - const ExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments, - tensorflow::gtl::ArraySlice buffers, - HloExecutionProfile* hlo_execution_profile) { // The calling convention for JITed functions is: // // void function(void* result, const void* run_options, void** args_array, @@ -177,23 +141,19 @@ Status CpuExecutable::ExecuteComputeFunction( // determined by buffer analysis. // std::vector args_array; - for (se::DeviceMemoryBase arg_mem : arguments) { - args_array.push_back(arg_mem.opaque()); + for (const ShapedBuffer* argument : arguments) { + args_array.push_back(argument->root_buffer().opaque()); } uint64 start_micros = tensorflow::Env::Default()->NowMicros(); - // Allocate profiling counters for each hlo instruction that we would like to - // profile. Even when not Hlo profiling, we allocate a counter for the entire - // computation, which we use to update ExecutionProfile below. - std::vector* profile_counters = nullptr; - std::vector profile_counter_for_entry_computation; - if (hlo_execution_profile) { - profile_counters = hlo_execution_profile->mutable_profile_counters(); - } else { - profile_counters = &profile_counter_for_entry_computation; - profile_counter_for_entry_computation.push_back(0); - } + size_t profile_counters_size = + hlo_execution_profile ? hlo_execution_profile->profile_counters().size() + : 0; + int64* profile_counters = + hlo_execution_profile + ? hlo_execution_profile->mutable_profile_counters()->data() + : nullptr; // Call the computation function following the calling convention. std::vector buffer_pointers; @@ -208,7 +168,7 @@ Status CpuExecutable::ExecuteComputeFunction( VLOG(3) << tensorflow::strings::Printf( " func(void* result, void* params[%zu], void* temps[%zu], " "uint64 profile_counters[%zu])", - args_array.size(), buffer_pointers.size(), profile_counters->size()); + args_array.size(), buffer_pointers.size(), profile_counters_size); VLOG(3) << tensorflow::strings::Printf(" result = %p", result_buffer); auto ptr_printer = [](string* out, const void* p) { tensorflow::strings::StrAppend(out, tensorflow::strings::Printf("%p", p)); @@ -220,11 +180,11 @@ Status CpuExecutable::ExecuteComputeFunction( " temps = [%s]", tensorflow::str_util::Join(buffer_pointers, ", ", ptr_printer).c_str()); VLOG(3) << tensorflow::strings::Printf(" profile_counters = %p", - profile_counters->data()); + profile_counters); } compute_function_(result_buffer, run_options, args_array.data(), - buffer_pointers.data(), profile_counters->data()); + buffer_pointers.data(), profile_counters); uint64 end_micros = tensorflow::Env::Default()->NowMicros(); @@ -232,13 +192,11 @@ Status CpuExecutable::ExecuteComputeFunction( tensorflow::mutex_lock lock(mutex_); const double nanoseconds = (end_micros - start_micros) * 1000.0; execution_profile_.set_compute_time_ns(std::max(nanoseconds, 1.0)); - + // If hlo profiling was disabled then the cycle count is left empty. if (hlo_execution_profile) { execution_profile_.set_compute_cycle_count( hlo_execution_profile->total_cycles_executed( *module().entry_computation())); - } else { - execution_profile_.set_compute_cycle_count(profile_counters->back()); } } @@ -246,11 +204,23 @@ Status CpuExecutable::ExecuteComputeFunction( } static void LogLiveAddresses( - const std::unordered_set& marked_addresses) { + tensorflow::gtl::ArraySlice buffers, + const std::vector& buffers_in_result) { + if (!VLOG_IS_ON(3)) { + return; + } + + CHECK_EQ(buffers.size(), buffers_in_result.size()); + std::vector live_out_buffers; + for (int i = 0; i < buffers.size(); ++i) { + if (buffers_in_result[i]) { + live_out_buffers.push_back(buffers[i].opaque()); + } + } VLOG(3) << "Live addresses in output marking found " - << marked_addresses.size() << " addresses:\n" + << live_out_buffers.size() << " addresses:\n" << tensorflow::str_util::Join( - marked_addresses, ", ", [](string* out, const void* address) { + live_out_buffers, ", ", [](string* out, const void* address) { tensorflow::strings::StrAppend( out, tensorflow::strings::Printf("%p", address)); }); @@ -259,13 +229,12 @@ static void LogLiveAddresses( static Status DeallocateTempBuffers( DeviceMemoryAllocator* allocator, se::Stream* stream, tensorflow::gtl::ArraySlice buffers, - const std::unordered_set& marked_addresses) { - // Keep those marked live because they are referenced by the output of the - // computation and are needed by the service. They will be deallocated by the - // service. + const std::vector& buffers_in_result) { + // Keep those buffers in the output of the marked live because they are needed + // by the service. They will be deallocated by the service. for (size_t i = 0; i < buffers.size(); ++i) { se::DeviceMemoryBase alloc = buffers[i]; - if (marked_addresses.count(alloc.opaque()) == 0 && !alloc.is_null()) { + if (!buffers_in_result[i] && !alloc.is_null()) { VLOG(3) << "CpuExecutable deallocating buffer #" << i << " [" << alloc.opaque() << "]"; TF_RETURN_IF_ERROR( @@ -276,33 +245,43 @@ static Status DeallocateTempBuffers( return Status::OK(); } -StatusOr CpuExecutable::ExecuteOnStream( +StatusOr> CpuExecutable::CreateResultShapedBuffer( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments, - HloExecutionProfile* hlo_execution_profile) { + tensorflow::gtl::ArraySlice + allocated_buffers, + std::vector* buffers_in_result) { se::Stream* stream = run_options->stream(); - DeviceMemoryAllocator* memory_allocator = run_options->allocator(); - std::vector buffers(assignment_->Allocations().size()); + auto result_buffer = MakeUnique( + /*on_host_shape=*/result_shape(), /*on_device_shape=*/result_shape(), + stream->parent()->platform(), stream->parent()->device_ordinal()); - TF_RETURN_IF_ERROR(AllocateBuffers( - memory_allocator, stream->parent()->device_ordinal(), &buffers)); - TF_RETURN_IF_ERROR(ExecuteComputeFunction( - &run_options->run_options(), arguments, buffers, hlo_execution_profile)); - - // Mark the buffers that are actually live (used in the output) when the - // computation finishes executing. - std::unordered_set marked_addresses; - TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice result_slice, - assignment_->GetUniqueTopLevelOutputSlice()); - se::DeviceMemoryBase top_level_output = buffers[result_slice.index()]; - MarkLiveAddressesInOutput(top_level_output.opaque(), result_shape(), - &marked_addresses); - - LogLiveAddresses(marked_addresses); - TF_RETURN_IF_ERROR(DeallocateTempBuffers(memory_allocator, stream, buffers, - marked_addresses)); - - return top_level_output; + // Copy DeviceMemoryBase values which contain the array(s) of the result into + // the respective location in ShapedBuffer which is returned to the caller. + TF_RETURN_IF_ERROR(result_buffer->buffers().ForEachMutableElementWithStatus( + [&](const ShapeIndex& index, se::DeviceMemoryBase* device_memory) { + const auto& sources = this->GetRootPointsToSet().element(index); + // The points to set is unambiguous so the set should be a + // singleton. + CHECK_EQ(1, sources.size()); + const LogicalBuffer* buffer_source = sources[0]; + HloInstruction* src = buffer_source->instruction(); + + // The source for this result buffer can be a nested buffer such as + // a tuple element. The source instruction should have a + // non-parameter buffer assigned. + TF_ASSIGN_OR_RETURN( + const BufferAllocation::Slice slice, + this->assignment_->GetUniqueSlice(src, buffer_source->index())); + CHECK(!slice.allocation()->is_entry_computation_parameter()); + + const BufferAllocation::Index buffer_index = slice.index(); + const se::DeviceMemoryBase& buffer = allocated_buffers[buffer_index]; + CHECK(!buffer.is_null() || buffer.size() == 0); + *device_memory = buffer; + (*buffers_in_result)[buffer_index] = true; + return Status::OK(); + })); + return std::move(result_buffer); } StatusOr> CpuExecutable::ExecuteOnStream( @@ -317,67 +296,26 @@ StatusOr> CpuExecutable::ExecuteOnStream( DeviceMemoryAllocator* memory_allocator = run_options->allocator(); std::vector buffers(assignment_->Allocations().size()); - auto result_buffer = - MakeUnique(result_shape(), stream->parent()->platform(), - stream->parent()->device_ordinal()); - TF_RETURN_IF_ERROR(AllocateBuffers( memory_allocator, stream->parent()->device_ordinal(), &buffers)); TF_RETURN_IF_ERROR(ExecuteComputeFunction( &run_options->run_options(), arguments, buffers, hlo_execution_profile)); - // Copy DeviceMemoryBase values which contain the array(s) of the result into - // the respective location in ShapedBuffer which is returned to the caller. std::vector buffers_in_result(assignment_->Allocations().size(), false); - TF_RETURN_IF_ERROR( - result_buffer->mutable_shape_index_to_buffer_entry() - ->ForEachMutableElementWithStatus( - [&buffers, &buffers_in_result, &result_buffer, this]( - const ShapeIndex& index, size_t* buffer_entry) { - const auto& sources = this->GetRootPointsToSet().element(index); - // The points to set is unambiguous so the set should be a - // singleton. - CHECK_EQ(1, sources.size()); - const LogicalBuffer* buffer_source = sources[0]; - HloInstruction* src = buffer_source->instruction(); - - // The source for this result buffer can be a nested buffer - // such as a tuple element. - - // The source instruction should have a non-parameter buffer - // assigned. - TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice, - this->assignment_->GetUniqueSlice( - src, buffer_source->index())); - CHECK(!slice.allocation()->is_entry_computation_parameter()); - - const BufferAllocation::Index buffer_index = slice.index(); - const se::DeviceMemoryBase& buffer = buffers[buffer_index]; - CHECK(!buffer.is_null() || buffer.size() == 0); - *buffer_entry = result_buffer->mutable_buffers()->size(); - result_buffer->mutable_buffers()->push_back(buffer); - buffers_in_result[buffer_index] = true; - return Status::OK(); - })); + TF_ASSIGN_OR_RETURN( + std::unique_ptr result_buffer, + CreateResultShapedBuffer(run_options, buffers, &buffers_in_result)); // Free all buffers not in the result. - for (size_t i = 0; i < buffers.size(); ++i) { - se::DeviceMemoryBase alloc = buffers[i]; - if (!buffers_in_result[i] && !alloc.is_null()) { - VLOG(3) << "CpuExecutable deallocating buffer #" << i << " [" - << alloc.opaque() << "]"; - TF_RETURN_IF_ERROR(memory_allocator->Deallocate( - stream->parent()->device_ordinal(), &alloc)); - } - } + TF_RETURN_IF_ERROR(DeallocateTempBuffers(memory_allocator, stream, buffers, + buffers_in_result)); return std::move(result_buffer); } -StatusOr -CpuExecutable::ExecuteAsyncOnStream( +StatusOr> CpuExecutable::ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments) { + tensorflow::gtl::ArraySlice arguments) { if (hlo_profiling_enabled()) { return Unimplemented( "Asynchronous execution on stream with hlo profiling is not yet " @@ -393,29 +331,25 @@ CpuExecutable::ExecuteAsyncOnStream( TF_RETURN_IF_ERROR(AllocateBuffers( memory_allocator, stream->parent()->device_ordinal(), &buffers)); - // Mark the buffers that are actually live (used in the output) when the - // computation finishes executing. - std::unordered_set marked_addresses; - TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice result_slice, - assignment_->GetUniqueTopLevelOutputSlice()); - se::DeviceMemoryBase top_level_output = buffers[result_slice.index()]; - MarkLiveAddressesInOutput(top_level_output.opaque(), result_shape(), - &marked_addresses); + std::vector buffers_in_result(assignment_->Allocations().size(), false); + TF_ASSIGN_OR_RETURN( + std::unique_ptr result_buffer, + CreateResultShapedBuffer(run_options, buffers, &buffers_in_result)); - LogLiveAddresses(marked_addresses); + LogLiveAddresses(buffers, buffers_in_result); host_stream->EnqueueTask([this, run_options, arguments, buffers, - marked_addresses, memory_allocator, stream]() { + buffers_in_result, memory_allocator, stream]() { // Failing a CHECK here is not great, but I don't see an obvious way to // return a failed Status asynchronously. TF_CHECK_OK(ExecuteComputeFunction(&run_options->run_options(), arguments, buffers, /*hlo_execution_profile=*/nullptr)); TF_CHECK_OK(DeallocateTempBuffers(memory_allocator, stream, buffers, - marked_addresses)); + buffers_in_result)); }); - return top_level_output; + return std::move(result_buffer); } /*static*/ int64 CpuExecutable::ShapeSizeBytes(const Shape& shape) { diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.h b/tensorflow/compiler/xla/service/cpu/cpu_executable.h index 17ee2d673ee7cde1847bf29e2399e6033cb7e30e..50443a59954e222f65fc935e83effdaf6d6c8bf0 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.h @@ -55,21 +55,14 @@ class CpuExecutable : public Executable { std::unique_ptr hlo_profile_index_map); ~CpuExecutable() override {} - StatusOr ExecuteOnStream( - const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice - arguments, - HloExecutionProfile* hlo_execution_profile) override; - StatusOr> ExecuteOnStream( const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments, HloExecutionProfile* hlo_execution_profile) override; - StatusOr ExecuteAsyncOnStream( + StatusOr> ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice - arguments) override; + tensorflow::gtl::ArraySlice arguments) override; // This should be called after set_ir_module_string. const string& ir_module_string() const { return ir_module_string_; } @@ -108,13 +101,6 @@ class CpuExecutable : public Executable { // Calls the generated function performing the computation with the given // arguments using the supplied buffers. - Status ExecuteComputeFunction( - const ExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice - arguments, - tensorflow::gtl::ArraySlice - buffers, - HloExecutionProfile* hlo_execution_profile); Status ExecuteComputeFunction( const ExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments, @@ -122,6 +108,18 @@ class CpuExecutable : public Executable { buffers, HloExecutionProfile* hlo_execution_profile); + // Create a ShapedBuffer for holding the result of the computation. The + // addresses (DeviceMemoryBases) are set according to buffer assignment. + // 'buffers_in_result' should point to a vector of the same size as + // 'allocated_buffers'. An element in buffers_in_result is set to true if the + // corresponding buffer is live out of the computation (and thus contained in + // the returned ShapedBuffer). + StatusOr> CreateResultShapedBuffer( + const ServiceExecutableRunOptions* run_options, + tensorflow::gtl::ArraySlice + allocated_buffers, + std::vector* buffers_in_result); + // Returns the points-to set of the root instruction of the entry // computation. Uses points-to analysis from buffer assignment. const PointsToSet& GetRootPointsToSet() const; diff --git a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.cc b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.cc new file mode 100644 index 0000000000000000000000000000000000000000..7bd4741a04b1135d9780e0cf765b7b33378526e1 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.cc @@ -0,0 +1,48 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h" + +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace xla { + +StatusOr CpuHloSupportChecker::Run(HloModule* module) { + for (auto* computation : module->computations()) { + for (const auto& instruction : computation->instructions()) { + TF_RETURN_IF_ERROR( + ShapeUtil::ValidateShapeWithOptionalLayout(instruction->shape())); + TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( + instruction->shape(), + [&instruction](const Shape& subshape, const ShapeIndex&) { + if (LayoutUtil::IsSparseArray(subshape)) { + return xla::Unimplemented( + "CPU backend does not support HLO instruction %s with shape " + "containing a sparse layout: %s", + instruction->ToString().c_str(), + ShapeUtil::HumanStringWithLayout(instruction->shape()) + .c_str()); + } + return Status::OK(); + })); + } + } + return false; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h new file mode 100644 index 0000000000000000000000000000000000000000..2271af7b247c2684d371010361308b4d7bcd6423 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h @@ -0,0 +1,42 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_HLO_SUPPORT_CHECKER_H_ +#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_HLO_SUPPORT_CHECKER_H_ + +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { + +// This pass should run early in the HLO pipeline and checks for HLO constructs +// which are not supported by the CPU backend and cannot be removed via HLO +// transformations (eg, sparse layouts). +class CpuHloSupportChecker : public HloPassInterface { + public: + CpuHloSupportChecker() = default; + ~CpuHloSupportChecker() override = default; + + tensorflow::StringPiece name() const override { + return "cpu_hlo_support_checker"; + } + + // Note: always returns false (no instructions are ever modified by this + // pass). + StatusOr Run(HloModule* module) override; +}; + +} // namespace xla + +#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_HLO_SUPPORT_CHECKER_H_ diff --git a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..0f463e6de623fc6ab43d685ff2a5d6882ba7b8a2 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker_test.cc @@ -0,0 +1,72 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/core/lib/core/error_codes.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { +namespace { + +using ::testing::HasSubstr; + +class CpuHloSupportCheckerTest : public HloTestBase { + protected: + CpuHloSupportChecker& checker() { return checker_; } + + private: + CpuHloSupportChecker checker_; +}; + +TEST_F(CpuHloSupportCheckerTest, Add) { + HloComputation::Builder builder(TestName()); + const Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "param0")); + HloInstruction* param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape, "param1")); + builder.AddInstruction(HloInstruction::CreateBinary( + scalar_shape, HloOpcode::kAdd, param0, param1)); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + + TF_ASSERT_OK(checker().Run(module.get()).status()); +} + +TEST_F(CpuHloSupportCheckerTest, SparseUnimplemented) { + HloComputation::Builder builder(TestName()); + const Shape sparse_shape = ShapeUtil::MakeShapeWithSparseLayout(F32, {10}, 2); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, sparse_shape, "param0")); + HloInstruction* param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, sparse_shape, "param1")); + builder.AddInstruction(HloInstruction::CreateBinary( + sparse_shape, HloOpcode::kAdd, param0, param1)); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + + Status status = checker().Run(module.get()).status(); + ASSERT_EQ(status.code(), tensorflow::error::UNIMPLEMENTED); + EXPECT_THAT(status.error_message(), + HasSubstr("CPU backend does not support")); + EXPECT_THAT(status.error_message(), + HasSubstr(ShapeUtil::HumanStringWithLayout(sparse_shape))); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc index f87ee3cecd932faac140636a3db7cd4aa0371b85..482e04052d5a914eab0e5bff2c7a83f3b698052f 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc @@ -26,7 +26,7 @@ int64 BytesInDimension(const Shape& shape, int64 dimension) { shape.dimensions(dimension); } -bool IsFusile(const HloInstruction& hlo) { +bool CanBeLoopFused(const HloInstruction& hlo) { // These are the only ones we fuse since we rely on effective elemental IR // generation. return hlo.IsElementwise() || // @@ -42,6 +42,23 @@ bool IsFusile(const HloInstruction& hlo) { hlo.opcode() == HloOpcode::kTranspose; } +bool IsMatrixVectorDot(const HloInstruction* hlo) { + const Shape& hlo_shape = hlo->shape(); + return hlo->opcode() == HloOpcode::kDot && hlo_shape.dimensions_size() == 2 && + (hlo_shape.dimensions(0) == 1 || hlo_shape.dimensions(1) == 1); +} + +bool CanBeOutputFused(const HloInstruction* producer, + const HloInstruction* consumer) { + return consumer->opcode() == HloOpcode::kAdd && IsMatrixVectorDot(producer) && + producer->user_count() == 1; +} + +bool CanBeOutputFusedIntoSomeOperand(const HloInstruction* consumer) { + return consumer->opcode() == HloOpcode::kAdd && + (CanBeOutputFused(consumer->operand(0), consumer) || + CanBeOutputFused(consumer->operand(1), consumer)); +} } // namespace bool CpuInstructionFusion::ShouldFuse(HloInstruction* consumer, @@ -52,7 +69,15 @@ bool CpuInstructionFusion::ShouldFuse(HloInstruction* consumer, constexpr int kFusionThresholdBytes = 16 * 1024; - if (!IsFusile(*producer)) { + if (CanBeOutputFused(producer, consumer)) { + return true; + } + + if (CanBeOutputFusedIntoSomeOperand(producer)) { + return false; + } + + if (!CanBeLoopFused(*producer)) { VLOG(2) << "Producer is not fusile."; return false; } @@ -108,16 +133,13 @@ bool CpuInstructionFusion::ShouldFuse(HloInstruction* consumer, } } - if (consumer->opcode() == HloOpcode::kFusion) { - // InstructionFusion::ShouldFuse above only allows kLoop and kInput fusions. - // The CPU backend does not create kInput fusions, so we only expect to see - // kLoop here. - CHECK(consumer->fusion_kind() == HloInstruction::FusionKind::kLoop); + if (consumer->opcode() == HloOpcode::kFusion && + consumer->fusion_kind() == HloInstruction::FusionKind::kLoop) { VLOG(2) << "Fusing: consumer is a fusion node."; return true; } - if (IsFusile(*consumer)) { + if (CanBeLoopFused(*consumer)) { VLOG(2) << "Fusing: consumer is elementwise or fusile."; return true; } @@ -126,5 +148,11 @@ bool CpuInstructionFusion::ShouldFuse(HloInstruction* consumer, return false; } +HloInstruction::FusionKind CpuInstructionFusion::ChooseKind( + const HloInstruction* producer, const HloInstruction* consumer) { + return CanBeOutputFused(producer, consumer) + ? HloInstruction::FusionKind::kOutput + : HloInstruction::FusionKind::kLoop; +} } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.h b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.h index 0eca4c3473e1454fe5dbd8bf855b4418cf553a94..07aff34974e0cfa6c7a129f82017b280fb1ccd59 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.h @@ -30,6 +30,8 @@ class CpuInstructionFusion : public InstructionFusion { protected: bool ShouldFuse(HloInstruction* consumer, int64 operand_index) override; + HloInstruction::FusionKind ChooseKind( + const HloInstruction* producer, const HloInstruction* consumer) override; }; } // namespace cpu diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc index b9e4d006d77ae76e33ac51440349400ea4eff118..595c3f55b321f47e2312b93e0c238c7637495d77 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc @@ -31,6 +31,14 @@ namespace { using InstructionFusionTest = HloTestBase; +std::unique_ptr MakeDot(const Shape& shape, HloInstruction* lhs, + HloInstruction* rhs) { + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + return HloInstruction::CreateDot(shape, lhs, rhs, dot_dnums); +} + TEST_F(InstructionFusionTest, DotOperationFusion_Basic_0) { HloComputation::Builder builder(TestName()); HloInstruction* arg0 = builder.AddInstruction(HloInstruction::CreateParameter( @@ -40,8 +48,8 @@ TEST_F(InstructionFusionTest, DotOperationFusion_Basic_0) { HloInstruction* exp0 = builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(S32, {1024, 256}), HloOpcode::kExp, arg0)); - HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {1024, 1}), HloOpcode::kDot, exp0, arg1)); + HloInstruction* dot = builder.AddInstruction( + MakeDot(ShapeUtil::MakeShape(F32, {1024, 1}), exp0, arg1)); auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); @@ -59,8 +67,8 @@ TEST_F(InstructionFusionTest, DotOperationFusion_Basic_1) { HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(S32, {256, 1024}), HloOpcode::kExp, arg1)); - HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {1, 1024}), HloOpcode::kDot, arg0, exp1)); + HloInstruction* dot = builder.AddInstruction( + MakeDot(ShapeUtil::MakeShape(F32, {1, 1024}), arg0, exp1)); auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); @@ -80,8 +88,8 @@ TEST_F(InstructionFusionTest, DotOperationFusion_Bitcast) { ShapeUtil::MakeShape(S32, {2, 512, 2, 128}), HloOpcode::kExp, arg0)); HloInstruction* bitcast0 = builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(S32, {1024, 256}), HloOpcode::kBitcast, exp0)); - HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {1024, 1}), HloOpcode::kDot, bitcast0, arg1)); + HloInstruction* dot = builder.AddInstruction( + MakeDot(ShapeUtil::MakeShape(F32, {1024, 1}), bitcast0, arg1)); auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); @@ -102,8 +110,8 @@ TEST_F(InstructionFusionTest, DotOperationFusion_Reshape) { HloInstruction* reshape0 = builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(S32, {1024, 256}), exp0)); - HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {1024, 1}), HloOpcode::kDot, reshape0, arg1)); + HloInstruction* dot = builder.AddInstruction( + MakeDot(ShapeUtil::MakeShape(F32, {1024, 1}), reshape0, arg1)); auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); @@ -121,8 +129,8 @@ TEST_F(InstructionFusionTest, DotOperationFusion_TooLarge) { HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(S32, {256, 32 * 1024}), HloOpcode::kExp, arg1)); - HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {1, 32 * 1024}), HloOpcode::kDot, arg0, exp1)); + HloInstruction* dot = builder.AddInstruction( + MakeDot(ShapeUtil::MakeShape(F32, {1, 32 * 1024}), arg0, exp1)); auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); @@ -140,8 +148,8 @@ TEST_F(InstructionFusionTest, DotOperationFusion_ElementReuse) { HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(S32, {256, 1024}), HloOpcode::kExp, arg1)); - HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {2, 1024}), HloOpcode::kDot, arg0, exp1)); + HloInstruction* dot = builder.AddInstruction( + MakeDot(ShapeUtil::MakeShape(F32, {2, 1024}), arg0, exp1)); auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); @@ -162,8 +170,8 @@ TEST_F(InstructionFusionTest, DotOperationFusion_TransposeFusion) { HloInstruction* transpose1 = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(S32, {256, 1024}), exp1, {1, 0})); - builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {1, 1024}), HloOpcode::kDot, arg0, transpose1)); + builder.AddInstruction( + MakeDot(ShapeUtil::MakeShape(F32, {1, 1024}), arg0, transpose1)); auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); @@ -188,7 +196,9 @@ class OpcodeFusionTest : public InstructionFusionTest { // Runs CPU instruction fusion on the given module, and tests that the result // contains a fused op at the root with exactly the given multiset of opcodes. void RunFusionAndCheckOpcodesWereFused( - HloModule* module, const std::multiset& expected_opcodes) { + HloModule* module, const std::multiset& expected_opcodes, + HloInstruction::FusionKind fusion_kind = + HloInstruction::FusionKind::kLoop) { auto computation = module->entry_computation(); auto did_fusion = CpuInstructionFusion().Run(module); ASSERT_TRUE(did_fusion.ok()); @@ -196,7 +206,7 @@ class OpcodeFusionTest : public InstructionFusionTest { HloInstruction* root = computation->root_instruction(); ASSERT_THAT(root, op::Fusion()); - EXPECT_EQ(root->fusion_kind(), HloInstruction::FusionKind::kLoop); + EXPECT_EQ(root->fusion_kind(), fusion_kind); std::vector fused_opcodes(root->fused_instruction_count()); std::transform(root->fused_instructions().begin(), @@ -608,6 +618,88 @@ TEST_F(OpcodeFusionTest, ReuseViaImplicitBroadcastBinary) { Not(op::Fusion())); } +void CreateComputationForDotAddOutputFusionTest(const string& test_name, + HloModule* module, int m, int k, + int n, + bool add_extra_use_for_dot) { + HloComputation::Builder builder(test_name); + + Shape dot_lhs_shape = ShapeUtil::MakeShape(F32, {m, k}); + Shape dot_rhs_shape = ShapeUtil::MakeShape(F32, {k, n}); + Shape dot_shape = ShapeUtil::MakeShape(F32, {m, n}); + + auto* dot_lhs = builder.AddInstruction( + HloInstruction::CreateParameter(0, dot_lhs_shape, "param0")); + auto* dot_rhs = builder.AddInstruction( + HloInstruction::CreateParameter(1, dot_rhs_shape, "param1")); + auto* addend = builder.AddInstruction( + HloInstruction::CreateParameter(2, dot_shape, "param2")); + + auto* dot = builder.AddInstruction( + HloInstruction::CreateCanonicalDot(dot_shape, dot_lhs, dot_rhs)); + builder.AddInstruction( + HloInstruction::CreateBinary(dot_shape, HloOpcode::kAdd, dot, addend)); + + if (add_extra_use_for_dot) { + builder.AddInstruction( + HloInstruction::CreateOutfeed(dot_shape, dot, "no_config")); + } + + module->AddEntryComputation(builder.Build()); +} + +TEST_F(OpcodeFusionTest, DotAddOutputFusion_1x50x19) { + auto module = CreateNewModule(); + CreateComputationForDotAddOutputFusionTest(TestName(), module.get(), /*m=*/1, + /*k=*/50, /*n=*/19, + /*add_extra_use_for_dot=*/false); + + RunFusionAndCheckOpcodesWereFused( + module.get(), + {HloOpcode::kDot, HloOpcode::kAdd, HloOpcode::kParameter, + HloOpcode::kParameter, HloOpcode::kParameter}, + HloInstruction::FusionKind::kOutput); +} + +TEST_F(OpcodeFusionTest, DotAddOutputFusion_19x50x1) { + auto module = CreateNewModule(); + CreateComputationForDotAddOutputFusionTest(TestName(), module.get(), /*m=*/19, + /*k=*/50, /*n=*/1, + /*add_extra_use_for_dot=*/false); + + RunFusionAndCheckOpcodesWereFused( + module.get(), + {HloOpcode::kDot, HloOpcode::kAdd, HloOpcode::kParameter, + HloOpcode::kParameter, HloOpcode::kParameter}, + HloInstruction::FusionKind::kOutput); +} + +TEST_F(OpcodeFusionTest, DotAddOutputFusion_19x50x19) { + auto module = CreateNewModule(); + CreateComputationForDotAddOutputFusionTest(TestName(), module.get(), /*m=*/19, + /*k=*/50, /*n=*/19, + /*add_extra_use_for_dot=*/false); + + TF_ASSERT_OK_AND_ASSIGN(bool fused_something, + CpuInstructionFusion().Run(module.get())); + EXPECT_FALSE(fused_something); + EXPECT_THAT(module->entry_computation()->root_instruction(), + Not(op::Fusion())); +} + +TEST_F(OpcodeFusionTest, DotAddOutputFusion_19x50x1_multi_use) { + auto module = CreateNewModule(); + CreateComputationForDotAddOutputFusionTest(TestName(), module.get(), /*m=*/19, + /*k=*/50, /*n=*/1, + /*add_extra_use_for_dot=*/true); + + TF_ASSERT_OK_AND_ASSIGN(bool fused_something, + CpuInstructionFusion().Run(module.get())); + EXPECT_FALSE(fused_something); + EXPECT_THAT(module->entry_computation()->root_instruction(), + Not(op::Fusion())); +} + } // namespace } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/layout_assignment.cc b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc similarity index 55% rename from tensorflow/compiler/xla/service/cpu/layout_assignment.cc rename to tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc index 3f2d101959db50d9f775097f01d5a2ba25a0da8c..e8117377e61a4e21b8c45b929c518a18878fcb60 100644 --- a/tensorflow/compiler/xla/service/cpu/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/cpu/layout_assignment.h" +#include "tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h" #include @@ -25,58 +25,77 @@ limitations under the License. namespace xla { namespace cpu { -Status CpuLayoutAssignment::AddBackendConstraints( - LayoutConstraints* constraints) { - auto row_major_shape = [](const Shape& old_shape) { - Shape new_shape(old_shape); - std::vector dimension_order(new_shape.dimensions_size()); - std::iota(dimension_order.rbegin(), dimension_order.rend(), 0); - *new_shape.mutable_layout() = LayoutUtil::MakeLayout(dimension_order); - return new_shape; - }; - auto col_major_shape = [](const Shape& old_shape) { - Shape new_shape(old_shape); - std::vector dimension_order(new_shape.dimensions_size()); - std::iota(dimension_order.begin(), dimension_order.end(), 0); - *new_shape.mutable_layout() = LayoutUtil::MakeLayout(dimension_order); - return new_shape; - }; - - // We want to change the layout of constant arrays to be column major when all - // of their users are dot operations that can be made faster with the flipped - // layout. To avoid going quadriatic over the # of instructions, we cache - // this property in should_make_rhs_col_major -- it maps a constant to true if - // all of the users of said constant are dot operations that can be sped up. - // This cache is populated lazily as we encounter dot operations traversing - // the instruction stream. - tensorflow::gtl::FlatMap - should_make_rhs_col_major_cache; - auto should_make_rhs_col_major = [&](const HloInstruction& instruction) { - if (ProfitableToImplementDotInUntiledLlvmIr(instruction) != - DotInLlvmIrProfitable::kWithColumnMajorRhs) { +// We want to change the layout of constant arrays to be column major when all +// of their users are dot operations that can be made faster with the flipped +// layout. To avoid going quadriatic over the # of instructions, we cache this +// property in should_make_rhs_col_major -- it maps a constant to true if all of +// the users of said constant are dot operations that can be sped up. This +// cache is populated lazily as we encounter dot operations traversing the +// instruction stream. + +namespace { +using ::tensorflow::gtl::nullopt; +using ::tensorflow::gtl::optional; + +using ShouldMakeOperandColMajorCache = + tensorflow::gtl::FlatMap; +} // namespace + +static bool ShouldMakeAllUsersColMajor(const HloInstruction* instruction) { + for (auto* user : instruction->users()) { + optional operand_idx = ProfitableToMakeDotOperandColumnMajor(*user); + if (!operand_idx || user->operand(*operand_idx) != instruction || + std::count(user->operands().begin(), user->operands().end(), + instruction) != 1) { return false; } + } + return true; +} - const auto* rhs = instruction.operand(1); - if (rhs->opcode() != HloOpcode::kConstant) { - return false; - } +static optional ShouldMakeOperandColumnMajor( + ShouldMakeOperandColMajorCache* cache, const HloInstruction& instruction) { + optional operand_idx = + ProfitableToMakeDotOperandColumnMajor(instruction); + if (!operand_idx) { + return nullopt; + } - auto it = should_make_rhs_col_major_cache.find(rhs); - if (it != should_make_rhs_col_major_cache.end()) { - return it->second; - } + const HloInstruction* operand = instruction.operand(*operand_idx); + if (operand->opcode() != HloOpcode::kConstant) { + return nullopt; + } - bool result = std::all_of( - rhs->users().begin(), rhs->users().end(), [&](HloInstruction* user) { - return ProfitableToImplementDotInUntiledLlvmIr(*user) == - DotInLlvmIrProfitable::kWithColumnMajorRhs && - user->operand(0) != rhs; - }); + auto it = cache->find(operand); + if (it == cache->end()) { + auto insert_result = + cache->insert({operand, ShouldMakeAllUsersColMajor(operand)}); + CHECK(insert_result.second); + it = insert_result.first; + } - InsertOrDie(&should_make_rhs_col_major_cache, rhs, result); - return result; - }; + return it->second ? operand_idx : nullopt; +} + +static Shape RowMajorShape(const Shape& old_shape) { + Shape new_shape(old_shape); + std::vector dimension_order(new_shape.dimensions_size()); + std::iota(dimension_order.rbegin(), dimension_order.rend(), 0); + *new_shape.mutable_layout() = LayoutUtil::MakeLayout(dimension_order); + return new_shape; +} + +static Shape ColMajorShape(const Shape& old_shape) { + Shape new_shape(old_shape); + std::vector dimension_order(new_shape.dimensions_size()); + std::iota(dimension_order.begin(), dimension_order.end(), 0); + *new_shape.mutable_layout() = LayoutUtil::MakeLayout(dimension_order); + return new_shape; +} + +Status CpuLayoutAssignment::AddBackendConstraints( + LayoutConstraints* constraints) { + ShouldMakeOperandColMajorCache cache; const HloComputation* computation = constraints->computation(); for (auto* instruction : computation->instructions()) { @@ -91,9 +110,9 @@ Status CpuLayoutAssignment::AddBackendConstraints( // // These constraints are not hard constraints. Ideally, we should decide // which layouts to choose according to some cost model. - Shape output_shape(row_major_shape(convolution->shape())); - Shape input_shape(row_major_shape(lhs_instruction->shape())); - Shape filter_shape(row_major_shape(rhs_instruction->shape())); + Shape output_shape(RowMajorShape(convolution->shape())); + Shape input_shape(RowMajorShape(lhs_instruction->shape())); + Shape filter_shape(RowMajorShape(rhs_instruction->shape())); // Set layouts of the instructions' shapes. TF_RETURN_IF_ERROR( @@ -102,11 +121,11 @@ Status CpuLayoutAssignment::AddBackendConstraints( constraints->SetOperandLayout(filter_shape, convolution, 1)); TF_RETURN_IF_ERROR( constraints->SetInstructionLayout(output_shape, convolution)); - } else if (should_make_rhs_col_major(*instruction)) { - auto* dot = instruction; - const auto& rhs_shape = dot->operand(1)->shape(); - TF_RETURN_IF_ERROR( - constraints->SetOperandLayout(col_major_shape(rhs_shape), dot, 1)); + } else if (optional op_idx = + ShouldMakeOperandColumnMajor(&cache, *instruction)) { + const HloInstruction* op = instruction->operand(*op_idx); + TF_RETURN_IF_ERROR(constraints->SetOperandLayout( + ColMajorShape(op->shape()), instruction, *op_idx)); } else if (PotentiallyImplementedAsEigenDot(*instruction)) { const HloInstruction* dot = instruction; // In order to implement `dot` with Eigen dot, the layouts of the lhs, @@ -114,17 +133,17 @@ Status CpuLayoutAssignment::AddBackendConstraints( // // These constraints are not hard constraints. Ideally, we should decide // which layouts to choose according to some cost model. - Shape output_shape(row_major_shape(dot->shape())); + Shape output_shape(RowMajorShape(dot->shape())); const HloInstruction* lhs_instruction = dot->operand(0); - Shape lhs_shape(row_major_shape(lhs_instruction->shape())); + Shape lhs_shape(RowMajorShape(lhs_instruction->shape())); TF_RETURN_IF_ERROR(constraints->SetOperandLayout(lhs_shape, dot, 0)); // dot is a kDot or a kTransposeDot fusion node. In the latter case, if // it represents X @ X, it may have just one operand. if (dot->operand_count() > 1) { const HloInstruction* rhs_instruction = dot->operand(1); - Shape rhs_shape(row_major_shape(rhs_instruction->shape())); + Shape rhs_shape(RowMajorShape(rhs_instruction->shape())); TF_RETURN_IF_ERROR(constraints->SetOperandLayout(rhs_shape, dot, 1)); } @@ -141,8 +160,12 @@ Status CpuLayoutAssignment::AddBackendConstraints( if (constraints->OperandBufferForwarded(instruction, operand_no)) { continue; } + // Skip operands with non-array shapes. + if (!ShapeUtil::IsArray(instruction->operand(operand_no)->shape())) { + continue; + } Shape operand_shape( - row_major_shape(instruction->operand(operand_no)->shape())); + RowMajorShape(instruction->operand(operand_no)->shape())); TF_RETURN_IF_ERROR(constraints->SetOperandLayout( operand_shape, instruction, operand_no)); } diff --git a/tensorflow/compiler/xla/service/cpu/layout_assignment.h b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h similarity index 86% rename from tensorflow/compiler/xla/service/cpu/layout_assignment.h rename to tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h index 4fd8d68dd6b4f2a8b16f6c048743a996ea76a560..c8edbb9e15a5b6f9c574f5fe9d130d149499ebd2 100644 --- a/tensorflow/compiler/xla/service/cpu/layout_assignment.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_LAYOUT_ASSIGNMENT_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_LAYOUT_ASSIGNMENT_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_LAYOUT_ASSIGNMENT_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_LAYOUT_ASSIGNMENT_H_ #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/layout_assignment.h" @@ -38,4 +38,4 @@ class CpuLayoutAssignment : public LayoutAssignment { } // namespace cpu } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_LAYOUT_ASSIGNMENT_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_LAYOUT_ASSIGNMENT_H_ diff --git a/tensorflow/compiler/xla/service/cpu/layout_assignment_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc similarity index 54% rename from tensorflow/compiler/xla/service/cpu/layout_assignment_test.cc rename to tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc index 1ea5e8c7fc4896512e62396d0a756cda44785f11..6ba030fff3bbc5f413bfb133114ceb5309b77672 100644 --- a/tensorflow/compiler/xla/service/cpu/layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/cpu/layout_assignment.h" +#include "tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h" #include #include @@ -40,6 +40,8 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/array_slice.h" +namespace op = xla::testing::opcode_matchers; + namespace xla { namespace { @@ -61,8 +63,8 @@ TEST_F(CpuLayoutAssignmentTest, DotWithConstantRhsTensor) { HloInstruction::CreateParameter(0, lhs_shape, "param0")); auto dot_rhs = builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateFromShape(rhs_shape))); - auto result = builder.AddInstruction(HloInstruction::CreateBinary( - result_shape, HloOpcode::kDot, dot_lhs, dot_rhs)); + auto result = builder.AddInstruction( + HloInstruction::CreateCanonicalDot(result_shape, dot_lhs, dot_rhs)); auto module = CreateNewModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); @@ -98,10 +100,10 @@ TEST_F(CpuLayoutAssignmentTest, MultipleDotsWithSameConstantRhsTensor0) { HloInstruction::CreateParameter(1, lhs_shape, "param1")); auto dot_rhs = builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateFromShape(rhs_shape))); - auto dot_a_result = builder.AddInstruction(HloInstruction::CreateBinary( - result_shape, HloOpcode::kDot, dot_a_lhs, dot_rhs)); - auto dot_b_result = builder.AddInstruction(HloInstruction::CreateBinary( - result_shape, HloOpcode::kDot, dot_b_lhs, dot_rhs)); + auto dot_a_result = builder.AddInstruction( + HloInstruction::CreateCanonicalDot(result_shape, dot_a_lhs, dot_rhs)); + auto dot_b_result = builder.AddInstruction( + HloInstruction::CreateCanonicalDot(result_shape, dot_b_lhs, dot_rhs)); builder.AddInstruction(HloInstruction::CreateBinary( result_shape, HloOpcode::kAdd, dot_a_result, dot_b_result)); @@ -142,10 +144,10 @@ TEST_F(CpuLayoutAssignmentTest, MultipleDotsWithSameConstantRhsTensor1) { HloInstruction::CreateParameter(1, lhs_b_shape, "param1")); auto dot_rhs = builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateFromShape(rhs_shape))); - auto dot_a_result = builder.AddInstruction(HloInstruction::CreateBinary( - result_a_shape, HloOpcode::kDot, dot_a_lhs, dot_rhs)); - auto dot_b_result = builder.AddInstruction(HloInstruction::CreateBinary( - result_b_shape, HloOpcode::kDot, dot_b_lhs, dot_rhs)); + auto dot_a_result = builder.AddInstruction( + HloInstruction::CreateCanonicalDot(result_a_shape, dot_a_lhs, dot_rhs)); + auto dot_b_result = builder.AddInstruction( + HloInstruction::CreateCanonicalDot(result_b_shape, dot_b_lhs, dot_rhs)); auto tuple_result = builder.AddInstruction( HloInstruction::CreateTuple({dot_a_result, dot_b_result})); @@ -180,8 +182,8 @@ TEST_F(CpuLayoutAssignmentTest, DotWithConstantLhsTensor) { HloInstruction::CreateConstant(Literal::CreateFromShape(lhs_shape))); auto dot_rhs = builder.AddInstruction( HloInstruction::CreateParameter(0, rhs_shape, "param0")); - auto dot_result = builder.AddInstruction(HloInstruction::CreateBinary( - result_shape, HloOpcode::kDot, dot_lhs, dot_rhs)); + auto dot_result = builder.AddInstruction( + HloInstruction::CreateCanonicalDot(result_shape, dot_lhs, dot_rhs)); auto module = CreateNewModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); @@ -220,8 +222,8 @@ TEST_F(CpuLayoutAssignmentTest, DotWithConstantRhsTensorThroughGTE) { HloInstruction::CreateParameter(0, lhs_shape, "param0")); auto dot_rhs = builder.AddInstruction( HloInstruction::CreateGetTupleElement(rhs_shape, constant, 1)); - auto dot_result = builder.AddInstruction(HloInstruction::CreateBinary( - result_shape, HloOpcode::kDot, dot_lhs, dot_rhs)); + auto dot_result = builder.AddInstruction( + HloInstruction::CreateCanonicalDot(result_shape, dot_lhs, dot_rhs)); auto module = CreateNewModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); @@ -241,5 +243,172 @@ TEST_F(CpuLayoutAssignmentTest, DotWithConstantRhsTensorThroughGTE) { EXPECT_NE(instruction->opcode(), HloOpcode::kCopy); } } + +struct DotOutputFusionLayoutAssignmentResult { + bool layout_assignment_changed_something; + const HloInstruction* dot_lhs_fusion_param; + const HloInstruction* dot_rhs_fusion_param; + const HloInstruction* addend_fusion_param; +}; + +static StatusOr RunDotOutputFusion( + HloModule* module, const string& test_name, int m, int k, int n, + const int64 dot_operand_idx_in_add) { + DotOutputFusionLayoutAssignmentResult result; + + CHECK(dot_operand_idx_in_add == 0 || dot_operand_idx_in_add == 1); + + auto builder = HloComputation::Builder(test_name); + + Shape dot_lhs_shape = ShapeUtil::MakeShape(F32, {m, k}); + Shape dot_rhs_shape = ShapeUtil::MakeShape(F32, {k, n}); + Shape dot_shape = ShapeUtil::MakeShape(F32, {m, n}); + + HloInstruction* dot_lhs = builder.AddInstruction( + HloInstruction::CreateParameter(0, dot_lhs_shape, "param0")); + HloInstruction* addend = builder.AddInstruction( + HloInstruction::CreateParameter(1, dot_shape, "param1")); + HloInstruction* dot_rhs = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateFromShape(dot_rhs_shape))); + HloInstruction* dot_result = builder.AddInstruction( + HloInstruction::CreateCanonicalDot(dot_shape, dot_lhs, dot_rhs)); + HloInstruction* add_result; + if (dot_operand_idx_in_add == 0) { + add_result = builder.AddInstruction(HloInstruction::CreateBinary( + dot_shape, HloOpcode::kAdd, dot_result, addend)); + } else { + add_result = builder.AddInstruction(HloInstruction::CreateBinary( + dot_shape, HloOpcode::kAdd, addend, dot_result)); + } + + HloComputation* computation = module->AddEntryComputation(builder.Build()); + + HloInstruction* fusion_instruction = + module->entry_computation()->AddInstruction(HloInstruction::CreateFusion( + dot_shape, HloInstruction::FusionKind::kOutput, add_result)); + TF_RETURN_IF_ERROR( + computation->ReplaceInstruction(add_result, fusion_instruction)); + + HloInstruction* fused_add = + fusion_instruction->fused_instructions_computation()->root_instruction(); + HloInstruction* fused_dot = fusion_instruction->FuseInstruction(dot_result); + + TF_RETURN_IF_ERROR( + computation->RemoveInstructionAndUnusedOperands(dot_result)); + + ComputationLayout computation_layout(computation->ComputeProgramShape()); + *computation_layout.mutable_parameter_layout(0) = + ShapeLayout(LayoutUtil::GetWithDefaultLayout(dot_lhs_shape)); + *computation_layout.mutable_parameter_layout(1) = + ShapeLayout(LayoutUtil::GetWithDefaultLayout(dot_shape)); + *computation_layout.mutable_result_layout() = + ShapeLayout(LayoutUtil::GetWithDefaultLayout(dot_shape)); + + result.dot_lhs_fusion_param = + fusion_instruction->operand(fused_dot->operand(0)->parameter_number()); + result.dot_rhs_fusion_param = + fusion_instruction->operand(fused_dot->operand(1)->parameter_number()); + result.addend_fusion_param = fusion_instruction->operand( + fused_add->operand(1 - dot_operand_idx_in_add)->parameter_number()); + + cpu::CpuLayoutAssignment layout_assignment(&computation_layout); + TF_ASSIGN_OR_RETURN(result.layout_assignment_changed_something, + layout_assignment.Run(module)); + + return result; +} + +static void AssertCorrectLayoutForDotOutputFusion( + const HloComputation* computation, + const DotOutputFusionLayoutAssignmentResult& layout_assignment_result, + bool expect_col_major_dot_rhs) { + Layout expected_dot_rhs_layout = expect_col_major_dot_rhs + ? LayoutUtil::MakeLayout({0, 1}) + : LayoutUtil::MakeLayout({1, 0}); + EXPECT_TRUE(LayoutUtil::Equal( + expected_dot_rhs_layout, + layout_assignment_result.dot_rhs_fusion_param->shape().layout())); + + EXPECT_TRUE(LayoutUtil::Equal( + LayoutUtil::MakeLayout({1, 0}), + layout_assignment_result.dot_lhs_fusion_param->shape().layout())); + + EXPECT_TRUE(LayoutUtil::Equal( + LayoutUtil::MakeLayout({1, 0}), + layout_assignment_result.addend_fusion_param->shape().layout())); + EXPECT_THAT(computation->instructions(), Each(Not(op::Copy()))); +} + +TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_1x50x19_dot_idx_0) { + std::unique_ptr module = CreateNewModule(); + TF_ASSERT_OK_AND_ASSIGN( + DotOutputFusionLayoutAssignmentResult layout_assignment_result, + RunDotOutputFusion(module.get(), TestName(), /*m=*/1, /*k=*/50, /*n=*/19, + /*dot_operand_idx_in_add=*/0)); + ASSERT_TRUE(layout_assignment_result.layout_assignment_changed_something); + AssertCorrectLayoutForDotOutputFusion(module->entry_computation(), + layout_assignment_result, + /*expect_col_major_dot_rhs=*/true); +} + +TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_1x50x19_dot_idx_1) { + std::unique_ptr module = CreateNewModule(); + TF_ASSERT_OK_AND_ASSIGN( + DotOutputFusionLayoutAssignmentResult layout_assignment_result, + RunDotOutputFusion(module.get(), TestName(), /*m=*/1, /*k=*/50, /*n=*/19, + /*dot_operand_idx_in_add=*/1)); + ASSERT_TRUE(layout_assignment_result.layout_assignment_changed_something); + AssertCorrectLayoutForDotOutputFusion(module->entry_computation(), + layout_assignment_result, + /*expect_col_major_dot_rhs=*/true); +} + +TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_19x50x1_dot_idx_0) { + std::unique_ptr module = CreateNewModule(); + TF_ASSERT_OK_AND_ASSIGN( + DotOutputFusionLayoutAssignmentResult layout_assignment_result, + RunDotOutputFusion(module.get(), TestName(), /*m=*/19, /*k=*/50, /*n=*/1, + /*dot_operand_idx_in_add=*/0)); + ASSERT_TRUE(layout_assignment_result.layout_assignment_changed_something); + AssertCorrectLayoutForDotOutputFusion(module->entry_computation(), + layout_assignment_result, + /*expect_col_major_dot_rhs=*/false); +} + +TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_19x50x1_dot_idx_1) { + std::unique_ptr module = CreateNewModule(); + TF_ASSERT_OK_AND_ASSIGN( + DotOutputFusionLayoutAssignmentResult layout_assignment_result, + RunDotOutputFusion(module.get(), TestName(), /*m=*/19, /*k=*/50, /*n=*/1, + /*dot_operand_idx_in_add=*/1)); + ASSERT_TRUE(layout_assignment_result.layout_assignment_changed_something); + AssertCorrectLayoutForDotOutputFusion(module->entry_computation(), + layout_assignment_result, + /*expect_col_major_dot_rhs=*/false); +} + +TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_19x50x19_dot_idx_0) { + std::unique_ptr module = CreateNewModule(); + TF_ASSERT_OK_AND_ASSIGN( + DotOutputFusionLayoutAssignmentResult layout_assignment_result, + RunDotOutputFusion(module.get(), TestName(), /*m=*/19, /*k=*/50, /*n=*/19, + /*dot_operand_idx_in_add=*/0)); + ASSERT_TRUE(layout_assignment_result.layout_assignment_changed_something); + AssertCorrectLayoutForDotOutputFusion(module->entry_computation(), + layout_assignment_result, + /*expect_col_major_dot_rhs=*/false); +} + +TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_19x50x19_dot_idx_1) { + std::unique_ptr module = CreateNewModule(); + TF_ASSERT_OK_AND_ASSIGN( + DotOutputFusionLayoutAssignmentResult layout_assignment_result, + RunDotOutputFusion(module.get(), TestName(), /*m=*/19, /*k=*/50, /*n=*/19, + /*dot_operand_idx_in_add=*/1)); + ASSERT_TRUE(layout_assignment_result.layout_assignment_changed_something); + AssertCorrectLayoutForDotOutputFusion(module->entry_computation(), + layout_assignment_result, + /*expect_col_major_dot_rhs=*/false); +} } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc index 7908dc173d79a4a9dcb6127ac344267e27d2b5f2..1ef45dbec39a0880ebb123ba3fcd1fd6c89eb39a 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc @@ -37,6 +37,7 @@ extern const char* const kEigenMatMulF64SymbolName = "__xla_cpu_runtime_EigenMatMulF64"; extern const char* const kEigenConvF32SymbolName = "__xla_cpu_runtime_EigenConvF32"; +extern const char* const kEigenFftSymbolName = "__xla_cpu_runtime_EigenFft"; extern const char* const kEigenSingleThreadedMatMulF32SymbolName = "__xla_cpu_runtime_EigenSingleThreadedMatMulF32"; extern const char* const kEigenSingleThreadedMatMulF64SymbolName = diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h index 2ade455b8a0a43dda8c93bbb79891439da2e4f75..3e1f08071119c938619d02777513e5b834077118 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h @@ -44,6 +44,7 @@ namespace runtime { extern const char* const kEigenMatMulF32SymbolName; extern const char* const kEigenMatMulF64SymbolName; extern const char* const kEigenConvF32SymbolName; +extern const char* const kEigenFftSymbolName; extern const char* const kEigenSingleThreadedMatMulF32SymbolName; extern const char* const kEigenSingleThreadedMatMulF64SymbolName; extern const char* const kEigenSingleThreadedConvF32SymbolName; diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.cc index 181deedde71bab3cb9ef1820a88de557131b9311..b1c1142e8d988be2ca00809b4be505466071c72f 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.cc @@ -19,7 +19,7 @@ limitations under the License. #include "third_party/eigen3/Eigen/Core" -#ifdef __AVX__ +#ifdef TF_XLA_HAS_AVX xla::cpu::runtime::V8F32AVX __xla_cpu_runtime_ExpV8F32AVX( xla::cpu::runtime::V8F32AVX x) { return Eigen::internal::pexp(x); @@ -29,7 +29,7 @@ xla::cpu::runtime::V8F32AVX __xla_cpu_runtime_LogV8F32AVX( xla::cpu::runtime::V8F32AVX x) { return Eigen::internal::plog(x); } -#endif // __AVX__ +#endif // TF_XLA_HAS_AVX namespace xla { namespace cpu { diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.h b/tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.h index acfada8540d89bb098bb0b04e109441e2123e678..e5c782f93f54dc9f8f76fce7e4735a60e8847583 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.h @@ -24,6 +24,11 @@ limitations under the License. #include "tensorflow/core/platform/macros.h" +#if defined(__AVX__) +#include +#define TF_XLA_HAS_AVX +#endif + namespace xla { namespace cpu { namespace runtime { @@ -31,21 +36,25 @@ namespace runtime { extern const char *const kExpV8F32AVXSymbolName; extern const char *const kLogV8F32AVXSymbolName; -typedef float V8F32AVX __attribute__((__vector_size__(32))); +#ifdef TF_XLA_HAS_AVX +typedef __m256 V8F32AVX; +#endif } // namespace runtime } // namespace cpu } // namespace xla extern "C" { +#ifdef TF_XLA_HAS_AVX // The following functions are vectorized versions of a selection of libm // library functions. // References to these functions are created by the LLVM vectorizer. xla::cpu::runtime::V8F32AVX __xla_cpu_runtime_ExpV8F32AVX( - xla::cpu::runtime::V8F32AVX x) TF_ATTRIBUTE_WEAK; + xla::cpu::runtime::V8F32AVX x); xla::cpu::runtime::V8F32AVX __xla_cpu_runtime_LogV8F32AVX( - xla::cpu::runtime::V8F32AVX x) TF_ATTRIBUTE_WEAK; + xla::cpu::runtime::V8F32AVX x); +#endif } #endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_AVX_H_ diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime_neon.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime_neon.cc index abe792b2787ce8baf56ee62585a0ab886d922a23..8099b722f10ecb83f7cf6c58ba2abb783478b97f 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime_neon.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime_neon.cc @@ -19,7 +19,7 @@ limitations under the License. #include "third_party/eigen3/Eigen/Core" -#ifdef __ARM_NEON__ +#ifdef TF_XLA_HAS_NEON xla::cpu::runtime::V4F32NEON __xla_cpu_runtime_ExpV4F32NEON( xla::cpu::runtime::V4F32NEON x) { @@ -32,7 +32,7 @@ xla::cpu::runtime::V4F32NEON __xla_cpu_runtime_LogV4F32NEON( return Eigen::internal::plog(p); } -#endif // __ARM_NEON__ +#endif // TF_XLA_HAS_NEON namespace xla { namespace cpu { diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime_neon.h b/tensorflow/compiler/xla/service/cpu/cpu_runtime_neon.h index 75cb16b273973d2bf665d378084343fd612a2941..2f5d1a872aaf3868d6d27f88a4f05c778d45660f 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime_neon.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime_neon.h @@ -27,6 +27,7 @@ limitations under the License. // __attribute__((__vector_size__(*))). Unfortunately, the typedef for the ARM // NEON SIMD types is not portable, so the type has to come from #include +#define TF_XLA_HAS_NEON #endif // __ARM_NEON__ namespace xla { @@ -36,12 +37,9 @@ namespace runtime { extern const char *const kExpV4F32NEONSymbolName; extern const char *const kLogV4F32NEONSymbolName; -#ifdef __ARM_NEON__ +#ifdef TF_XLA_HAS_NEON typedef float32x4_t V4F32NEON; -#else -// On non-ARM platforms ensure the declaration is present -struct V4F32NEON; -#endif // __ARM_NEON__ +#endif // TF_XLA_HAS_NEON } // namespace runtime } // namespace cpu @@ -49,14 +47,16 @@ struct V4F32NEON; extern "C" { +#ifdef TF_XLA_HAS_NEON // The following functions are vectorized versions of a selection of libm // library functions. // References to these functions are created by the LLVM vectorizer. xla::cpu::runtime::V4F32NEON __xla_cpu_runtime_ExpV4F32NEON( - xla::cpu::runtime::V4F32NEON x) TF_ATTRIBUTE_WEAK; + xla::cpu::runtime::V4F32NEON x); xla::cpu::runtime::V4F32NEON __xla_cpu_runtime_LogV4F32NEON( - xla::cpu::runtime::V4F32NEON x) TF_ATTRIBUTE_WEAK; + xla::cpu::runtime::V4F32NEON x); +#endif // TF_XLA_HAS_NEON } #endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_NEON_H_ diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.cc index a9a45db5a424d2faecbd437542c41fbd7fdf0bb8..d8ecf231cc8c859ac88e1ef1478f7107cd86a052 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.cc @@ -19,7 +19,7 @@ limitations under the License. #include "third_party/eigen3/Eigen/Core" -#ifdef __SSE4_1__ +#ifdef TF_XLA_HAS_SSE4_1 xla::cpu::runtime::V4F32SSE __xla_cpu_runtime_ExpV4F32SSE( xla::cpu::runtime::V4F32SSE x) { @@ -33,7 +33,7 @@ xla::cpu::runtime::V4F32SSE __xla_cpu_runtime_LogV4F32SSE( return Eigen::internal::plog(p); } -#endif // __SSE4_1__ +#endif // TF_XLA_HAS_SSE4_1 namespace xla { namespace cpu { diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.h b/tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.h index 96587d10d2b86e14ff6a7400fdf14ca0d994ddc5..aeb1eda23f76a6b5cb520b6673e0a011fa1130c7 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.h @@ -24,6 +24,13 @@ limitations under the License. #include "tensorflow/core/platform/macros.h" +// MSVC does not have __SSE4_1__ macro. Eigen enables EIGEN_VECTORIZE_SSE4_1 +// when __AVX__ is defined, we should do the same. +#if defined(__SSE4_1__) || (defined(_MSC_VER) && defined(__AVX__)) +#include +#define TF_XLA_HAS_SSE4_1 +#endif + namespace xla { namespace cpu { namespace runtime { @@ -31,7 +38,9 @@ namespace runtime { extern const char *const kExpV4F32SSESymbolName; extern const char *const kLogV4F32SSESymbolName; -typedef float V4F32SSE __attribute__((__vector_size__(16))); +#ifdef TF_XLA_HAS_SSE4_1 +typedef __m128 V4F32SSE; +#endif } // namespace runtime } // namespace cpu @@ -39,14 +48,16 @@ typedef float V4F32SSE __attribute__((__vector_size__(16))); extern "C" { +#ifdef TF_XLA_HAS_SSE4_1 // The following functions are vectorized versions of a selection of libm // library functions. // References to these functions are created by the LLVM vectorizer. xla::cpu::runtime::V4F32SSE __xla_cpu_runtime_ExpV4F32SSE( - xla::cpu::runtime::V4F32SSE x) TF_ATTRIBUTE_WEAK; + xla::cpu::runtime::V4F32SSE x); xla::cpu::runtime::V4F32SSE __xla_cpu_runtime_LogV4F32SSE( - xla::cpu::runtime::V4F32SSE x) TF_ATTRIBUTE_WEAK; + xla::cpu::runtime::V4F32SSE x); +#endif } #endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_SSE4_1_H_ diff --git a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc index b53719fcc260d706eab3d7460c42af4a1b5e775f..f5e61aef534da57ce13d3ee9bbeaeaec31f53d2e 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc @@ -98,7 +98,7 @@ Status CpuTransferManager::TransferLiteralToInfeed(se::StreamExecutor* executor, if (!ShapeUtil::IsTuple(shape)) { int64 size = GetByteSizeRequirement(shape); - return TransferBufferToInfeed(executor, size, literal.InternalData()); + return TransferBufferToInfeed(executor, size, literal.untyped_data()); } if (ShapeUtil::IsNestedTuple(shape)) { @@ -111,20 +111,20 @@ Status CpuTransferManager::TransferLiteralToInfeed(se::StreamExecutor* executor, // enqueue the resulting destination device addresses with the // infeed manager. std::vector buffers; - buffers.reserve(literal.tuple_literals_size()); + buffers.reserve(ShapeUtil::TupleElementCount(shape)); auto cleanup = tensorflow::gtl::MakeCleanup([&buffers]() { for (cpu::runtime::XfeedBuffer* b : buffers) { b->Done(Cancelled("Failed to infeed buffer to device.")); } }); - for (const auto& tuple_element : literal.tuple_literals()) { - const Shape& tuple_element_shape = tuple_element.shape(); + for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { + const Shape& tuple_element_shape = ShapeUtil::GetSubshape(shape, {i}); int64 tuple_element_size = GetByteSizeRequirement(tuple_element_shape); TF_ASSIGN_OR_RETURN( cpu::runtime::XfeedBuffer * buffer, TransferBufferToInfeedInternal(executor, tuple_element_size, - tuple_element.InternalData())); + literal.untyped_data({i}))); buffers.push_back(buffer); } @@ -187,14 +187,14 @@ Status CpuTransferManager::TransferLiteralFromOutfeed( literal_shape.element_type(), dimensions)); TF_ASSIGN_OR_RETURN(Shape received_shape, TransferArrayBufferFromOutfeed( - executor, literal->MutableInternalData(), size)); + executor, literal->untyped_data(), size)); TF_RET_CHECK(ShapeUtil::Compatible(received_shape, literal->shape())) << "Shape received from outfeed " << ShapeUtil::HumanString(received_shape) << " did not match the shape that was requested for outfeed: " << ShapeUtil::HumanString(literal_shape); TF_RET_CHECK(size == GetByteSizeRequirement(received_shape)); - *literal->mutable_shape() = received_shape; + *literal->mutable_shape_do_not_use() = received_shape; return Status::OK(); } @@ -217,7 +217,7 @@ Status CpuTransferManager::TransferLiteralFromOutfeed( auto empty = Literal::CreateFromDimensions( tuple_element_shape.element_type(), dimensions); int64 size = GetByteSizeRequirement(tuple_element_shape); - buffer_data.push_back({empty->MutableInternalData(), size}); + buffer_data.push_back({empty->untyped_data(), size}); elements.push_back(std::move(empty)); } @@ -233,7 +233,7 @@ Status CpuTransferManager::TransferLiteralFromOutfeed( GetByteSizeRequirement(received_shape)); for (int64 i = 0; i < literal_shape.tuple_shapes_size(); ++i) { - *elements[i]->mutable_shape() = received_shape.tuple_shapes(i); + *elements[i]->mutable_shape_do_not_use() = received_shape.tuple_shapes(i); } *literal = std::move(*Literal::MakeTupleOwned(std::move(elements))); TF_RET_CHECK(ShapeUtil::Equal(literal->shape(), literal_shape)); diff --git a/tensorflow/compiler/xla/service/cpu/disassembler.h b/tensorflow/compiler/xla/service/cpu/disassembler.h index b6feaa7e45cee26eb7f850081bd1fad2cb63b15c..5e302f88990ee4a3c37758881ecec4d6f71dd8e6 100644 --- a/tensorflow/compiler/xla/service/cpu/disassembler.h +++ b/tensorflow/compiler/xla/service/cpu/disassembler.h @@ -37,7 +37,7 @@ struct DisassemblerResult { DisassemblerResult(const string& text, size_t code_size_bytes) : text(text), code_size_bytes(code_size_bytes) {} - // The dissassembled text sections of the object file. + // The disassembled text sections of the object file. string text; // The total number of bytes of executable code in the object file. uint64_t code_size_bytes; @@ -53,7 +53,7 @@ class Disassembler { // Returns a DisassemblerResult for the given object file, containing the // disassembled code. // - // If we couldnt' retrieve a disassembler for this platform, an error status + // If we couldn't retrieve a disassembler for this platform, an error status // is returned. StatusOr DisassembleObjectFile( const llvm::object::ObjectFile& object_file) const; diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc index 4c40dae5122b0853a72d6428fc120220e3a69237..c9fc586b9a4c06eb9e1f111d8f9bd2f717990aab 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc @@ -23,10 +23,11 @@ limitations under the License. #include "llvm/IR/Module.h" #include "llvm/IR/Value.h" #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" +#include "tensorflow/compiler/xla/service/cpu/target_machine_features.h" +#include "tensorflow/compiler/xla/service/cpu/vector_support_library.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" -#include "tensorflow/compiler/xla/service/llvm_ir/vector_support_library.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/util.h" @@ -143,7 +144,8 @@ class ColumnMajorMatrixVectorProductEmitter { ColumnMajorMatrixVectorProductEmitter(PrimitiveType scalar_type, int64 tile_rows, int64 tile_cols, int64 m, int64 k, llvm::Value* lhs, - llvm::Value* rhs, llvm::Value* result, + llvm::Value* rhs, llvm::Value* addend, + llvm::Value* result, llvm::IRBuilder<>* ir_builder) : scalar_type_(scalar_type), tile_rows_(tile_rows), @@ -152,6 +154,7 @@ class ColumnMajorMatrixVectorProductEmitter { k_(k), lhs_(lhs), rhs_(rhs), + addend_(addend), result_(result), ir_builder_(ir_builder), ksl_(ir_builder_), @@ -173,7 +176,7 @@ class ColumnMajorMatrixVectorProductEmitter { } // Load a tile of values from the RHS. For the RHS a "tile" is a contiguous - // sequnce of `count` values, each one broadcasted to the vector width. + // sequence of `count` values, each one broadcasted to the vector width. std::vector LoadRhsTile(llvm::Value* offset, int64 count) { llvm::Value* base_pointer = vsl_.ComputeOffsetPointer(rhs_, offset); std::vector result; @@ -198,6 +201,7 @@ class ColumnMajorMatrixVectorProductEmitter { int64 k_; llvm::Value* lhs_; llvm::Value* rhs_; + llvm::Value* addend_; llvm::Value* result_; llvm::IRBuilder<>* ir_builder_; KernelSupportLibrary ksl_; @@ -242,9 +246,10 @@ void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopTiled( /*step=*/tile_rows_, [&](llvm::Value* row) { std::vector lhs_tile = lhs_tile_loader->LoadTile(/*minor_dim_offset=*/row); - llvm::Value* accumulator = is_first_column - ? vsl_.GetZeroVector() - : vsl_.LoadVector(result_, row); + llvm::Value* accumulator = + is_first_column ? (addend_ ? vsl_.LoadVector(addend_, row) + : vsl_.GetZeroVector()) + : vsl_.LoadVector(result_, row); for (int i = 0; i < columns; i++) { accumulator = vsl_.MulAdd(lhs_tile[i], rhs_tile[i], accumulator); } @@ -288,7 +293,18 @@ void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue( ir_builder_->getInt1(is_first_tiled_column)); ksl_.If( setting_result_first_time, - [&]() { vsl_.StoreScalar(product, result_, scalar_row); }, + /*true_block_generator=*/ + [&]() { + if (addend_) { + vsl_.StoreScalar( + vsl_.Add(vsl_.LoadScalar(addend_, scalar_row), + product), + result_, scalar_row); + } else { + vsl_.StoreScalar(product, result_, scalar_row); + } + }, + /*false_block_generator=*/ [&]() { vsl_.StoreScalar( vsl_.Add(vsl_.LoadScalar(result_, scalar_row), product), @@ -353,7 +369,7 @@ class RowMajorMatrixVectorProductEmitter { RowMajorMatrixVectorProductEmitter(PrimitiveType scalar_type, int64 tile_rows, int64 tile_cols, int64 m, int64 k, llvm::Value* lhs, llvm::Value* rhs, - llvm::Value* result, + llvm::Value* addend, llvm::Value* result, llvm::IRBuilder<>* ir_builder) : scalar_type_(scalar_type), tile_rows_(tile_rows), @@ -362,6 +378,7 @@ class RowMajorMatrixVectorProductEmitter { k_(k), lhs_(lhs), rhs_(rhs), + addend_(addend), result_(result), ir_builder_(ir_builder), ksl_(ir_builder_), @@ -394,6 +411,7 @@ class RowMajorMatrixVectorProductEmitter { int64 k_; llvm::Value* lhs_; llvm::Value* rhs_; + llvm::Value* addend_; llvm::Value* result_; llvm::IRBuilder<>* ir_builder_; KernelSupportLibrary ksl_; @@ -415,11 +433,32 @@ void RowMajorMatrixVectorProductEmitter::EmitOuterLoopBody(llvm::Value* row, EmitInnerLoopEpilogue(/*current_tile_row=*/row, /*rows=*/row_count, &scalar_accumulators); + std::vector accumulator_values; + std::transform( + vector_accumulators.begin(), vector_accumulators.end(), + std::back_inserter(accumulator_values), + [](const VectorVariable& vector_var) { return vector_var.Get(); }); + + std::vector horizontal_sums; + if (row_count == vsl_.vector_size()) { + if (addend_) { + horizontal_sums = vsl_.ComputeHorizontalSums( + std::move(accumulator_values), vsl_.LoadVector(addend_, row)); + } else { + horizontal_sums = + vsl_.ComputeHorizontalSums(std::move(accumulator_values)); + } + } else { + horizontal_sums = vsl_.ComputeHorizontalSums(std::move(accumulator_values)); + } + for (int i = 0; i < row_count; i++) { llvm::Value* result_value = - vsl_.Add(vsl_.AddReduce(vector_accumulators[i].Get()), - scalar_accumulators[i].Get()); + vsl_.Add(horizontal_sums[i], scalar_accumulators[i].Get()); llvm::Value* offset = ir_builder_->CreateAdd(ir_builder_->getInt64(i), row); + if (addend_ && row_count != vsl_.vector_size()) { + result_value = vsl_.Add(vsl_.LoadScalar(addend_, offset), result_value); + } vsl_.StoreScalar(result_value, result_, offset); } } @@ -483,49 +522,52 @@ void RowMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue( } // namespace -DotOpEmitter::DotOpEmitter(const HloInstruction& dot, bool transpose_lhs, - bool transpose_rhs, - const llvm_ir::IrArray& target_array, - const llvm_ir::IrArray& lhs_array, - const llvm_ir::IrArray& rhs_array, - llvm::Value* executable_run_options_value, - llvm::IRBuilder<>* ir_builder, - const HloModuleConfig& hlo_module_config) +DotOpEmitter::DotOpEmitter( + const HloInstruction& dot, bool transpose_lhs, bool transpose_rhs, + const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& lhs_array, + const llvm_ir::IrArray& rhs_array, const llvm_ir::IrArray* addend_array, + llvm::Value* executable_run_options_value, llvm::IRBuilder<>* ir_builder, + const HloModuleConfig& hlo_module_config, + const TargetMachineFeatures& target_machine_features) : dot_(dot), transpose_lhs_(transpose_lhs), transpose_rhs_(transpose_rhs), target_array_(target_array), lhs_array_(lhs_array), rhs_array_(rhs_array), + addend_array_(addend_array), executable_run_options_value_(executable_run_options_value), ir_builder_(ir_builder), - hlo_module_config_(hlo_module_config) {} + hlo_module_config_(hlo_module_config), + target_machine_features_(target_machine_features) {} /* static */ tensorflow::Status DotOpEmitter::EmitDotOperation( const HloInstruction& dot, bool transpose_lhs, bool transpose_rhs, const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& lhs_array, - const llvm_ir::IrArray& rhs_array, + const llvm_ir::IrArray& rhs_array, const llvm_ir::IrArray* addend_array, llvm::Value* executable_run_options_value, llvm::IRBuilder<>* ir_builder, - const HloModuleConfig& hlo_module_config) { + const HloModuleConfig& hlo_module_config, + const TargetMachineFeatures& target_machine_features) { PrimitiveType type = target_array.GetShape().element_type(); TF_RET_CHECK(F32 == type || F64 == type || C64 == type); DotOpEmitter dot_emitter(dot, transpose_lhs, transpose_rhs, target_array, - lhs_array, rhs_array, executable_run_options_value, - ir_builder, hlo_module_config); + lhs_array, rhs_array, addend_array, + executable_run_options_value, ir_builder, + hlo_module_config, target_machine_features); return dot_emitter.Emit(); } bool DotOpEmitter::ShapesAreLegalForRuntimeDot() const { return true; } bool DotOpEmitter::EmitLlvmIrDotIfProfitable() { - if (dot_.shape().dimensions_size() != 2 || - ProfitableToImplementDotInUntiledLlvmIr(dot_) == - DotInLlvmIrProfitable::kYes) { + if (dot_.shape().dimensions_size() != 2) { return false; } - if (!primitive_util::IsFloatingPointType(dot_.shape().element_type()) && - !primitive_util::IsIntegralType(dot_.shape().element_type())) { + PrimitiveType primitive_type = dot_.shape().element_type(); + + if (!primitive_util::IsFloatingPointType(primitive_type) && + !primitive_util::IsIntegralType(primitive_type)) { return false; } @@ -575,30 +617,76 @@ bool DotOpEmitter::EmitLlvmIrDotIfProfitable() { int64 tiling_factor = GetGemvTilingFactor(); CHECK_GT(tiling_factor, 0); + llvm::Value* result_op = target_array_.GetBasePointer(); + llvm::Value* lhs_op = + swap_operands ? rhs_array_.GetBasePointer() : lhs_array_.GetBasePointer(); + llvm::Value* rhs_op = + swap_operands ? lhs_array_.GetBasePointer() : rhs_array_.GetBasePointer(); + + const bool enable_fast_math = + hlo_module_config_.debug_options().xla_enable_fast_math(); + const bool optimize_for_size = + options::OptimizeForSizeRequested(hlo_module_config_); + + const int target_vector_register_element_size = + target_machine_features_.vector_register_num_elements( + *ir_builder_->GetInsertBlock()->getParent(), primitive_type); + + // We may not always know the vector register size for the target we're + // compiling against, in which case target_vector_register_element_size is 0. + // In these cases we choose a default LLVM IR register size. + const int kUnknownTargetVectorRegisterSize = 4; + const int vector_register_element_size = + target_vector_register_element_size == 0 + ? kUnknownTargetVectorRegisterSize + : target_vector_register_element_size; + if (is_column_major_matrix_vector) { VLOG(2) << "Emitting column major matrix-vector multiply with m = " << m << " and k = " << k; - ColumnMajorMatrixVectorProductEmitter emitter( - dot_.shape().element_type(), /*tile_rows=*/8, - /*tile_cols=*/tiling_factor, m, k, - swap_operands ? rhs_array_.GetBasePointer() - : lhs_array_.GetBasePointer(), - swap_operands ? lhs_array_.GetBasePointer() - : rhs_array_.GetBasePointer(), - target_array_.GetBasePointer(), ir_builder_); - emitter.Emit(); + int64 tile_rows = vector_register_element_size; + int64 tile_cols = tiling_factor; + + string kernel_name = tensorflow::strings::StrCat( + "col_major_gemv_", PrimitiveType_Name(primitive_type), "_", tile_rows, + "_", tile_cols, "_", m, "_", k, addend_array_ ? "_with_addend" : ""); + + KernelSupportLibrary::EmitAndCallOutlinedKernel( + /*enable_fast_math=*/enable_fast_math, + /*optimize_for_size=*/optimize_for_size, ir_builder_, kernel_name, + lhs_op, rhs_op, + addend_array_ ? addend_array_->GetBasePointer() : nullptr, result_op, + [this, tile_rows, tile_cols, m, k, primitive_type]( + llvm::Value* lhs_op, llvm::Value* rhs_op, llvm::Value* addend_op, + llvm::Value* result_op) { + ColumnMajorMatrixVectorProductEmitter emitter( + primitive_type, tile_rows, tile_cols, m, k, lhs_op, rhs_op, + addend_op, result_op, ir_builder_); + emitter.Emit(); + }); } else { VLOG(2) << "Emitting row major matrix-vector multiply with m = " << m << " and k = " << k; - RowMajorMatrixVectorProductEmitter emitter( - dot_.shape().element_type(), /*tile_rows=*/tiling_factor, - /*tile_cols=*/8, m, k, - swap_operands ? rhs_array_.GetBasePointer() - : lhs_array_.GetBasePointer(), - swap_operands ? lhs_array_.GetBasePointer() - : rhs_array_.GetBasePointer(), - target_array_.GetBasePointer(), ir_builder_); - emitter.Emit(); + int64 tile_rows = tiling_factor; + int64 tile_cols = vector_register_element_size; + + string kernel_name = tensorflow::strings::StrCat( + "row_major_gemv_", PrimitiveType_Name(primitive_type), "_", tile_rows, + "_", tile_cols, "_", m, "_", k, addend_array_ ? "_with_addend" : ""); + + KernelSupportLibrary::EmitAndCallOutlinedKernel( + /*enable_fast_math=*/enable_fast_math, + /*optimize_for_size=*/optimize_for_size, ir_builder_, kernel_name, + lhs_op, rhs_op, + addend_array_ ? addend_array_->GetBasePointer() : nullptr, result_op, + [this, tile_rows, tile_cols, m, k, primitive_type]( + llvm::Value* lhs_op, llvm::Value* rhs_op, llvm::Value* addend_op, + llvm::Value* result_op) { + RowMajorMatrixVectorProductEmitter emitter( + primitive_type, tile_rows, tile_cols, m, k, lhs_op, rhs_op, + addend_op, result_op, ir_builder_); + emitter.Emit(); + }); } return true; @@ -641,6 +729,8 @@ tensorflow::Status DotOpEmitter::Emit() { return Status::OK(); } + CHECK_EQ(addend_array_, nullptr); + if (PotentiallyImplementedAsEigenDot(dot_)) { return EmitCallToRuntime(); } @@ -915,8 +1005,8 @@ DotOpEmitter::MatMultDims DotOpEmitter::GetMatMultDims() const { return {lhs_shape.dimensions(transpose_lhs_ ? 1 : 0), lhs_shape.dimensions(transpose_lhs_ ? 0 : 1), rhs_shape.dimensions(transpose_rhs_ ? 0 : 1), - lhs_shape.layout().minor_to_major(0) == 0, - rhs_shape.layout().minor_to_major(0) == 0}; + LayoutUtil::Minor(lhs_shape.layout(), 0) == 0, + LayoutUtil::Minor(rhs_shape.layout(), 0) == 0}; } llvm_ir::IrArray::Index DotOpEmitter::EmitOperandArrayLoopNest( @@ -927,8 +1017,8 @@ llvm_ir::IrArray::Index DotOpEmitter::EmitOperandArrayLoopNest( // reduction dimension. std::vector dimensions; const Shape& shape = operand_array.GetShape(); - for (int i = shape.layout().minor_to_major_size() - 1; i >= 0; --i) { - int64 dimension = shape.layout().minor_to_major(i); + for (int i = LayoutUtil::MinorToMajor(shape).size() - 1; i >= 0; --i) { + int64 dimension = LayoutUtil::Minor(shape.layout(), i); if (dimension != reduction_dimension) { dimensions.push_back(dimension); } @@ -977,9 +1067,7 @@ bool PotentiallyImplementedAsEigenDot(const HloInstruction& hlo) { return false; } - if (ProfitableToImplementDotInUntiledLlvmIr(hlo) == - DotInLlvmIrProfitable::kYes || - ProfitableToImplementDotInTiledLlvmIr(hlo)) { + if (ProfitableToImplementDotInTiledLlvmIr(hlo)) { return false; } @@ -1010,46 +1098,42 @@ bool PotentiallyImplementedAsEigenDot(const HloInstruction& hlo) { return false; } -DotInLlvmIrProfitable ProfitableToImplementDotInUntiledLlvmIr( - const HloInstruction& dot) { - if (dot.opcode() == HloOpcode::kDot && dot.shape().dimensions_size() == 2) { - const Shape& result_shape = dot.shape(); - // kReductionDimensionThresholdBytes was chosen to be 1/4 of a typical L1 - // cache line size, so that we can have the reduction dimension of both the - // LHS and RHS matrices and still have some space "left over". This needs - // to be tuned further. - const int64 kReductionDimensionThresholdBytes = 8 * 1024; - const bool single_threaded_eigen = - !dot.GetModule()->config().debug_options().xla_cpu_multi_thread_eigen(); - - // This is the point at which it is better to call into Eigen and shard the - // dot across multiple worker threads. This is a rough estimate by running - // a matmult benchmark on my local machine, and it can be tuned further. - const int64 kMaxSingleThreadedFlops = 16 * 1024; - - const int64 M = result_shape.dimensions(0); - const int64 N = result_shape.dimensions(1); - const int64 K = dot.operand(1)->shape().dimensions(0); - const int64 primitive_type_size = - ShapeUtil::ByteSizeOfPrimitiveType(result_shape.element_type()); - if (M == 1 && - K * primitive_type_size <= kReductionDimensionThresholdBytes && - (single_threaded_eigen || M * K * N <= kMaxSingleThreadedFlops)) { - // Heuristics: - // - // - Look for a configuration where we will likely be able to keep LHS in - // L1 and do a cache-optimal traversal of RHS. - // - // - Bail out on matrices that are large enough that Eigen can profitably - // shard the computation across multiple cores. This only applies when - // multi-threading is enabled. - return LayoutUtil::IsMonotonicWithDim0Major( - dot.operand(1)->shape().layout()) - ? DotInLlvmIrProfitable::kWithColumnMajorRhs - : DotInLlvmIrProfitable::kYes; +// For vector-matrix dot products, it is always profitable to make the Rhs +// column major. +tensorflow::gtl::optional ProfitableToMakeDotOperandColumnMajor( + const HloInstruction& hlo) { + if (hlo.opcode() == HloOpcode::kDot && hlo.shape().dimensions_size() == 2 && + hlo.shape().dimensions(0) == 1) { + if (hlo.dot_dimension_numbers().rhs_contracting_dimensions(0) == 0) { + return 1; + } + return {}; + } + + if (hlo.opcode() == HloOpcode::kFusion && + hlo.fusion_kind() == HloInstruction::FusionKind::kOutput) { + auto* fusion_root = + hlo.fused_instructions_computation()->root_instruction(); + if (fusion_root->opcode() != HloOpcode::kAdd) { + return {}; + } + + for (auto* fusion_root_op : fusion_root->operands()) { + if (fusion_root_op->opcode() != HloOpcode::kDot) { + continue; + } + if (auto operand_num = + ProfitableToMakeDotOperandColumnMajor(*fusion_root_op)) { + auto* operand = fusion_root_op->operand(*operand_num); + if (operand->opcode() == HloOpcode::kParameter && + operand->user_count() == 1) { + return operand->parameter_number(); + } + } } } - return DotInLlvmIrProfitable::kNo; + + return {}; } bool ProfitableToImplementDotInTiledLlvmIr(const HloInstruction& dot) { diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h index c9168ccc0f6629c2a2bfbc7d4dc9c7ebab0a5708..9d748eb81f7850f3ccdb10f076eecfdc8326c05f 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h @@ -18,6 +18,7 @@ limitations under the License. #include "llvm/IR/IRBuilder.h" #include "tensorflow/compiler/xla/service/cpu/cpu_options.h" +#include "tensorflow/compiler/xla/service/cpu/target_machine_features.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" @@ -32,19 +33,11 @@ namespace cpu { bool PotentiallyImplementedAsEigenDot(const HloInstruction& hlo); -enum class DotInLlvmIrProfitable { kYes, kNo, kWithColumnMajorRhs }; - -// Returns a value to indicate if (and under what conditions) will lowering -// |dot| as a untiled LLVM IR dot operation be profitable over calling into -// Eigen or emitting a tiled LLVM IR implementation. Possible return values -// are: -// -// * DotInLlvmIrProfitable::kYes - always profitable. -// * DotInLlvmIrProfitable::kNo - never profitable. -// * DotInLlvmIrProfitable::kWithColumnMajorRhs - only if we can manage to make -// the Rhs layout column major. -DotInLlvmIrProfitable ProfitableToImplementDotInUntiledLlvmIr( - const HloInstruction& dot); +// Returns the index for an operand to `hlo` that should ideally be column +// major. Returns nullopt if there is no such operand or if `hlo` is not a dot +// or a fusion containing a dot. +tensorflow::gtl::optional ProfitableToMakeDotOperandColumnMajor( + const HloInstruction& hlo); // Returns true to indicate that we can generate a tiled LLVM IR implementation // for |dot|. @@ -57,21 +50,29 @@ class DotOpEmitter { // place the result in target_array. IR is emitted at current insert point of // the builder. Upon completion of the method, the insert point is set to the // end of all instructions emitted for this operation. + // + // If `addend_array` is not nullptr then it must be an array of the same + // dimensions as the result, and the result is computed as `addend_array` + + // dot(`lhs_array`, `rhs_array`). A non-null `addend_array` is only supported + // for Matrix-vector products. static tensorflow::Status EmitDotOperation( const HloInstruction& dot, bool transpose_lhs, bool transpose_rhs, const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& lhs_array, - const llvm_ir::IrArray& rhs_array, + const llvm_ir::IrArray& rhs_array, const llvm_ir::IrArray* addend_array, llvm::Value* executable_run_options_value, llvm::IRBuilder<>* ir_builder, - const HloModuleConfig& hlo_module_config); + const HloModuleConfig& hlo_module_config, + const TargetMachineFeatures& target_machine_features); private: DotOpEmitter(const HloInstruction& dot, bool transpose_lhs, bool transpose_rhs, const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array, + const llvm_ir::IrArray* addend_array, llvm::Value* executable_run_options_value, llvm::IRBuilder<>* ir_builder, - const HloModuleConfig& hlo_module_config); + const HloModuleConfig& hlo_module_config, + const TargetMachineFeatures& target_machine_features); // Emits the IR to perform the dot operation. tensorflow::Status Emit(); @@ -140,9 +141,11 @@ class DotOpEmitter { const llvm_ir::IrArray& target_array_; const llvm_ir::IrArray& lhs_array_; const llvm_ir::IrArray& rhs_array_; + const llvm_ir::IrArray* addend_array_; llvm::Value* executable_run_options_value_; llvm::IRBuilder<>* ir_builder_; const HloModuleConfig& hlo_module_config_; + const TargetMachineFeatures& target_machine_features_; }; } // namespace cpu diff --git a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc index ba693ec89ab7c4090f8c9d1e4d65f17a80d0ac55..ebd96c4c42759b71b79408c73814605301af03c1 100644 --- a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc @@ -44,15 +44,11 @@ StatusOr CpuElementalIrEmitter::EmitFloatUnaryOp( default: return Unimplemented("tanh"); } - // Create function type for the function. - llvm::FunctionType* function_type = llvm::FunctionType::get( - llvm_ir::PrimitiveTypeToIrType(element_type, module_), - llvm_ir::PrimitiveTypeToIrType(element_type, module_), - /*isVarArg=*/false); // Create function declaration for 'tanhf'. llvm::Function* function = llvm::cast(module_->getOrInsertFunction( - llvm_ir::AsStringRef(function_name), function_type)); + llvm_ir::AsStringRef(function_name), operand_value->getType(), + operand_value->getType())); function->setCallingConv(llvm::CallingConv::C); function->setDoesNotThrow(); function->setDoesNotAccessMemory(); @@ -64,6 +60,31 @@ StatusOr CpuElementalIrEmitter::EmitFloatUnaryOp( } } +StatusOr CpuElementalIrEmitter::EmitAtan2( + PrimitiveType prim_type, llvm::Value* lhs, llvm::Value* rhs) const { + string function_name; + switch (prim_type) { + case F32: + function_name = "atan2f"; + break; + case F64: + function_name = "atan2"; + break; + default: + return Unimplemented("atan2"); + } + // Create function declaration for 'atan2'. + llvm::Function* function = + llvm::cast(module_->getOrInsertFunction( + llvm_ir::AsStringRef(function_name), lhs->getType(), lhs->getType(), + rhs->getType())); + function->setCallingConv(llvm::CallingConv::C); + function->setDoesNotThrow(); + function->setDoesNotAccessMemory(); + // Create instruction to call 'atan2'. + return ir_builder_->CreateCall(function, {lhs, rhs}); +} + llvm_ir::ElementGenerator CpuElementalIrEmitter::MakeElementGenerator( const HloInstruction* hlo, const HloToElementGeneratorMap& operand_to_generator) const { diff --git a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h index 7e9f27befb456c17581f556868712f92fd8fd083..4446dfd2821fb4b6e75f33694367392ecbcdd8bf 100644 --- a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h @@ -41,6 +41,8 @@ class CpuElementalIrEmitter : public ElementalIrEmitter { protected: StatusOr EmitFloatUnaryOp( const HloInstruction* op, llvm::Value* operand_value) const override; + StatusOr EmitAtan2(PrimitiveType prim_type, llvm::Value* lhs, + llvm::Value* rhs) const override; IrEmitter* ir_emitter_; }; diff --git a/tensorflow/compiler/xla/service/cpu/external_constant_pool.cc b/tensorflow/compiler/xla/service/cpu/external_constant_pool.cc index c9f8e5584965d0c73771750e26bd63c401d5b0c0..7dcc4ca7fa08b478f24065275ffa69725dc51682 100644 --- a/tensorflow/compiler/xla/service/cpu/external_constant_pool.cc +++ b/tensorflow/compiler/xla/service/cpu/external_constant_pool.cc @@ -33,15 +33,12 @@ void ExternalConstantPool::Insert(string name, const Literal& literal, CHECK(entries_.find(name) == entries_.end()); int64 literal_size = ShapeUtil::ByteSizeOf(literal.shape()); - void* raw_pointer; - CHECK_EQ( - posix_memalign(&raw_pointer, std::max(alignment, sizeof(void*)), - literal_size), - 0) - << "failed to allocate " << literal_size << " bytes with alignment of " - << alignment; - - std::memcpy(raw_pointer, literal.InternalData(), literal_size); + void* raw_pointer = tensorflow::port::AlignedMalloc( + literal_size, std::max(alignment, sizeof(void*))); + CHECK(raw_pointer != nullptr) << "failed to allocate " << literal_size + << " bytes with alignment of " << alignment; + + std::memcpy(raw_pointer, literal.untyped_data(), literal_size); entries_.emplace(std::move(name), static_cast(raw_pointer)); } diff --git a/tensorflow/compiler/xla/service/cpu/external_constant_pool.h b/tensorflow/compiler/xla/service/cpu/external_constant_pool.h index ade28cbcbcfda05a9ad0adab1139bf316720e11f..9c00d476b1fca6c3174af4ebb62dbbde324fd0ea 100644 --- a/tensorflow/compiler/xla/service/cpu/external_constant_pool.h +++ b/tensorflow/compiler/xla/service/cpu/external_constant_pool.h @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/platform/mem.h" namespace xla { namespace cpu { @@ -49,10 +50,10 @@ class ExternalConstantPool { const uint8* Find(const string& name); private: - // We need to `free()` pointers allocated into `entries_` since we allocate - // them with `posix_memalign`. + // We need to `AlignedFree` pointers allocated into `entries_` since we + // allocate them with `AlignedMalloc`. struct FreeDeleter { - void operator()(void* ptr) { free(ptr); } + void operator()(void* ptr) { tensorflow::port::AlignedFree(ptr); } }; tensorflow::gtl::FlatMap> diff --git a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc index 3993779da636e519f8d8fded468c3271d27ee093..788217aab6172b4e548452b3f6ffd4197c163ce4 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc @@ -44,6 +44,9 @@ bool PotentiallyImplementedAsEigenConvolution( ShapeUtil::ElementIsComplex(kernel_shape)) { return false; } + if (window_util::HasWindowReversal(convolution.window())) { + return false; + } const ConvolutionDimensionNumbers& dnums = convolution.convolution_dimension_numbers(); diff --git a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.h b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.h index ac361ddfb4c8d253ffb1c99200939f6324cad2bb..34b2003916933f5ec0a15d9e219063c0a912fa40 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_EMISSION_UTILS_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_EMISSION_UTILS_H_ +#include "llvm/IR/Value.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" namespace xla { @@ -23,6 +24,19 @@ namespace cpu { bool PotentiallyImplementedAsEigenConvolution( const HloInstruction& convolution); + +// Dynamic loop bounds are specified as an array of dimension index +// [start, limit) pairs of ir values (one for each partitioned outer dimension). +// +// EX: Let 'shape' = [8, 16, 32], with the loop bounds of the two-most major +// dimensions dynamic. Then 'dynamic_loop_bounds' will contain the +// following ir values for the two most-major dimensions: +// [dim0_index_start_ir_value, dim0_index_limit_ir_value] +// [dim1_index_start_ir_value, dim1_index_limit_ir_value] +// +// See IrFunction and ParallelLoopEmitter for details. +using DynamicLoopBounds = std::vector>; + } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index 502dd2e7387d701e69e1c7ecb67fbdac26c6b5de..cfdf9f4ebc5a5ae2b0188c86edcdc70e3a596971 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -24,6 +24,7 @@ limitations under the License. #include #include +#include "tensorflow/core/lib/math/math_util.h" #include "tensorflow/core/platform/logging.h" // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc" #include "llvm/CodeGen/TargetRegisterInfo.h" @@ -42,6 +43,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h" #include "tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h" #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/cpu/ir_function.h" +#include "tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h" #include "tensorflow/compiler/xla/service/cpu/shape_partition.h" #include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h" #include "tensorflow/compiler/xla/service/elemental_ir_emitter.h" @@ -76,16 +79,16 @@ namespace cpu { IrEmitter::IrEmitter( const HloModule& hlo_module, const BufferAssignment& assignment, llvm::Module* llvm_module, - std::unordered_map hlo_to_profile_idx, - tensorflow::gtl::optional entry_computation_profile_idx, + std::unordered_map instruction_to_profile_idx, + std::unordered_map computation_to_profile_idx, llvm::TargetMachine* target_machine, ExternalConstantPool* external_constant_pool) : assignment_(assignment), module_(llvm_module), arch_type_(llvm::Triple(llvm_module->getTargetTriple()).getArch()), ir_builder_(llvm_module->getContext()), - hlo_to_profile_idx_(std::move(hlo_to_profile_idx)), - entry_computation_profile_idx_(std::move(entry_computation_profile_idx)), + instruction_to_profile_idx_(std::move(instruction_to_profile_idx)), + computation_to_profile_idx_(std::move(computation_to_profile_idx)), alias_analysis_(hlo_module, assignment, &llvm_module->getContext()), hlo_module_config_(hlo_module.config()), parallel_cpu_backend_( @@ -117,138 +120,33 @@ StatusOr IrEmitter::EmitComputation( // readcyclecounter if it is unavailable. bool use_rdtscp = arch_type_ == llvm::Triple::ArchType::x86 || arch_type_ == llvm::Triple::ArchType::x86_64; - profiling_state_ = ProfilingState(is_top_level_computation_, use_rdtscp, - GetProfileCountersArgument()); + profiling_state_ = ProfilingState(use_rdtscp, GetProfileCountersArgument()); if (instruction_order == nullptr) { TF_RETURN_IF_ERROR(computation->Accept(this)); } else { TF_RETURN_IF_ERROR(computation->AcceptOrdered(this, *instruction_order)); } - InsertOrDie(&emitted_functions_, computation, compute_function_); - - return compute_function_; -} - -static llvm::Argument* GetArg(llvm::Function* f, int idx) { - llvm::Function::arg_iterator arg_iter = f->arg_begin(); - std::advance(arg_iter, idx); - return &*arg_iter; + llvm::Function* ir_function = compute_function_->function(); + InsertOrDie(&emitted_functions_, computation, ir_function); + // Delete 'compute_function', finalizing 'ir_function' and restoring caller + // IR insert point. + compute_function_.reset(); + return ir_function; } void IrEmitter::InitializeIrFunction(const string& function_name) { - // The function signature is: - // void function(i8* retval, i8* run_options, i8** params, i8** temps, - // i64* dynamic_loop_bounds, i64* prof_counters) - // - // retval: points to the returned value. - // params: address of an array with pointers to parameters. - // temps: address of an array with pointers to temporary buffers. - // - // Therefore, the generated function's signature (FunctionType) is statically - // determined - parameter unpacking is done in code generated into the - // function, rather than by a prologue dictated by the platform ABI. - // - // /--------------\ - // retval ----------> | return value | - // \--------------/ - // - // /-------------------------------\ - // run_options -----> | xla::ExecutableRunOptions | - // \-------------------------------/ - // - // /---------------------------------------------\ - // params --------> | param 0 | param 1 | ..... | param N-1 | - // | addr | addr | | addr | - // \---------------------------------------------/ - // | | | - // | | | - // V V V - // /---------\ /---------\ /-----------\ - // | param 0 | | param 1 | | param N-1 | - // \---------/ \---------/ \-----------/ - // - // /---------------------------------------------\ - // temps ---------> | temp 0 | temp 1 | ..... | temp N-1 | - // | addr | addr | | addr | - // \---------------------------------------------/ - // | | | - // | | | - // V V V - // /---------\ /---------\ /-----------\ - // | temp 0 | | temp 1 | | temp N-1 | - // \---------/ \---------/ \-----------/ - // - // /--------------------------------------------\ - // dynamic loop bounds -> | outer_dim0_start | outer_dim0_limit | .....| - // (elided for aot) \--------------------------------------------/ - // - // /---------------------------------------------\ - // prof counters -> | counter 0 | counter 1 | ..... | counter N-1 | - // (elided for aot) \---------------------------------------------/ - - // Even though the type of params and temps is void** in the host's view, in - // LLVM IR this is represented by i8*, similarly to void*. It's up to the code - // to use GEPs to unravel the indirection layers. - llvm::FunctionType* compute_function_type = llvm::FunctionType::get( - /*Result=*/llvm::Type::getVoidTy(module_->getContext()), - /*Params=*/GetComputeFunctionParams(), - /*isVarArg=*/false); - // Functions with local linkage get an inlining bonus. Because we know // a-priori that embedded functions (non-entry functions) will not have its // name resolved, give it local linkage. llvm::Function::LinkageTypes linkage = is_top_level_computation_ ? llvm::GlobalValue::ExternalLinkage : llvm::GlobalValue::InternalLinkage; - compute_function_ = - llvm::Function::Create(/*Ty=*/compute_function_type, - /*Linkage=*/linkage, - /*Name=*/AsStringRef(function_name), - /*Module=*/module_); - compute_function_->setCallingConv(llvm::CallingConv::C); - - // Set meaningful names for the function's arguments: useful for debugging. - llvm::Function::arg_iterator arg_iter = compute_function_->arg_begin(); - arg_iter->setName("retval"); - (++arg_iter)->setName("run_options"); - (++arg_iter)->setName("params"); - (++arg_iter)->setName("temps"); - if (num_dynamic_loop_bounds_ > 0) { - (++arg_iter)->setName("dynamic_loop_bounds"); - } - (++arg_iter)->setName("prof_counters"); - - // We know a-priori that the function arguments are guaranteed to point to - // disjoint objects. - llvm::Argument* retval = GetResultArgument(); - for (llvm::Argument& argument : compute_function_->args()) { - // However, the return buffer aliases the temporaries and thus cannot be - // marked noalias. - if (&argument == retval) { - continue; - } - compute_function_->addAttribute(argument.getArgNo() + 1, - llvm::Attribute::NoAlias); - } - - // Add the optize attribute to the function if optimizing for size. This - // controls internal behavior of some optimization passes (e.g. loop - // unrolling). - if (options::OptimizeForSizeRequested(hlo_module_config_)) { - compute_function_->addFnAttr(llvm::Attribute::OptimizeForSize); - } - - if (hlo_module_config_.debug_options().xla_enable_fast_math()) { - compute_function_->addFnAttr("unsafe-fp-math", "true"); - compute_function_->addFnAttr("no-infs-fp-math", "true"); - compute_function_->addFnAttr("no-nans-fp-math", "true"); - compute_function_->addFnAttr("no-signed-zeros-fp-math", "true"); - } - - ir_builder_.SetInsertPoint(llvm::BasicBlock::Create( - /*Context=*/module_->getContext(), - /*Name=*/"entry", - /*Parent=*/compute_function_)); + // Create and initialize new IrFunction. + compute_function_.reset( + new IrFunction(function_name, linkage, + options::OptimizeForSizeRequested(hlo_module_config_), + hlo_module_config_.debug_options().xla_enable_fast_math(), + module_, &ir_builder_, num_dynamic_loop_bounds_)); } IrEmitter::~IrEmitter() {} @@ -344,11 +242,12 @@ int IrEmitter::MinimumAlignmentForBufferSize(int64 buffer_size) { // Calculate the alignment of a buffer allocated for a given primitive type. int IrEmitter::MinimumAlignmentForPrimitiveType(PrimitiveType primitive_type) { - int64 buffer_size = ShapeUtil::ByteSizeOfPrimitiveType(primitive_type); - DCHECK_GE(buffer_size, 0); - DCHECK_LE(buffer_size, SIZE_MAX); - - return MinimumAlignmentForBufferSize(buffer_size); + int64 byte_size = ShapeUtil::ByteSizeOfPrimitiveType(primitive_type); + DCHECK_GE(byte_size, 0); + // Largest scalar is a complex64 so we don't need to worry about the + // int64->int truncation here. + DCHECK_LE(byte_size, 8); + return byte_size; } int64 IrEmitter::ByteSizeOf(const Shape& shape) const { @@ -357,6 +256,10 @@ int64 IrEmitter::ByteSizeOf(const Shape& shape) const { // Calculate the alignment of a buffer allocated for a given shape. int IrEmitter::MinimumAlignmentForShape(const Shape& shape) { + if (ShapeUtil::IsScalar(shape)) { + return MinimumAlignmentForPrimitiveType(shape.element_type()); + } + int64 buffer_size = ByteSizeOf(shape); DCHECK_GE(buffer_size, 0); DCHECK_LE(buffer_size, SIZE_MAX); @@ -612,7 +515,7 @@ Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window) { HloComputation* function = reduce_window->to_apply(); TF_RETURN_IF_ERROR(ElementTypesSameAndSupported( /*instruction=*/*reduce_window, /*operands=*/{operand}, - /*supported_types=*/{F32})); + /*supported_types=*/{F32, BF16})); // TODO(b/31410564): Implement dilation for reduce-window. if (window_util::HasDilation(window)) { @@ -898,6 +801,24 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { TF_RETURN_IF_ERROR(ElementTypesSameAndSupported( /*instruction=*/*dot, /*operands=*/{lhs, rhs}, /*supported_types=*/{F32, F64, C64})); + const DotDimensionNumbers& dnums = dot->dot_dimension_numbers(); + if (dnums.lhs_batch_dimensions_size() > 0 || + dnums.rhs_batch_dimensions_size() > 0) { + return Unimplemented("Dot with batch dimensions not implemented."); + } + + if (dnums.lhs_contracting_dimensions_size() != 1) { + // This is disallowed by ShapeInference today. + return Unimplemented( + "Dot with multiple contracting dimensions not implemented."); + } + + if (dnums.lhs_contracting_dimensions(0) != + std::min(lhs->shape().dimensions_size() - 1, 1) || + dnums.rhs_contracting_dimensions(0) != 0) { + return Unimplemented( + "Dot with non-standard contracting dimensions not implemented."); + } llvm_ir::IrArray lhs_array(GetIrArrayFor(lhs)); llvm_ir::IrArray rhs_array(GetIrArrayFor(rhs)); @@ -916,8 +837,9 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { // Dot operation is complicated so we delegate to a helper class. return DotOpEmitter::EmitDotOperation( *dot, /*transpose_lhs=*/false, /*transpose_rhs=*/false, target_array, - lhs_array, rhs_array, GetExecutableRunOptionsArgument(), &ir_builder_, - hlo_module_config_); + lhs_array, rhs_array, /*addend_array=*/nullptr, + GetExecutableRunOptionsArgument(), &ir_builder_, hlo_module_config_, + target_machine_features_); } Status IrEmitter::HandleConvolution(HloInstruction* convolution) { @@ -1189,8 +1111,14 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) { llvm_ir::IrArray kernel_array(GetIrArrayFor(rhs)); llvm_ir::IrArray::Index kernel_index(num_dims); for (int i = 0; i < num_spatial_dims; ++i) { - kernel_index[dnums.kernel_spatial_dimensions(i)] = kernel_spatial[i]; + kernel_index[dnums.kernel_spatial_dimensions(i)] = + window.dimensions(i).window_reversal() + ? ir_builder_.CreateNSWSub( + ir_builder_.getInt64(window.dimensions(i).size() - 1), + kernel_spatial[i]) + : kernel_spatial[i]; } + kernel_index[dnums.kernel_input_feature_dimension()] = input_feature; kernel_index[dnums.kernel_output_feature_dimension()] = output_feature; @@ -1207,10 +1135,67 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) { }); } +Status IrEmitter::HandleFft(HloInstruction* fft) { + auto operand = fft->operand(0); + TF_RETURN_IF_ERROR(ElementTypesSameAndSupported( + /*instruction=*/*fft, /*operands=*/{operand}, + /*supported_types=*/{F32, C64})); + TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(operand->shape().layout())); + TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(fft->shape().layout())); + VLOG(3) << "operand=" << ShapeUtil::HumanStringWithLayout(operand->shape()); + VLOG(3) << "fft=" << ShapeUtil::HumanStringWithLayout(fft->shape()); + + llvm::Value* operand_address = GetEmittedValueFor(operand); + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(fft)); + + const std::vector& fft_length = fft->fft_length(); + int64 input_batch = 1; + for (int i = 0; i < fft->shape().dimensions_size() - fft_length.size(); i++) { + input_batch *= fft->shape().dimensions(i); + } + + // Args have been computed, make the call. + llvm::Type* int8_ptr_type = ir_builder_.getInt8Ty()->getPointerTo(); + llvm::Type* int32_type = ir_builder_.getInt32Ty(); + llvm::Type* int64_type = ir_builder_.getInt64Ty(); + llvm::FunctionType* fft_type = llvm::FunctionType::get( + ir_builder_.getVoidTy(), + {int8_ptr_type, int8_ptr_type, int8_ptr_type, int32_type, int32_type, + int64_type, int64_type, int64_type, int64_type}, + /*isVarArg=*/false); + const char* fn_name = runtime::kEigenFftSymbolName; + llvm::Function* fft_func = llvm::cast( + module_->getOrInsertFunction(fn_name, fft_type)); + fft_func->setCallingConv(llvm::CallingConv::C); + fft_func->setDoesNotThrow(); + fft_func->setOnlyAccessesInaccessibleMemOrArgMem(); + const int fft_rank = fft_length.size(); + ir_builder_.CreateCall( + fft_func, + {GetExecutableRunOptionsArgument(), + ir_builder_.CreateBitCast(GetEmittedValueFor(fft), int8_ptr_type), + ir_builder_.CreateBitCast(operand_address, int8_ptr_type), + ir_builder_.getInt32(fft->fft_type()), ir_builder_.getInt32(fft_rank), + ir_builder_.getInt64(input_batch), + ir_builder_.getInt64(fft_rank > 0 ? fft_length[0] : 0), + ir_builder_.getInt64(fft_rank > 1 ? fft_length[1] : 0), + ir_builder_.getInt64(fft_rank > 2 ? fft_length[2] : 0)}); + + return Status::OK(); +} + Status IrEmitter::HandleCrossReplicaSum(HloInstruction* crs) { + if (hlo_module_config_.replica_count() == 1) { + // When there is a single replica, a cross replica sum is the identity + // function, and the buffer assignment expects a copy (we could eliminate + // these at the HLO level as an optimization). + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(crs)); + return EmitMemcpy(*crs->operand(0), *crs); + } + // TODO(b/33011107): Support cross replica sum on CPU. return Unimplemented( - "Cross replica sum not implemented on CPU. See b/33011107."); + "Cross replica sum is not implemented on CPU. See b/33011107."); } // Fills up the free variables in 'index_with_free_var' with values from @@ -1452,15 +1437,20 @@ Status IrEmitter::HandleParameter(HloInstruction* parameter) { // // Where Param is the actual element type of the underlying buffer (for // example, float for an XLA F32 element type). - llvm::Argument* params = GetArg(compute_function_, 2); + llvm::Value* params = compute_function_->parameters_arg(); llvm::Value* param_address_offset = llvm_ir::EmitBufferIndexingGEP(params, param_number, &ir_builder_); llvm::LoadInst* param_address_untyped = ir_builder_.CreateLoad(param_address_offset); param_address_untyped->setName(AsStringRef(IrName(parameter, "untyped"))); - if (hlo_module_config_.debug_options() + if (is_top_level_computation_ && + hlo_module_config_.debug_options() .xla_llvm_enable_invariant_load_metadata()) { - // We never reassign parameters, so this load is invariant. + // In the entry computation the parameter slots in the %params argument are + // invariant through program execution. In computations that are called + // from the entry computation (via kWhile, kCall and kConditional) the + // parameter slots are *not* invariant since they're written to by their + // callers. param_address_untyped->setMetadata( llvm::LLVMContext::MD_invariant_load, llvm::MDNode::get(param_address_untyped->getContext(), /*MDs=*/{})); @@ -1587,13 +1577,9 @@ IrEmitter::ReductionGenerator IrEmitter::MatchReductionGenerator( IrEmitter::ShardedVectorType IrEmitter::CreateShardedVectorType( PrimitiveType element_type, unsigned element_count) { - // Here we assume that the largest register is a vector register. - int max_vector_register_size_in_bytes = - target_machine_features_.largest_register_size_in_bytes( - compute_function_); - int vector_register_size_in_elements = - max_vector_register_size_in_bytes / + target_machine_features_.vector_register_byte_size( + *compute_function_->function()) / ShapeUtil::ByteSizeOfPrimitiveType(element_type); ShardedVectorType sharded_vector_type; @@ -1748,19 +1734,6 @@ void IrEmitter::EmitShardedVectorStore( } } -namespace { -// TODO(sanjoy): This is duplicated in tensorflow/core/lib/core/arena.cc. -// Extract out a common implementation to tensorflow/core/lib/math/math_util.h -uint32 GCD(uint32 x, uint32 y) { - while (y != 0) { - uint32 r = x % y; - x = y; - y = r; - } - return x; -} -} // namespace - StatusOr IrEmitter::EmitVectorizedReduce( HloInstruction* reduce, HloInstruction* arg, HloInstruction* init_value, tensorflow::gtl::ArraySlice dimensions, HloComputation* function, @@ -1781,11 +1754,12 @@ StatusOr IrEmitter::EmitVectorizedReduce( bool is_reduction_over_minor_dimension = std::find(dimensions.begin(), dimensions.end(), - arg->shape().layout().minor_to_major(0)) != dimensions.end(); + LayoutUtil::Minor(arg->shape().layout(), 0)) != + dimensions.end(); - unsigned element_alignment = - GCD(ShapeUtil::ByteSizeOfPrimitiveType(reduce->shape().element_type()), - MinimumAlignmentForPrimitiveType(reduce->shape().element_type())); + unsigned element_alignment = tensorflow::MathUtil::GCD( + ShapeUtil::ByteSizeOfPrimitiveType(reduce->shape().element_type()), + MinimumAlignmentForPrimitiveType(reduce->shape().element_type())); if (is_reduction_over_minor_dimension) { // TODO(sanjoy): Implement vectorized reduction over the minor dimension. @@ -1818,8 +1792,9 @@ StatusOr IrEmitter::EmitVectorizedReduce( llvm_ir::ForLoopNest loop_nest(IrName(reduce), &ir_builder_); llvm_ir::IrArray::Index array_index(reduce->shape().dimensions_size()); - for (int i = reduce->shape().layout().minor_to_major_size() - 1; i > 0; --i) { - int64 dimension = reduce->shape().layout().minor_to_major(i); + for (int i = LayoutUtil::MinorToMajor(reduce->shape()).size() - 1; i > 0; + --i) { + int64 dimension = LayoutUtil::Minor(reduce->shape().layout(), i); int64 start_index = 0; int64 end_index = reduce->shape().dimensions(dimension); std::unique_ptr loop = @@ -1828,7 +1803,7 @@ StatusOr IrEmitter::EmitVectorizedReduce( array_index[dimension] = loop->GetIndVarValue(); } - int64 innermost_dimension = reduce->shape().layout().minor_to_major(0); + int64 innermost_dimension = LayoutUtil::Minor(reduce->shape().layout(), 0); int64 innermost_dimension_size = reduce->shape().dimensions(innermost_dimension); @@ -1864,10 +1839,10 @@ StatusOr IrEmitter::EmitVectorizedReduce( target_array); if (auto exit_terminator = loop->GetExitBasicBlock()->getTerminator()) { - CHECK_GT(reduce->shape().layout().minor_to_major_size(), 1); + CHECK_GT(LayoutUtil::MinorToMajor(reduce->shape()).size(), 1); ir_builder_.SetInsertPoint(exit_terminator); } else { - CHECK_EQ(reduce->shape().layout().minor_to_major_size(), 1); + CHECK_EQ(LayoutUtil::MinorToMajor(reduce->shape()).size(), 1); ir_builder_.SetInsertPoint(loop->GetExitBasicBlock()); } } @@ -1995,7 +1970,7 @@ Status IrEmitter::HandleSlice(HloInstruction* slice) { VLOG(2) << "HandleSlice: " << slice->ToString(); auto operand = slice->operand(0); // The code below emits a sequential loop nest. For the parallel backend, use - // EmitParallelTargetElementLoop() which respects dynamic loop bounds. + // ParallelLoopEmitter which respects dynamic loop bounds. if (ShouldEmitParallelLoopFor(*slice)) { return DefaultAction(slice); } @@ -2027,7 +2002,7 @@ Status IrEmitter::HandleSlice(HloInstruction* slice) { // * Implement the memcpy within the innermost loop. tensorflow::gtl::FlatSet inner_dims; - for (int64 dim : layout.minor_to_major()) { + for (int64 dim : LayoutUtil::MinorToMajor(layout)) { if (operand->shape().dimensions(dim) != slice->shape().dimensions(dim)) { break; } @@ -2054,7 +2029,7 @@ Status IrEmitter::HandleSlice(HloInstruction* slice) { // memcpy_dim is the innermost (in terms of layout) dimension for which the // slice does *not* just copy all the elements along the dimension. - const int64 memcpy_dim = layout.minor_to_major(inner_dims.size()); + const int64 memcpy_dim = LayoutUtil::Minor(layout, inner_dims.size()); const bool memcpy_is_contiguous = slice->slice_strides(memcpy_dim) == 1; // The number of logical elements that can be copied in a single call @@ -2263,8 +2238,8 @@ Status IrEmitter::HandleFusion(HloInstruction* fusion) { TF_RETURN_IF_ERROR(DotOpEmitter::EmitDotOperation( *root, root->operand(0)->IsRank2Transpose(), root->operand(1)->IsRank2Transpose(), target_array, lhs_array, - rhs_array, GetExecutableRunOptionsArgument(), &ir_builder_, - hlo_module_config_)); + rhs_array, /*addend_array=*/nullptr, GetExecutableRunOptionsArgument(), + &ir_builder_, hlo_module_config_, target_machine_features_)); return Status::OK(); } else if (llvm_ir::CanEmitFusedDynamicUpdateSliceInPlace(fusion, assignment_)) { @@ -2285,6 +2260,35 @@ Status IrEmitter::HandleFusion(HloInstruction* fusion) { TF_RETURN_IF_ERROR(fusion->fused_expression_root()->Accept(&fused_emitter)); return EmitTargetElementLoop(fusion, fused_emitter.GetRootGenerator()); + } else if (fusion->fusion_kind() == HloInstruction::FusionKind::kOutput) { + VLOG(3) << "HandleFusion kOutput"; + int64 dot_op_index = root->operand(0)->opcode() == HloOpcode::kDot ? 0 : 1; + const HloInstruction* dot = root->operand(dot_op_index); + CHECK_EQ(dot->opcode(), HloOpcode::kDot) + << dot->ToString() << " " + << fusion->fused_instructions_computation()->ToString(); + + int64 dot_lhs_param_number = dot->operand(0)->parameter_number(); + int64 dot_rhs_param_number = dot->operand(1)->parameter_number(); + int64 addend_param_number = + root->operand(1 - dot_op_index)->parameter_number(); + + Shape target_shape = fusion->shape(); + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(fusion)); + llvm_ir::IrArray target_array = GetIrArrayFor(fusion); + + llvm_ir::IrArray lhs_array( + GetIrArrayFor(fusion->operand(dot_lhs_param_number))); + llvm_ir::IrArray rhs_array( + GetIrArrayFor(fusion->operand(dot_rhs_param_number))); + llvm_ir::IrArray addend_array( + GetIrArrayFor(fusion->operand(addend_param_number))); + + TF_RETURN_IF_ERROR(DotOpEmitter::EmitDotOperation( + *dot, /*transpose_lhs=*/false, /*transpose_rhs=*/false, target_array, + lhs_array, rhs_array, &addend_array, GetExecutableRunOptionsArgument(), + &ir_builder_, hlo_module_config_, target_machine_features_)); + return Status::OK(); } else { return Unimplemented("Fusion kind not implemented on CPU"); } @@ -2305,9 +2309,17 @@ Status IrEmitter::HandleCall(HloInstruction* call) { !parallel_cpu_backend_) { // ParallelTaskAssignment assigned partitions, emit call to // ParallelForkJoin. - TF_RETURN_IF_ERROR(EmitParallelForkJoin(parameter_addresses, - emitted_value_[call], computation, - call_ir_function)); + std::vector call_args = GetArrayFunctionCallArguments( + parameter_addresses, &ir_builder_, computation->name(), + /*return_value_buffer=*/emitted_value_[call], + /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(), + /*temp_buffers_arg=*/GetTempBuffersArgument(), + /*profile_counters_arg=*/GetProfileCountersArgument()); + + HloInstruction* root = computation->root_instruction(); + TF_RETURN_IF_ERROR(EmitCallToParallelForkJoin( + call_args, root->shape(), root->outer_dimension_partitions(), + &ir_builder_, call_ir_function, computation->name())); } else { EmitArrayFunctionCallInto(call_ir_function, parameter_addresses, emitted_value_[call], computation->name()); @@ -2410,7 +2422,7 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) { // Terminates the current block with a branch to a while header. llvm::BasicBlock* header_bb = llvm::BasicBlock::Create( module_->getContext(), AsStringRef(IrName(xla_while, "header")), - compute_function_); + compute_function_->function()); ir_builder_.CreateBr(header_bb); ir_builder_.SetInsertPoint(header_bb); @@ -2427,7 +2439,7 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) { // Branches to the body or to the while exit depending on the condition. llvm::BasicBlock* body_bb = llvm::BasicBlock::Create( module_->getContext(), AsStringRef(IrName(xla_while, "body")), - compute_function_); + compute_function_->function()); llvm::BasicBlock* exit_bb = llvm::BasicBlock::Create( module_->getContext(), AsStringRef(IrName(xla_while, "exit"))); ir_builder_.CreateCondBr(while_predicate, body_bb, exit_bb); @@ -2442,7 +2454,7 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) { ir_builder_.CreateBr(header_bb); // Adds the exit block to the function and sets the insert point there. - compute_function_->getBasicBlockList().push_back(exit_bb); + compute_function_->function()->getBasicBlockList().push_back(exit_bb); ir_builder_.SetInsertPoint(exit_bb); return Status::OK(); @@ -2478,14 +2490,13 @@ StatusOr IrEmitter::EmitFastConcatenate( int64 concat_dim = concatenate->dimensions(0); const Layout& output_layout = output_shape.layout(); + auto output_min2maj = LayoutUtil::MinorToMajor(output_layout); auto concat_dim_layout_itr = - std::find(output_layout.minor_to_major().begin(), - output_layout.minor_to_major().end(), concat_dim); + std::find(output_min2maj.begin(), output_min2maj.end(), concat_dim); - std::vector inner_dims(output_layout.minor_to_major().begin(), - concat_dim_layout_itr); + std::vector inner_dims(output_min2maj.begin(), concat_dim_layout_itr); std::vector outer_dims(std::next(concat_dim_layout_itr), - output_layout.minor_to_major().end()); + output_min2maj.end()); llvm::Type* i8_ptr_type = ir_builder_.getInt8PtrTy(); llvm::Type* i8_type = ir_builder_.getInt8Ty(); @@ -2560,7 +2571,7 @@ void IrEmitter::EmitTransferElements(llvm::Value* target, llvm::Value* source, const llvm_ir::IrArray& source_array) { unsigned primitive_type_size = ShapeUtil::ByteSizeOfPrimitiveType(primitive_type); - unsigned element_alignment = GCD( + unsigned element_alignment = tensorflow::MathUtil::GCD( primitive_type_size, MinimumAlignmentForPrimitiveType(primitive_type)); llvm::Type* primitive_ptr_type = llvm::PointerType::getUnqual( llvm_ir::PrimitiveTypeToIrType(primitive_type, module_)); @@ -2607,6 +2618,65 @@ Status IrEmitter::HandleConcatenate(HloInstruction* concatenate) { return DefaultAction(concatenate); } +Status IrEmitter::HandleConditional(HloInstruction* conditional) { + auto pred = conditional->operand(0); + auto true_arg = conditional->operand(1); + auto false_arg = conditional->operand(2); + TF_RET_CHECK(ShapeUtil::IsScalar(pred->shape()) && + pred->shape().element_type() == PRED) + << "Predicate on a Conditional must be bool; got: " + << ShapeUtil::HumanString(pred->shape()); + + HloComputation* true_computation = conditional->true_computation(); + HloComputation* false_computation = conditional->false_computation(); + TF_RET_CHECK(ShapeUtil::Equal(conditional->shape(), + true_computation->root_instruction()->shape())) + << "Shape of conditional should be same as the shape of the true " + << "computation; got: " << ShapeUtil::HumanString(conditional->shape()) + << " and " + << ShapeUtil::HumanString(true_computation->root_instruction()->shape()); + + TF_RET_CHECK(ShapeUtil::Equal(conditional->shape(), + false_computation->root_instruction()->shape())) + << "Shape of conditional should be same as the shape of the false " + << "computation; got: " << ShapeUtil::HumanString(conditional->shape()) + << " and " + << ShapeUtil::HumanString(false_computation->root_instruction()->shape()); + + llvm::Function* true_function = + FindOrDie(emitted_functions_, true_computation); + llvm::Function* false_function = + FindOrDie(emitted_functions_, false_computation); + + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(conditional)); + llvm::Value* conditional_result = GetEmittedValueFor(conditional); + + // Generating: + // if (pred) + // cond_result = true_computation(true_operand) + // else + // cond_result = false_computation(false_operand) + llvm::LoadInst* pred_value = ir_builder_.CreateLoad( + GetIrArrayFor(pred).GetBasePointer(), "load_predicate_value"); + llvm::Value* pred_cond = ir_builder_.CreateICmpNE( + pred_value, + llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0), + "boolean_predicate"); + llvm_ir::LlvmIfData if_data = + llvm_ir::EmitIfThenElse(pred_cond, "conditional", &ir_builder_); + + SetToFirstInsertPoint(if_data.true_block, &ir_builder_); + EmitArrayFunctionCallInto(true_function, {GetEmittedValueFor(true_arg)}, + conditional_result, IrName(conditional, "_true")); + + SetToFirstInsertPoint(if_data.false_block, &ir_builder_); + EmitArrayFunctionCallInto(false_function, {GetEmittedValueFor(false_arg)}, + conditional_result, IrName(conditional, "_false")); + + SetToFirstInsertPoint(if_data.after_block, &ir_builder_); + return Status::OK(); +} + Status IrEmitter::FinishVisit(HloInstruction* root) { // When this method is called, we should have already emitted an IR value for // the root (return) op. The IR value holds the address of the buffer holding @@ -2618,57 +2688,51 @@ Status IrEmitter::FinishVisit(HloInstruction* root) { llvm::Value* root_value = GetEmittedValueFor(root); VLOG(2) << " value: " << llvm_ir::DumpToString(*root_value); - llvm::Value* prof_counter = [&]() { - // For the parallel cpu backend, we record the total for each embedded - // computation callee with its caller kCall HLO. - if (parallel_cpu_backend_ && is_top_level_computation_) { - auto* computation = root->parent(); - auto* entry_computation = computation->parent()->entry_computation(); - if (computation != entry_computation) { - for (HloInstruction* instruction : entry_computation->instructions()) { - if (instruction->opcode() == HloOpcode::kCall && - instruction->to_apply()->root_instruction() == root) { - return GetProfileCounterFor(*instruction); - } + auto record_complete_computation = [&](llvm::Value* prof_counter) { + if (prof_counter) { + profiling_state_.RecordCompleteComputation(&ir_builder_, prof_counter); + } + }; + + // For the parallel cpu backend, we record the total for each embedded + // computation callee with its caller kCall HLO. + if (parallel_cpu_backend_ && is_top_level_computation_) { + auto* computation = root->parent(); + auto* entry_computation = computation->parent()->entry_computation(); + if (computation != entry_computation) { + for (HloInstruction* instruction : entry_computation->instructions()) { + if (instruction->opcode() == HloOpcode::kCall && + instruction->to_apply()->root_instruction() == root) { + record_complete_computation(GetProfileCounterFor(*instruction)); + return Status::OK(); } } } - - // Otherwise we record the total computation cycles in a dedicated slot for - // the entry computation. - return GetProfileCounterForEntryComputation(); - }(); - - if (prof_counter) { - profiling_state_.RecordCompleteComputation(&ir_builder_, prof_counter); } - ir_builder_.CreateRetVoid(); + + // For the entry computation this increment is cumulative of embedded + // computations since it includes cycles spent in computations invoked by + // While, Call etc. + record_complete_computation(GetProfileCounterFor(*root->parent())); return Status::OK(); } -llvm::Value* IrEmitter::GetProfileCounterFor(const HloInstruction& hlo) { - auto it = hlo_to_profile_idx_.find(&hlo); - if (it == hlo_to_profile_idx_.end()) { +template +llvm::Value* IrEmitter::GetProfileCounterCommon( + const T& hlo, + const std::unordered_map& profile_index_map) { + auto it = profile_index_map.find(&hlo); + if (it == profile_index_map.end()) { return nullptr; } - size_t prof_counter_idx = it->second; + int64 prof_counter_idx = it->second; string counter_name = IrName("prof_counter", hlo.name()); return ir_builder_.CreateGEP(GetProfileCountersArgument(), ir_builder_.getInt64(prof_counter_idx), AsStringRef(counter_name)); } -llvm::Value* IrEmitter::GetProfileCounterForEntryComputation() { - if (entry_computation_profile_idx_) { - return ir_builder_.CreateGEP( - GetProfileCountersArgument(), - ir_builder_.getInt64(*entry_computation_profile_idx_), - "prof_counter.computation"); - } - return nullptr; -} - void IrEmitter::ProfilingState::UpdateProfileCounter( llvm::IRBuilder<>* ir_builder, llvm::Value* prof_counter, llvm::Value* cycle_end, llvm::Value* cycle_start) { @@ -2731,8 +2795,7 @@ void IrEmitter::ProfilingState::RecordCycleDelta(llvm::IRBuilder<>* ir_builder, void IrEmitter::ProfilingState::RecordCompleteComputation( llvm::IRBuilder<>* ir_builder, llvm::Value* prof_counter) { - if (is_top_level_computation_ && last_read_cycle_end_ && - first_read_cycle_start_) { + if (last_read_cycle_end_ && first_read_cycle_start_) { UpdateProfileCounter(ir_builder, prof_counter, last_read_cycle_end_, first_read_cycle_start_); } @@ -2740,7 +2803,7 @@ void IrEmitter::ProfilingState::RecordCompleteComputation( Status IrEmitter::Preprocess(HloInstruction* hlo) { VLOG(3) << "Visiting: " << hlo->ToString(); - if (hlo_to_profile_idx_.count(hlo)) { + if (instruction_to_profile_idx_.count(hlo)) { profiling_state_.RecordCycleStart(&ir_builder_, hlo); } return Status::OK(); @@ -2783,43 +2846,16 @@ llvm::Type* IrEmitter::IrShapeType(const Shape& shape) { return llvm_ir::ShapeToIrType(shape, module_); } -std::vector IrEmitter::GetComputeFunctionParams() { - llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(module_->getContext()); - llvm::Type* i8_ptr_ptr_type = i8_ptr_type->getPointerTo(); - llvm::Type* i64_ptr_type = llvm::Type::getInt64PtrTy(module_->getContext()); - std::vector compute_function_params( - {i8_ptr_type, i8_ptr_type, i8_ptr_ptr_type, i8_ptr_ptr_type}); - if (num_dynamic_loop_bounds_ > 0) { - compute_function_params.push_back(i64_ptr_type); - } - compute_function_params.push_back(i64_ptr_type); - return compute_function_params; -} - -llvm::Argument* IrEmitter::GetResultArgument() { - return GetArg(compute_function_, 0); -} - -llvm::Argument* IrEmitter::GetProfileCountersArgument() { - const int64 arg_index = num_dynamic_loop_bounds_ > 0 ? 5 : 4; - return GetArg(compute_function_, arg_index); +llvm::Value* IrEmitter::GetProfileCountersArgument() { + return compute_function_->profile_counters_arg(); } llvm::Value* IrEmitter::GetTempBuffersArgument() { - return GetArg(compute_function_, 3); -} - -llvm::Value* IrEmitter::GetDynamicLoopBound(const int64 offset) { - CHECK_GT(num_dynamic_loop_bounds_, 0); - CHECK_LT(offset, num_dynamic_loop_bounds_ * 2); - llvm::Argument* loop_bounds_arg = GetArg(compute_function_, 4); - string name = tensorflow::strings::StrCat("dynamic_loop_bound_", offset); - return ir_builder_.CreateLoad(ir_builder_.CreateGEP( - loop_bounds_arg, ir_builder_.getInt64(offset), AsStringRef(name))); + return compute_function_->temp_buffers_arg(); } llvm::Value* IrEmitter::GetExecutableRunOptionsArgument() { - return GetArg(compute_function_, 1); + return compute_function_->exec_run_options_arg(); } llvm::Value* IrEmitter::EmitTempBufferPointer( @@ -2850,10 +2886,14 @@ llvm::Value* IrEmitter::EmitTempBufferPointer( GetTempBuffersArgument(), slice.index(), &ir_builder_); llvm::LoadInst* tempbuf_address_base = ir_builder_.CreateLoad(tempbuf_address_ptr); - if (hlo_module_config_.debug_options() + if (is_top_level_computation_ && + hlo_module_config_.debug_options() .xla_llvm_enable_invariant_load_metadata()) { - // Loading the address of a buffer is invariant of the point at which the - // load is executed in the program because we never reassign buffers. + // In the entry computation the parameter slots in the %params argument are + // invariant through program execution. In computations that are called + // from the entry computation (via kWhile, kCall and kConditional) the + // parameter slots are *not* invariant since they're written to by their + // callers. tempbuf_address_base->setMetadata( llvm::LLVMContext::MD_invariant_load, llvm::MDNode::get(tempbuf_address_base->getContext(), /*MDs=*/{})); @@ -2884,42 +2924,6 @@ llvm::Value* IrEmitter::EmitElementFunctionCall( AsStringRef(tensorflow::strings::StrCat(name, "_return_value"))); } -// Emits code to allocate an array of parameter address pointers, and store -// each address from 'parameter_addresses'. -// Returns an array of compute function call arguments (including parameter -// address buffer). -std::vector IrEmitter::GetArrayFunctionCallArguments( - tensorflow::gtl::ArraySlice parameter_addresses, - llvm::Value* return_value_buffer, tensorflow::StringPiece name) { - llvm::Value* parameter_addresses_buffer = - llvm_ir::EmitAllocaAtFunctionEntryWithCount( - ir_builder_.getInt8PtrTy(), - ir_builder_.getInt32(parameter_addresses.size()), - tensorflow::strings::StrCat(name, "_parameter_addresses"), - &ir_builder_); - for (size_t i = 0; i < parameter_addresses.size(); ++i) { - llvm::Value* parameter_as_i8ptr = ir_builder_.CreateBitCast( - parameter_addresses[i], ir_builder_.getInt8PtrTy(), - AsStringRef(tensorflow::strings::StrCat(name, "_parameter_", i, - "_address_as_i8ptr"))); - llvm::Value* slot_in_param_adresses = ir_builder_.CreateInBoundsGEP( - parameter_addresses_buffer, {ir_builder_.getInt64(i)}); - ir_builder_.CreateStore(parameter_as_i8ptr, slot_in_param_adresses); - } - - const auto to_int8_ptr = [this](llvm::Value* ptr) { - return ir_builder_.CreatePointerCast(ptr, ir_builder_.getInt8PtrTy()); - }; - std::vector arguments{ - to_int8_ptr(return_value_buffer), - to_int8_ptr(GetExecutableRunOptionsArgument()), - parameter_addresses_buffer, GetTempBuffersArgument()}; - if (auto* profile_counters = GetProfileCountersArgument()) { - arguments.push_back(profile_counters); - } - return arguments; -} - // Emits a core function call based on the following pseudo-code. // // char** parameter_addresses_buffer = @@ -2935,8 +2939,12 @@ void IrEmitter::EmitArrayFunctionCallInto( tensorflow::gtl::ArraySlice parameter_addresses, llvm::Value* return_value_buffer, tensorflow::StringPiece name) { ir_builder_.CreateCall( - function, GetArrayFunctionCallArguments(parameter_addresses, - return_value_buffer, name)); + function, GetArrayFunctionCallArguments( + parameter_addresses, &ir_builder_, name, + /*return_value_buffer=*/return_value_buffer, + /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(), + /*temp_buffers_arg=*/GetTempBuffersArgument(), + /*profile_counters_arg=*/GetProfileCountersArgument())); } llvm::Value* IrEmitter::EmitArrayFunctionCall( @@ -2956,117 +2964,13 @@ llvm::Value* IrEmitter::EmitArrayFunctionCall( return return_value_buffer; } -// Emits a call to a runtime fork/join function which dispatches parallel -// calls to 'parallel_function' (and joins threads before returning). -Status IrEmitter::EmitParallelForkJoin( - tensorflow::gtl::ArraySlice parameter_addresses, - llvm::Value* output_address, HloComputation* computation, - llvm::Function* parallel_function) { - HloInstruction* root = computation->root_instruction(); - - // Build ParallelForkJoin function type. - std::vector compute_function_params = GetComputeFunctionParams(); - // Number of parallel compute functions. - compute_function_params.push_back(ir_builder_.getInt32Ty()); - // Array of partitions. There is an array element for each - // partition x partition_dim x 2 (for dimension start and limit). - compute_function_params.push_back( - llvm::Type::getInt64PtrTy(module_->getContext())); - // Number of partitioned most-major dimensions in 'root.shape'. - compute_function_params.push_back(ir_builder_.getInt32Ty()); - // Function pointer for compute function to be dispatched in parallel. - compute_function_params.push_back( - llvm::Type::getInt8PtrTy(module_->getContext())); - - llvm::FunctionType* fork_join_type = llvm::FunctionType::get( - /*Result=*/llvm::Type::getVoidTy(module_->getContext()), - /*Params=*/compute_function_params, - /*isVarArg=*/false); - - llvm::Function* fork_join_func = - llvm::cast(module_->getOrInsertFunction( - runtime::kParallelForkJoinSymbolName, fork_join_type)); - fork_join_func->setCallingConv(llvm::CallingConv::C); - fork_join_func->setDoesNotThrow(); - - // Add common compute function arguments. - const string name = computation->name(); - std::vector arguments = - GetArrayFunctionCallArguments(parameter_addresses, output_address, name); - - // Create ShapePartitionIterator to generate all partitions of 'root.shape'. - ShapePartitionIterator partition_iterator(root->shape(), - root->outer_dimension_partitions()); - const int64 num_partitions = partition_iterator.GetTotalPartitionCount(); - // Add argument specifying the number of parallel partitions. - arguments.push_back(ir_builder_.getInt32(num_partitions)); - - // The number of partitioned most-major dimensions in 'root.shape'. - const int32 num_partitioned_dims = root->outer_dimension_partitions().size(); - // A dimension partition consists of two elements: [start_index, limit_index). - const int32 dim_partition_size = 2; - // Calculate array partition stride. - const int32 array_partition_stride = - num_partitioned_dims * dim_partition_size; - // Calculate the total number of elements in the partition array. - const int32 partition_array_size = - dim_partition_size * num_partitioned_dims * num_partitions; - - // Store dimension partition values as llvm constants in 'partitions'. - // See comments in runtime_fork_join.cc for array layout description. - std::vector partitions(partition_array_size); - for (int32 i = 0; i < num_partitions; ++i) { - std::vector> dim_partitions = - partition_iterator.GetPartition(i); - CHECK_EQ(num_partitioned_dims, dim_partitions.size()); - const int32 partition_index = i * array_partition_stride; - for (int32 j = 0; j < num_partitioned_dims; ++j) { - const std::pair& dim_partition = dim_partitions[j]; - const int32 index = partition_index + j * dim_partition_size; - // Store partition [dim_start, dim_limit) intervals for each dimension. - partitions[index] = ir_builder_.getInt64(dim_partition.first); - partitions[index + 1] = - ir_builder_.getInt64(dim_partition.first + dim_partition.second); - } - } - - // Create global variable out of dimension partitions in 'partitions'. - llvm::ArrayType* partitions_array_type = - llvm::ArrayType::get(ir_builder_.getInt64Ty(), partition_array_size); - llvm::Constant* partitions_array = - llvm::ConstantArray::get(partitions_array_type, partitions); - llvm::GlobalVariable* global_partitions_array = new llvm::GlobalVariable( - /*Module=*/*module_, - /*Type=*/partitions_array_type, - /*isConstant=*/true, - /*Linkage=*/llvm::GlobalValue::PrivateLinkage, - /*Initializer=*/partitions_array, - /*Name=*/ - AsStringRef( - tensorflow::strings::StrCat(name, "_parallel_dimension_partitions"))); - - // Add argument specifying parallel dimension partitions. - arguments.push_back(ir_builder_.CreateBitCast( - global_partitions_array, - llvm::Type::getInt64PtrTy(module_->getContext()))); - // Add argument specifying the number of partitioned most-major dimensions. - arguments.push_back(ir_builder_.getInt32(num_partitioned_dims)); - // Add argument for parallel compute function pointer. - arguments.push_back( - ir_builder_.CreateBitCast(parallel_function, ir_builder_.getInt8PtrTy())); - // Emit call to parallel fork/join. - ir_builder_.CreateCall(fork_join_func, arguments); - - return Status::OK(); -} - Status IrEmitter::EmitTargetAddressForOp(const HloInstruction* op) { llvm::Value* addr; const Shape& target_shape = op->shape(); if (op == op->parent()->root_instruction()) { // For the root node, we write directly to the output buffer of the // function. - llvm::Argument* retval = GetResultArgument(); + llvm::Argument* retval = compute_function_->result_arg(); if (!ShapeUtil::IsNil(target_shape)) { llvm::AttrBuilder attr_builder; attr_builder.addAlignmentAttr(MinimumAlignmentForShape(target_shape)); @@ -3127,8 +3031,13 @@ Status IrEmitter::EmitTargetElementLoop( } else { if (ShouldEmitParallelLoopFor(*target_op)) { - TF_RETURN_IF_ERROR(EmitParallelTargetElementLoop( - target_shape, element_generator, IrName(target_op), &target_array)); + // Emit code to read dynamic loop bounds from compute function argument. + std::vector> dynamic_loop_bounds = + compute_function_->GetDynamicLoopBounds(); + // Emit parallel loop with dynamic loop bounds for most-major dimensions. + TF_RETURN_IF_ERROR(ParallelLoopEmitter(element_generator, target_array, + &dynamic_loop_bounds, &ir_builder_) + .EmitLoop(IrName(target_op))); } else { TF_RETURN_IF_ERROR( llvm_ir::LoopEmitter(element_generator, target_array, &ir_builder_) @@ -3138,60 +3047,6 @@ Status IrEmitter::EmitTargetElementLoop( return Status::OK(); } -Status IrEmitter::EmitParallelTargetElementLoop( - const Shape& target_shape, - const llvm_ir::ElementGenerator& element_generator, - tensorflow::StringPiece loop_name, llvm_ir::IrArray* target_array) { - CHECK(!ShapeUtil::IsTuple(target_shape)); - CHECK(!ShapeUtil::IsScalar(target_shape)); - - // Emit code to read dynamic loop bounds from function argument 4. - std::vector dynamic_loop_bounds(2 * num_dynamic_loop_bounds_); - for (int i = 0; i < 2 * num_dynamic_loop_bounds_; ++i) { - dynamic_loop_bounds[i] = GetDynamicLoopBound(i); - } - - llvm_ir::ForLoopNest loop_nest(loop_name, &ir_builder_); - const int64 num_dims = target_shape.dimensions_size(); - llvm_ir::IrArray::Index array_index(num_dims); - - // Add loops from outer-most to inner-most dimensions. - for (int i = target_shape.layout().minor_to_major_size() - 1; i >= 0; --i) { - const int64 dimension = target_shape.layout().minor_to_major(i); - const int bounds_index = num_dims - 1 - i; - if (bounds_index < num_dynamic_loop_bounds_) { - // Emit dynamic loop bounds for this dimension. Dynamic loop bounds - // are read from ir function dynamic loop bounds argument. - llvm::Value* start_index = dynamic_loop_bounds[bounds_index * 2 + 0]; - llvm::Value* end_index = dynamic_loop_bounds[bounds_index * 2 + 1]; - - std::unique_ptr loop = loop_nest.AddLoop( - /*suffix=*/tensorflow::strings::Printf("dim.%lld", dimension), - start_index, end_index); - array_index[dimension] = loop->GetIndVarValue(); - } else { - // Emit static loop bounds for this dimension. - std::unique_ptr loop = loop_nest.AddLoop( - /*start_index=*/0, - /*end_index=*/target_shape.dimensions(dimension), - /*suffix=*/tensorflow::strings::Printf("dim.%lld", dimension)); - array_index[dimension] = loop->GetIndVarValue(); - } - } - // Point IR builder at inner loop BB. - SetToFirstInsertPoint(loop_nest.GetInnerLoopBodyBasicBlock(), &ir_builder_); - - // Emit loop body. - TF_ASSIGN_OR_RETURN(llvm::Value * target_element, - element_generator(array_index)); - target_array->EmitWriteArrayElement(array_index, target_element, - &ir_builder_); - // Point IR builder at outer loop exit BB. - SetToFirstInsertPoint(loop_nest.GetOuterLoopExitBasicBlock(), &ir_builder_); - - return Status::OK(); -} - Status IrEmitter::EmitMemcpy(const HloInstruction& source, const HloInstruction& destination) { llvm::Value* source_value = GetEmittedValueFor(&source); @@ -3249,37 +3104,5 @@ StatusOr IrEmitter::EmitScalarCall( ShapeUtil::MakeShape(return_type, {}), argument_addrs, name); } - -unsigned TargetMachineFeatures::largest_register_size_in_bytes( - llvm::Function* function) { - auto itr = largest_register_size_in_bytes_.find(function); - if (itr != largest_register_size_in_bytes_.end()) { - return itr->second; - } - - int result = largest_register_size_in_bytes_impl(function); - - InsertOrDie(&largest_register_size_in_bytes_, function, result); - DCHECK_EQ(result, largest_register_size_in_bytes_.begin()->second); - return result; -} - -unsigned TargetMachineFeatures::largest_register_size_in_bytes_impl( - llvm::Function* function) const { - auto register_info = - target_machine_->getSubtargetImpl(*function)->getRegisterInfo(); - - unsigned largest_register_size = 0; - for (const llvm::TargetRegisterClass* register_class : - register_info->regclasses()) { - if (register_class->isAllocatable()) { - largest_register_size = - std::max(largest_register_size, - register_info->getRegSizeInBits(*register_class)); - } - } - - return largest_register_size / 8; -} } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h index 351c95278c17f536e56d9f085b938a9baea9cde1..66f2aeeab33dbaa34297c8dc6a37c3ad481820d8 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include #include #include #include @@ -30,6 +31,8 @@ limitations under the License. #include "llvm/Target/TargetMachine.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/cpu/external_constant_pool.h" +#include "tensorflow/compiler/xla/service/cpu/ir_function.h" +#include "tensorflow/compiler/xla/service/cpu/target_machine_features.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -49,49 +52,6 @@ limitations under the License. namespace xla { namespace cpu { - -// Wraps an llvm::TargetMachine and parses out some information that feeds into -// code LLVM IR generation decisions. -// -// Ideally we'd be able to use llvm::TargetTransformInfo here (since its -// interface is pretty much a perfect fit for our use case), but obtaining an -// instance of llvm::TargetTransformInfo outside an LLVM pass pipeline without -// super-ugly hacks is difficult. -// -// TODO(b/66049221): See if the LLVM community will be receptive to exposing an -// API that lets us directly create and use llvm::TargetTransformInfo instances -// outside of a pass manager. -class TargetMachineFeatures { - public: - TargetMachineFeatures(llvm::TargetMachine* target_machine) - : target_machine_(target_machine) {} - - // Return the vectorization factor, which is the number of bytes of data - // explicitly vectorized routines will try to process at once. - int vectorization_factor_in_bytes() const { - // Ideally this should be a function of the cache line size (which we can - // get from llvm::TargetTransformInfo::getCacheLineSize) of the target - // machine. Guess a value of 128 bytes for now. - return 128; - } - - // Return the size of the largest register size in bytes. We need to pass in - // "function" since llvm functions can contain annotations for specializing - // them to specific micro-architectures (though currently XLA does not use - // this functionality). - // - // Ideally we should have been able to use - // llvm::TargetTransformInfo::getRegisterBitWidth(true) here. - unsigned largest_register_size_in_bytes(llvm::Function* function); - - private: - unsigned largest_register_size_in_bytes_impl(llvm::Function* function) const; - - tensorflow::gtl::FlatMap - largest_register_size_in_bytes_; - llvm::TargetMachine* target_machine_; -}; - // This class is the top-level API for the XLA HLO --> LLVM IR compiler. It // implements the DfsHloVisitor interface and emits HLO computations as LLVM IR // functions. @@ -103,20 +63,21 @@ class IrEmitter : public DfsHloVisitorWithDefault { // assignment: a BufferAssignment from which we know which temporary buffers // are used by the HLO nodes. // llvm_module: the LLVM module to emit IR into. - // hlo_to_profile_idx: the mapping from HLO to its index in the profiling - // array. - // entry_computation_profile_idx: the index in the profiling array - // for the entry computation. + // instruction_to_profile_idx: the mapping from HLO instructions to their + // index in the profiling array. + // computation_to_profile_idx: the mapping from HLO computations to their + // index in the profiling array. // external_constant_pool: if non-null, points to an ExternalConstantPool // instance into which the Ir emitter can spill // constants. - IrEmitter( - const HloModule& hlo_module, const BufferAssignment& assignment, - llvm::Module* llvm_module, - std::unordered_map hlo_to_profile_idx, - tensorflow::gtl::optional entry_computation_profile_idx, - llvm::TargetMachine* target_machine, - ExternalConstantPool* external_constant_pool); + IrEmitter(const HloModule& hlo_module, const BufferAssignment& assignment, + llvm::Module* llvm_module, + std::unordered_map + instruction_to_profile_idx, + std::unordered_map + computation_to_profile_idx, + llvm::TargetMachine* target_machine, + ExternalConstantPool* external_constant_pool); ~IrEmitter() override; // Emit and return the given HLO computation as an LLVM IR @@ -163,6 +124,7 @@ class IrEmitter : public DfsHloVisitorWithDefault { Status HandleSelect(HloInstruction* select) override; Status HandleDot(HloInstruction* dot) override; Status HandleConvolution(HloInstruction* convolution) override; + Status HandleFft(HloInstruction* fft) override; Status HandleBatchNormTraining(HloInstruction* batch_norm_training) override; Status HandleBatchNormGrad(HloInstruction* batch_norm_grad) override; Status HandleCrossReplicaSum(HloInstruction* crs) override; @@ -189,6 +151,7 @@ class IrEmitter : public DfsHloVisitorWithDefault { Status HandleCustomCall(HloInstruction* custom_call) override; Status HandleWhile(HloInstruction* xla_while) override; Status HandleConcatenate(HloInstruction* concatenate) override; + Status HandleConditional(HloInstruction* conditional) override; Status FinishVisit(HloInstruction* root) override; Status Preprocess(HloInstruction* hlo) override; @@ -198,14 +161,23 @@ class IrEmitter : public DfsHloVisitorWithDefault { // Private helper to initialize an IR function for the computation. void InitializeIrFunction(const string& function_name); - // Convenience function to generate a GEP into the profile counter parameter - // which would correspond to the index for a given HLO. - llvm::Value* GetProfileCounterFor(const HloInstruction& hlo); + template + llvm::Value* GetProfileCounterCommon( + const T& hlo, + const std::unordered_map& profile_index_map); + + // Convenience functions to generate a GEP into the profile counter parameter + // which would correspond to the index for a given HLO instruction or + // computation. + llvm::Value* GetProfileCounterFor(const HloInstruction& instruction) { + return GetProfileCounterCommon(instruction, + instruction_to_profile_idx_); + } - // Convenience function to generate a GEP into the profile counter parameter - // corresponding to the index for the entry computation. Returns nullptr if - // profiling the entry computation is disabled. - llvm::Value* GetProfileCounterForEntryComputation(); + llvm::Value* GetProfileCounterFor(const HloComputation& computation) { + return GetProfileCounterCommon(computation, + computation_to_profile_idx_); + } // Gets the IR Value emitted previously for the given hlo. // @@ -233,16 +205,9 @@ class IrEmitter : public DfsHloVisitorWithDefault { // Convenience function to get the IR type matching the given shape. llvm::Type* IrShapeType(const Shape& shape); - // Returns an array of compute function parameter types. - std::vector GetComputeFunctionParams(); - - // Get the llvm::Value* that represents the "retval" argument of the - // computation function being emitted by this emitter. - llvm::Argument* GetResultArgument(); - // Get the llvm::Value* that represents the "prof_counters" argument of the // computation function being emitted by this emitter. - llvm::Argument* GetProfileCountersArgument(); + llvm::Value* GetProfileCountersArgument(); // Get the xla::ExecutableRunOptions that represents the "run_options" // argument of the computation function being emitted by this emitter. @@ -252,11 +217,6 @@ class IrEmitter : public DfsHloVisitorWithDefault { // computation function being emitted by this emitter. llvm::Value* GetTempBuffersArgument(); - // Emit ir to read and return the ir value for the dynamic loop bound at - // 'offset' from the "dynamic_loop_bounds" argument of the computation - // function being emitted by this emitter. - llvm::Value* GetDynamicLoopBound(const int64 offset); - // Emits code that computes the address of the given temporary buffer to the // function. target_shape is the shape of this temporary buffer. // The returned Value's type is a pointer to element_type. @@ -310,18 +270,6 @@ class IrEmitter : public DfsHloVisitorWithDefault { tensorflow::gtl::ArraySlice parameter_addresses, tensorflow::StringPiece name); - // Returns an array of compute function call arguments. - std::vector GetArrayFunctionCallArguments( - tensorflow::gtl::ArraySlice parameter_addresses, - llvm::Value* return_value_buffer, tensorflow::StringPiece name); - - // Emits a call to a runtime fork/join function which dispatches parallel - // calls to 'parallel_function' (and joins threads before returning). - Status EmitParallelForkJoin( - tensorflow::gtl::ArraySlice parameter_addresses, - llvm::Value* output_address, HloComputation* computation, - llvm::Function* parallel_function); - // Verifies that the element types of all of the given operand instructions // match and are of one of the given supported types. Status ElementTypesSameAndSupported( @@ -346,15 +294,6 @@ class IrEmitter : public DfsHloVisitorWithDefault { HloInstruction* target_op, tensorflow::StringPiece desc, const llvm_ir::ElementGenerator& element_generator); - // Emit IR to perform a computation for every element in a partition/slice of - // 'target_shape'. The loop bounds for the outer-dimension partitions are - // passed into the compute function as a runtime argument (accessible from - // GetDynamicLoopBound). - Status EmitParallelTargetElementLoop( - const Shape& target_shape, - const llvm_ir::ElementGenerator& element_generator, - tensorflow::StringPiece loop_name, llvm_ir::IrArray* target_array); - // Emits a memcpy from the source instruction's result value to the // destination's. Both source and destination must have an entry in the // emitted_value_ table. @@ -476,13 +415,19 @@ class IrEmitter : public DfsHloVisitorWithDefault { thread_local_buffers_; // The following fields track the IR emission state. According to LLVM memory - // management rules, their memory is owned by the module. - llvm::Function* compute_function_; + // management rules, their memory is owned by the module (Note that IrFunction + // creates the encapsulated llvm::Function s.t. it is added to the llvm + // module's function list). + std::unique_ptr compute_function_; llvm::IRBuilder<> ir_builder_; - // Maps HLOs to their index into the profile counter array. - std::unordered_map hlo_to_profile_idx_; - const tensorflow::gtl::optional entry_computation_profile_idx_; + // Maps HLO instructions to their index into the profile counter array. + const std::unordered_map + instruction_to_profile_idx_; + + // Maps HLO computations to their index into the profile counter array. + const std::unordered_map + computation_to_profile_idx_; // Maps HLOs to Values emitted for them. std::unordered_map emitted_value_; @@ -490,7 +435,7 @@ class IrEmitter : public DfsHloVisitorWithDefault { llvm_ir::AliasAnalysis alias_analysis_; // The number of root instruction outer dimensions used in parallel loop - // emission (EmitParallelTargetElementLoop). + // emission (ParallelLoopEmitter). int64 num_dynamic_loop_bounds_ = 0; // Returns whether the given instruction should be emitted as a parallel loop. @@ -505,15 +450,9 @@ class IrEmitter : public DfsHloVisitorWithDefault { // profiling a computation. class ProfilingState { public: - ProfilingState() - : is_top_level_computation_(false), - use_rdtscp_(false), - prof_counters_(nullptr) {} - ProfilingState(bool is_top_level_computation, bool use_rdtscp, - llvm::Argument* prof_counters) - : is_top_level_computation_(is_top_level_computation), - use_rdtscp_(use_rdtscp), - prof_counters_(prof_counters) {} + ProfilingState() : use_rdtscp_(false), prof_counters_(nullptr) {} + ProfilingState(bool use_rdtscp, llvm::Value* prof_counters) + : use_rdtscp_(use_rdtscp), prof_counters_(prof_counters) {} // Record the cycle counter before an HLO executes. void RecordCycleStart(llvm::IRBuilder<>* ir_builder, HloInstruction* hlo); @@ -535,15 +474,12 @@ class IrEmitter : public DfsHloVisitorWithDefault { llvm::Value* cycle_start); private: - // Is this IrEmitter for a top-level computation? - bool is_top_level_computation_; - // Should we use the x86-specific rdtscp or the generic readcyclecounter // intrinsic? bool use_rdtscp_; // The argument which corresponds to the profile counter buffer. - llvm::Argument* prof_counters_; + llvm::Value* prof_counters_; // The first read cycle counter in the program. llvm::Value* first_read_cycle_start_ = nullptr; diff --git a/tensorflow/compiler/xla/service/cpu/ir_function.cc b/tensorflow/compiler/xla/service/cpu/ir_function.cc new file mode 100644 index 0000000000000000000000000000000000000000..ca8c290dd1c4959e42026c3917d37f8fc95a1011 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/ir_function.cc @@ -0,0 +1,333 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/compiler/xla/service/cpu/ir_function.h" + +#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" +#include "tensorflow/compiler/xla/service/cpu/shape_partition.h" +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" +#include "tensorflow/compiler/xla/status_macros.h" + +namespace xla { + +namespace { +using llvm_ir::AsStringRef; +} // namespace + +namespace cpu { + +static std::vector GetComputeFunctionParams( + llvm::Module* llvm_module, const int64 num_dynamic_loop_bounds) { + llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(llvm_module->getContext()); + llvm::Type* i8_ptr_ptr_type = i8_ptr_type->getPointerTo(); + llvm::Type* i64_ptr_type = + llvm::Type::getInt64PtrTy(llvm_module->getContext()); + std::vector compute_function_params( + {i8_ptr_type, i8_ptr_type, i8_ptr_ptr_type, i8_ptr_ptr_type}); + if (num_dynamic_loop_bounds > 0) { + compute_function_params.push_back(i64_ptr_type); + } + compute_function_params.push_back(i64_ptr_type); + return compute_function_params; +} + +IrFunction::IrFunction(const string& function_name, + llvm::Function::LinkageTypes linkage, + const bool optimize_for_size_requested, + const bool enable_fast_math, llvm::Module* llvm_module, + llvm::IRBuilder<>* ir_builder, + int64 num_dynamic_loop_bounds) + : ir_builder_(ir_builder), + llvm_module_(llvm_module), + caller_insert_point_guard_(*ir_builder), + num_dynamic_loop_bounds_(num_dynamic_loop_bounds) { + Initialize(function_name, linkage, optimize_for_size_requested, + enable_fast_math); +} + +IrFunction::~IrFunction() { + // Emit function return value. + ir_builder_->CreateRetVoid(); +} + +DynamicLoopBounds IrFunction::GetDynamicLoopBounds() { + DynamicLoopBounds dynamic_loop_bounds(num_dynamic_loop_bounds_); + for (int i = 0; i < num_dynamic_loop_bounds_; ++i) { + dynamic_loop_bounds[i].first = GetDynamicLoopBound(i * 2 + 0); + dynamic_loop_bounds[i].second = GetDynamicLoopBound(i * 2 + 1); + } + return dynamic_loop_bounds; +} + +void IrFunction::Initialize(const string& function_name, + llvm::Function::LinkageTypes linkage, + const bool optimize_for_size_requested, + const bool enable_fast_math) { + // The function signature is: + // void function(i8* retval, i8* run_options, i8** params, i8** temps, + // i64* dynamic_loop_bounds, i64* prof_counters) + // + // retval: points to the returned value. + // params: address of an array with pointers to parameters. + // temps: address of an array with pointers to temporary buffers. + // + // Therefore, the generated function's signature (FunctionType) is statically + // determined - parameter unpacking is done in code generated into the + // function, rather than by a prologue dictated by the platform ABI. + // + // /--------------\ + // retval ----------> | return value | + // \--------------/ + // + // /-------------------------------\ + // run_options -----> | xla::ExecutableRunOptions | + // \-------------------------------/ + // + // /---------------------------------------------\ + // params --------> | param 0 | param 1 | ..... | param N-1 | + // | addr | addr | | addr | + // \---------------------------------------------/ + // | | | + // | | | + // V V V + // /---------\ /---------\ /-----------\ + // | param 0 | | param 1 | | param N-1 | + // \---------/ \---------/ \-----------/ + // + // /---------------------------------------------\ + // temps ---------> | temp 0 | temp 1 | ..... | temp N-1 | + // | addr | addr | | addr | + // \---------------------------------------------/ + // | | | + // | | | + // V V V + // /---------\ /---------\ /-----------\ + // | temp 0 | | temp 1 | | temp N-1 | + // \---------/ \---------/ \-----------/ + // + // /--------------------------------------------\ + // dynamic loop bounds -> | outer_dim0_start | outer_dim0_limit | .....| + // (elided for aot) \--------------------------------------------/ + // + // /---------------------------------------------\ + // prof counters -> | counter 0 | counter 1 | ..... | counter N-1 | + // \---------------------------------------------/ + + // Even though the type of params and temps is void** in the host's view, in + // LLVM IR this is represented by i8*, similarly to void*. It's up to the code + // to use GEPs to unravel the indirection layers. + llvm::FunctionType* function_type = llvm::FunctionType::get( + /*Result=*/llvm::Type::getVoidTy(llvm_module_->getContext()), + /*Params=*/ + GetComputeFunctionParams(llvm_module_, num_dynamic_loop_bounds_), + /*isVarArg=*/false); + + // Functions with local linkage get an inlining bonus. Because we know + // a-priori that embedded functions (non-entry functions) will not have its + // name resolved, give it local linkage. + function_ = + llvm_ir::CreateFunction(function_type, linkage, + /*enable_fast_math=*/enable_fast_math, + /*optimize_for_size=*/optimize_for_size_requested, + function_name, llvm_module_); + + // Set meaningful names for the function's arguments: useful for debugging. + llvm::Function::arg_iterator arg_iter = function_->arg_begin(); + arg_iter->setName("retval"); + result_arg_ = &*arg_iter; + (++arg_iter)->setName("run_options"); + exec_run_options_arg_ = &*arg_iter; + (++arg_iter)->setName("params"); + parameters_arg_ = &*arg_iter; + (++arg_iter)->setName("temps"); + temp_buffers_arg_ = &*arg_iter; + if (num_dynamic_loop_bounds_ > 0) { + (++arg_iter)->setName("dynamic_loop_bounds"); + dynamic_loop_bounds_arg_ = &*arg_iter; + } + (++arg_iter)->setName("prof_counters"); + profile_counters_arg_ = &*arg_iter; + + // We know a-priori that the function arguments are guaranteed to point to + // disjoint objects. + llvm::Argument* retval = result_arg(); + for (llvm::Argument& argument : function_->args()) { + // However, the return buffer aliases the temporaries and thus cannot be + // marked noalias. + if (&argument == retval) { + continue; + } + function_->addAttribute(argument.getArgNo() + 1, llvm::Attribute::NoAlias); + } + + ir_builder_->SetInsertPoint(llvm::BasicBlock::Create( + /*Context=*/llvm_module_->getContext(), + /*Name=*/"entry", + /*Parent=*/function_)); +} + +llvm::Value* IrFunction::GetDynamicLoopBound(const int64 offset) { + CHECK_GT(num_dynamic_loop_bounds_, 0); + CHECK_LT(offset, num_dynamic_loop_bounds_ * 2); + string name = tensorflow::strings::StrCat("dynamic_loop_bound_", offset); + return ir_builder_->CreateLoad( + ir_builder_->CreateGEP(CHECK_NOTNULL(dynamic_loop_bounds_arg_), + ir_builder_->getInt64(offset), AsStringRef(name))); +} + +// Emits code to allocate an array of parameter address pointers, and store +// each address from 'parameter_addresses'. +// Returns an array of compute function call arguments (including parameter +// address buffer). +std::vector GetArrayFunctionCallArguments( + tensorflow::gtl::ArraySlice parameter_addresses, + llvm::IRBuilder<>* ir_builder, tensorflow::StringPiece name, + llvm::Value* return_value_buffer, llvm::Value* exec_run_options_arg, + llvm::Value* temp_buffers_arg, llvm::Value* profile_counters_arg) { + llvm::Value* parameter_addresses_buffer = + llvm_ir::EmitAllocaAtFunctionEntryWithCount( + ir_builder->getInt8PtrTy(), + ir_builder->getInt32(parameter_addresses.size()), + tensorflow::strings::StrCat(name, "_parameter_addresses"), + ir_builder); + for (size_t i = 0; i < parameter_addresses.size(); ++i) { + llvm::Value* parameter_as_i8ptr = ir_builder->CreateBitCast( + parameter_addresses[i], ir_builder->getInt8PtrTy(), + AsStringRef(tensorflow::strings::StrCat(name, "_parameter_", i, + "_address_as_i8ptr"))); + llvm::Value* slot_in_param_adresses = ir_builder->CreateInBoundsGEP( + parameter_addresses_buffer, {ir_builder->getInt64(i)}); + ir_builder->CreateStore(parameter_as_i8ptr, slot_in_param_adresses); + } + + const auto to_int8_ptr = [=](llvm::Value* ptr) { + return ir_builder->CreatePointerCast(ptr, ir_builder->getInt8PtrTy()); + }; + std::vector arguments{ + to_int8_ptr(return_value_buffer), to_int8_ptr(exec_run_options_arg), + parameter_addresses_buffer, temp_buffers_arg}; + if (profile_counters_arg != nullptr) { + arguments.push_back(profile_counters_arg); + } + return arguments; +} + +// Emits a call to a runtime fork/join function which dispatches parallel +// calls to 'parallel_function' (and joins threads before returning). +Status EmitCallToParallelForkJoin( + const std::vector& arguments, const Shape& shape, + const std::vector& dimension_partition_counts, + llvm::IRBuilder<>* ir_builder, llvm::Function* parallel_function, + const string& name) { + llvm::Module* module = ir_builder->GetInsertBlock()->getModule(); + + // Build ParallelForkJoin function type. + std::vector compute_function_params = + GetComputeFunctionParams(module, /*num_dynamic_loop_bounds=*/0); + // Number of parallel compute functions. + compute_function_params.push_back(ir_builder->getInt32Ty()); + // Array of partitions. There is an array element for each + // partition x partition_dim x 2 (for dimension start and limit). + compute_function_params.push_back( + llvm::Type::getInt64PtrTy(module->getContext())); + // Number of partitioned most-major dimensions in 'shape'. + compute_function_params.push_back(ir_builder->getInt32Ty()); + // Function pointer for compute function to be dispatched in parallel. + compute_function_params.push_back( + llvm::Type::getInt8PtrTy(module->getContext())); + + llvm::FunctionType* fork_join_type = llvm::FunctionType::get( + /*Result=*/llvm::Type::getVoidTy(module->getContext()), + /*Params=*/compute_function_params, + /*isVarArg=*/false); + + llvm::Function* fork_join_func = + llvm::cast(module->getOrInsertFunction( + runtime::kParallelForkJoinSymbolName, fork_join_type)); + fork_join_func->setCallingConv(llvm::CallingConv::C); + fork_join_func->setDoesNotThrow(); + + // Add common compute function arguments. + std::vector fork_join_arguments(arguments); + + // Create ShapePartitionIterator to generate all partitions of 'shape'. + ShapePartitionIterator partition_iterator(shape, dimension_partition_counts); + const int64 num_partitions = partition_iterator.GetTotalPartitionCount(); + // Add argument specifying the number of parallel partitions. + fork_join_arguments.push_back(ir_builder->getInt32(num_partitions)); + + // The number of partitioned most-major dimensions in 'shape'. + const int32 num_partitioned_dims = dimension_partition_counts.size(); + // A dimension partition consists of two elements: [start_index, limit_index). + const int32 dim_partition_size = 2; + // Calculate array partition stride. + const int32 array_partition_stride = + num_partitioned_dims * dim_partition_size; + // Calculate the total number of elements in the partition array. + const int32 partition_array_size = + dim_partition_size * num_partitioned_dims * num_partitions; + + // Store dimension partition values as llvm constants in 'partitions'. + // See comments in runtime_fork_join.cc for array layout description. + std::vector partitions(partition_array_size); + for (int32 i = 0; i < num_partitions; ++i) { + std::vector> dim_partitions = + partition_iterator.GetPartition(i); + CHECK_EQ(num_partitioned_dims, dim_partitions.size()); + const int32 partition_index = i * array_partition_stride; + for (int32 j = 0; j < num_partitioned_dims; ++j) { + const std::pair& dim_partition = dim_partitions[j]; + const int32 index = partition_index + j * dim_partition_size; + // Store partition [dim_start, dim_limit) intervals for each dimension. + partitions[index] = ir_builder->getInt64(dim_partition.first); + partitions[index + 1] = + ir_builder->getInt64(dim_partition.first + dim_partition.second); + } + } + + // Create global variable out of dimension partitions in 'partitions'. + llvm::ArrayType* partitions_array_type = + llvm::ArrayType::get(ir_builder->getInt64Ty(), partition_array_size); + llvm::Constant* partitions_array = + llvm::ConstantArray::get(partitions_array_type, partitions); + llvm::GlobalVariable* global_partitions_array = new llvm::GlobalVariable( + /*M=*/*module, + /*Ty=*/partitions_array_type, + /*isConstant=*/true, + /*Linkage=*/llvm::GlobalValue::PrivateLinkage, + /*Initializer=*/partitions_array, + /*Name=*/ + AsStringRef( + tensorflow::strings::StrCat(name, "_parallel_dimension_partitions"))); + + // Add argument specifying parallel dimension partitions. + fork_join_arguments.push_back(ir_builder->CreateBitCast( + global_partitions_array, + llvm::Type::getInt64PtrTy(module->getContext()))); + // Add argument specifying the number of partitioned most-major dimensions. + fork_join_arguments.push_back(ir_builder->getInt32(num_partitioned_dims)); + // Add argument for parallel compute function pointer. + fork_join_arguments.push_back( + ir_builder->CreateBitCast(parallel_function, ir_builder->getInt8PtrTy())); + // Emit call to parallel fork/join. + ir_builder->CreateCall(fork_join_func, fork_join_arguments); + + return Status::OK(); +} + +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/ir_function.h b/tensorflow/compiler/xla/service/cpu/ir_function.h new file mode 100644 index 0000000000000000000000000000000000000000..1fd2da4dce23982ed030f3aa8ec604182d0ebab8 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/ir_function.h @@ -0,0 +1,134 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_FUNCTION_H_ +#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_FUNCTION_H_ + +#include "llvm/IR/Function.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Value.h" +#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/gtl/array_slice.h" + +namespace xla { +namespace cpu { + +// IrFunction creates and encapsulates an llvm::Function, exposing methods to +// emitters for function and function argument access. +// The llvm::Function is created with the standard function signature +// used in the XLA CPU backend (see ir_function.cc for argument details). +// In addtion IrFunction saves the callers IR insert point during contruction, +// and restores it after desctruction. +// +// Example usage: +// +// // Create and initialize new IrFunction. +// std::unique_ptr compute_function(new IrFunction(...)); +// // Emit IR for function body using IrFunction helper methods. +// ... +// // Store reference to llvm::Function for future invocation. +// ir_functions.push_back(compute_function.function()); +// // Delete IrFunction (finalizes IR function and restores caller insertion +// // point). +// compute_function.reset(); +// + +class IrFunction { + public: + IrFunction(const string& function_name, llvm::Function::LinkageTypes linkage, + const bool optimize_for_size_requested, + const bool enable_fast_math, llvm::Module* llvm_module, + llvm::IRBuilder<>* ir_builder, int64 num_dynamic_loop_bounds); + ~IrFunction(); + + // Emit ir to read and return the set of ir values representing the dynamic + // loop bounds argument of this function. + // Each element in returned vector is a pair of ir values representing + // the loop bounds for a specific dimension, where the first element of the + // pair is the dimension start index, and the second element of the pair + // is the dimension limit. + // EX: [dimension_i_index_start_ir_value, dimension_i_index_limit_ir_value] + // + DynamicLoopBounds GetDynamicLoopBounds(); + + // Returns the encapculated llvm::Function. + llvm::Function* function() { return function_; } + + // Get the llvm::Value* that represents this functions "retval" argument. + llvm::Argument* result_arg() { return result_arg_; } + + // Get the xla::ExecutableRunOptions that represents this functions + // "run_options" argument. + llvm::Value* exec_run_options_arg() { return exec_run_options_arg_; } + + // Get the llvm::Value* that represents this functions parameters argument. + llvm::Value* parameters_arg() { return parameters_arg_; } + + // Get the llvm::Value* that represents this functions "temps" argument. + llvm::Value* temp_buffers_arg() { return temp_buffers_arg_; } + + // Get the llvm::Value* that represents this functions "prof_counters" + // argument. + llvm::Value* profile_counters_arg() { return profile_counters_arg_; } + + private: + // Initialize an llvm::Function with standard signature based on arguments. + void Initialize(const string& function_name, + llvm::Function::LinkageTypes linkage, + bool optimize_for_size_requested, bool enable_fast_math); + + // Emit ir to read and return the ir value for the dynamic loop bound at + // 'offset' from the "dynamic_loop_bounds" argument of this function. + llvm::Value* GetDynamicLoopBound(int64 offset); + + llvm::IRBuilder<>* ir_builder_; + llvm::Module* llvm_module_; + llvm::IRBuilder<>::InsertPointGuard caller_insert_point_guard_; + + int64 num_dynamic_loop_bounds_ = 0; + // Encapsulated llvm::Function. + llvm::Function* function_; + // Function argument IR values. + llvm::Argument* result_arg_; + llvm::Value* exec_run_options_arg_; + llvm::Value* parameters_arg_; + llvm::Value* temp_buffers_arg_; + llvm::Value* dynamic_loop_bounds_arg_ = nullptr; + llvm::Value* profile_counters_arg_; +}; + +// Returns an array of compute function call argument ir values. +std::vector GetArrayFunctionCallArguments( + tensorflow::gtl::ArraySlice parameter_addresses, + llvm::IRBuilder<>* ir_builder, tensorflow::StringPiece name, + llvm::Value* return_value_buffer, llvm::Value* exec_run_options_arg, + llvm::Value* temp_buffers_arg, llvm::Value* profile_counters_arg); + +// Emits a call to a runtime fork/join function which dispatches parallel +// calls to 'parallel_function' (and joins threads before returning). +Status EmitCallToParallelForkJoin( + const std::vector& arguments, const Shape& shape, + const std::vector& dimension_partition_counts, + llvm::IRBuilder<>* ir_builder, llvm::Function* parallel_function, + const string& name); + +} // namespace cpu +} // namespace xla + +#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_FUNCTION_H_ diff --git a/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc b/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc index 81c29e4726c7be53b433be896f558f502e43c885..0336fa61312e5cd626ae38ddd29875bff256212a 100644 --- a/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc +++ b/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc @@ -64,14 +64,14 @@ llvm::Function* EmitVectorF32TanhIfNeeded(llvm::Module* module, &ir_builder), llvm::ConstantFP::get(vector_type, 9.0), &ir_builder); - std::array numerator_coeffs( - {{-2.76076847742355e-16f, 2.00018790482477e-13f, -8.60467152213735e-11f, - 5.12229709037114e-08f, 1.48572235717979e-05f, 6.37261928875436e-04f, - 4.89352455891786e-03f}}); - - std::array denominator_coeffs( - {{1.19825839466702e-06f, 1.18534705686654e-04f, 2.26843463243900e-03f, - 4.89352518554385e-03f}}); + std::array numerator_coeffs{ + -2.76076847742355e-16f, 2.00018790482477e-13f, -8.60467152213735e-11f, + 5.12229709037114e-08f, 1.48572235717979e-05f, 6.37261928875436e-04f, + 4.89352455891786e-03f}; + + std::array denominator_coeffs{ + 1.19825839466702e-06f, 1.18534705686654e-04f, 2.26843463243900e-03f, + 4.89352518554385e-03f}; llvm::Value* input_squared = ir_builder.CreateFMul(input_clamped, input_clamped); diff --git a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc index 0077e344e2bd34aa598ee076220fee678f31b4ad..d1b88b27f068962fb86477fcad3e4390b1636c2b 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc @@ -376,19 +376,6 @@ Status ParallelCpuExecutable::ExecuteComputeFunctions( tensorflow::gtl::ArraySlice arguments, tensorflow::gtl::ArraySlice buffers, HloExecutionProfile* hlo_execution_profile) { - std::vector argument_buffers(arguments.size()); - for (int i = 0; i < arguments.size(); ++i) { - argument_buffers[i] = arguments[i]->buffer(/*index=*/{}); - } - return ExecuteComputeFunctions(run_options, argument_buffers, buffers, - hlo_execution_profile); -} - -Status ParallelCpuExecutable::ExecuteComputeFunctions( - const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments, - tensorflow::gtl::ArraySlice buffers, - HloExecutionProfile* hlo_execution_profile) { // Allocate profiling counters for each hlo instruction that we would like to // profile. std::vector* profile_counters = nullptr; @@ -428,8 +415,9 @@ Status ParallelCpuExecutable::ExecuteComputeFunctions( // just copy the existing buffer into the map containing instruction // results.. if (instruction->opcode() == HloOpcode::kParameter) { - InsertOrDie(&results, instruction, - arguments[instruction->parameter_number()].opaque()); + InsertOrDie( + &results, instruction, + arguments[instruction->parameter_number()]->root_buffer().opaque()); } else if (instruction->opcode() == HloOpcode::kConstant) { unsigned char* aligned_data = FindOrDie(aligned_constants_, instruction).get(); @@ -461,69 +449,6 @@ Status ParallelCpuExecutable::ExecuteComputeFunctions( return Status::OK(); } -StatusOr -ParallelCpuExecutable::ExecuteOnStream( - const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments, - HloExecutionProfile* hlo_execution_profile) { - se::Stream* stream = run_options->stream(); - DeviceMemoryAllocator* memory_allocator = run_options->allocator(); - VLOG(3) << "ExecuteOnStream arg size: " << arguments.size(); - if (!arguments.empty()) { - VLOG(3) << "ExecuteOnStream arg[0]: " << arguments.at(0).opaque(); - } - - // Allocate the temporary buffers required for the computation. - se::StreamExecutor* stream_executor = stream->parent(); - int device_ordinal = stream_executor->device_ordinal(); - int64 buffer_count = assignment_->Allocations().size(); - VLOG(3) << "temp buffer count: " << buffer_count; - - std::vector device_allocations( - assignment_->Allocations().size()); - TF_RETURN_IF_ERROR(AllocateBuffers(memory_allocator, - stream->parent()->device_ordinal(), - &device_allocations)); - - TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice result_slice, - assignment_->GetUniqueTopLevelOutputSlice()); - const BufferAllocation::Index result_index = result_slice.index(); - VLOG(3) << "result index: " << result_index; - - TF_RETURN_IF_ERROR(ExecuteComputeFunctions( - run_options, arguments, device_allocations, hlo_execution_profile)); - - // Mark the buffers that are actually live (used in the output) when the - // computation finishes executing. - std::unordered_set marked_addresses; - MarkLiveAddressesInOutput(device_allocations[result_index].opaque(), - result_shape(), &marked_addresses); - - VLOG(3) << "Live addresses in output marking found " - << marked_addresses.size() << " addresses:\n" - << tensorflow::str_util::Join( - marked_addresses, ", ", [](string* out, const void* address) { - tensorflow::strings::StrAppend( - out, tensorflow::strings::Printf("%p", address)); - }); - - // Computation is done - deallocate temp buffers. Keep those marked - // live because they are referenced by the output of the computation - // and are needed by the service. They will be deallocated by the - // service. - for (size_t i = 0; i < device_allocations.size(); ++i) { - auto alloc = device_allocations[i]; - if (marked_addresses.count(alloc.opaque()) == 0 && - alloc.opaque() != nullptr) { - VLOG(3) << "ParallelCpuExecutable deallocating buffer #" << i << " [" - << alloc.opaque() << "]"; - TF_RETURN_IF_ERROR(memory_allocator->Deallocate(device_ordinal, &alloc)); - } - } - - return device_allocations[result_index]; -} - StatusOr> ParallelCpuExecutable::ExecuteOnStream( const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments, @@ -536,9 +461,9 @@ StatusOr> ParallelCpuExecutable::ExecuteOnStream( DeviceMemoryAllocator* memory_allocator = run_options->allocator(); std::vector buffers(assignment_->Allocations().size()); - auto result_buffer = - MakeUnique(result_shape(), stream->parent()->platform(), - stream->parent()->device_ordinal()); + auto result_buffer = MakeUnique( + /*on_host_shape=*/result_shape(), /*on_device_shape=*/result_shape(), + stream->parent()->platform(), stream->parent()->device_ordinal()); TF_RETURN_IF_ERROR(AllocateBuffers( memory_allocator, stream->parent()->device_ordinal(), &buffers)); @@ -549,37 +474,30 @@ StatusOr> ParallelCpuExecutable::ExecuteOnStream( // Copy DeviceMemoryBase values which into the respective location in // ShapedBuffer which is returned to the caller. std::vector buffers_in_result(assignment_->Allocations().size(), false); - TF_RETURN_IF_ERROR( - result_buffer->mutable_shape_index_to_buffer_entry() - ->ForEachMutableElementWithStatus( - [&buffers, &buffers_in_result, &result_buffer, this]( - const ShapeIndex& index, size_t* buffer_entry) { - const auto& sources = - this->GetRootPointsToSet().element(index); - // The points to set is unambiguous so the set should be a - // singleton. - CHECK_EQ(1, sources.size()); - const LogicalBuffer* buffer_source = sources[0]; - HloInstruction* src = buffer_source->instruction(); - - // The source for this result buffer can be a nested buffer - // such as a tuple element. - - // The source instruction should have a non-parameter buffer - // assigned. - TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice, - this->assignment_->GetUniqueSlice( - src, buffer_source->index())); - CHECK(!slice.allocation()->is_entry_computation_parameter()); - - const BufferAllocation::Index buffer_index = slice.index(); - const se::DeviceMemoryBase& buffer = buffers[buffer_index]; - CHECK(!buffer.is_null() || buffer.size() == 0); - *buffer_entry = result_buffer->mutable_buffers()->size(); - result_buffer->mutable_buffers()->push_back(buffer); - buffers_in_result[buffer_index] = true; - return Status::OK(); - })); + TF_RETURN_IF_ERROR(result_buffer->buffers().ForEachMutableElementWithStatus( + [&](const ShapeIndex& index, se::DeviceMemoryBase* device_memory) { + const auto& sources = this->GetRootPointsToSet().element(index); + + // The points to set is unambiguous so the set should be a singleton. + CHECK_EQ(1, sources.size()); + const LogicalBuffer* buffer_source = sources[0]; + HloInstruction* src = buffer_source->instruction(); + + // The source for this result buffer can be a nested buffer such as a + // tuple element. The source instruction should have a non-parameter + // buffer assigned. + TF_ASSIGN_OR_RETURN( + const BufferAllocation::Slice slice, + this->assignment_->GetUniqueSlice(src, buffer_source->index())); + CHECK(!slice.allocation()->is_entry_computation_parameter()); + + const BufferAllocation::Index buffer_index = slice.index(); + const se::DeviceMemoryBase& buffer = buffers[buffer_index]; + CHECK(!buffer.is_null() || buffer.size() == 0); + *device_memory = buffer; + buffers_in_result[buffer_index] = true; + return Status::OK(); + })); // Free all buffers not in the result. for (size_t i = 0; i < buffers.size(); ++i) { @@ -595,10 +513,10 @@ StatusOr> ParallelCpuExecutable::ExecuteOnStream( return std::move(result_buffer); } -StatusOr +StatusOr> ParallelCpuExecutable::ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments) { + tensorflow::gtl::ArraySlice arguments) { // TODO(b/30671675): Implement asynchronous execution mode. return Unimplemented( "Asynchronous execution on stream is not yet supported on CPU."); diff --git a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h index d65e3f42f3cb34eff005f34b51b81fd5c42974a3..90ac94ef9288b2e860cb30c47ed44a7b96e4825d 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h +++ b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h @@ -59,21 +59,14 @@ class ParallelCpuExecutable : public Executable { std::unique_ptr hlo_profile_index_map); ~ParallelCpuExecutable() override {} - StatusOr ExecuteOnStream( - const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice - arguments, - HloExecutionProfile* hlo_execution_profile) override; - StatusOr> ExecuteOnStream( const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments, HloExecutionProfile* hlo_execution_profile) override; - StatusOr ExecuteAsyncOnStream( + StatusOr> ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice - arguments) override; + tensorflow::gtl::ArraySlice arguments) override; // This should be called after set_ir_module_string. const string& ir_module_string() const { return ir_module_string_; } @@ -108,13 +101,6 @@ class ParallelCpuExecutable : public Executable { // Calls the generated functions in 'function_names_', performing the // computation with the given arguments using the supplied buffers. - Status ExecuteComputeFunctions( - const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice - arguments, - tensorflow::gtl::ArraySlice - buffers, - HloExecutionProfile* hlo_execution_profile); Status ExecuteComputeFunctions( const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments, diff --git a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc new file mode 100644 index 0000000000000000000000000000000000000000..1e439cde11cf74272101b80c867a308e51ab26a6 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc @@ -0,0 +1,76 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h" + +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" +#include "tensorflow/core/lib/strings/stringprintf.h" + +namespace xla { +namespace cpu { + +ParallelLoopEmitter::ParallelLoopEmitter( + const llvm_ir::ElementGenerator& target_element_generator, + const llvm_ir::IrArray& target_array, + const DynamicLoopBounds* dynamic_loop_bounds, llvm::IRBuilder<>* ir_builder) + : LoopEmitter(target_element_generator, target_array, ir_builder), + dynamic_loop_bounds_(dynamic_loop_bounds) {} + +llvm_ir::IrArray::Index ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock( + tensorflow::StringPiece loop_name) { + CHECK(!ShapeUtil::IsTuple(shape_)); + CHECK(!ShapeUtil::IsScalar(shape_)); + + llvm_ir::ForLoopNest loop_nest(loop_name, ir_builder_); + const int64 num_dims = shape_.dimensions_size(); + llvm_ir::IrArray::Index array_index(num_dims); + + // Add loops from outer-most to inner-most dimensions. + for (int i = LayoutUtil::MinorToMajor(shape_).size() - 1; i >= 0; --i) { + const int64 dimension = LayoutUtil::Minor(shape_.layout(), i); + const int bounds_index = num_dims - 1 - i; + if (bounds_index < dynamic_loop_bounds_->size()) { + // Emit dynamic loop bounds for this dimension. Dynamic loop bounds + // are read from ir function dynamic loop bounds argument. + llvm::Value* start_index = (*dynamic_loop_bounds_)[bounds_index].first; + llvm::Value* end_index = (*dynamic_loop_bounds_)[bounds_index].second; + + std::unique_ptr loop = loop_nest.AddLoop( + /*suffix=*/tensorflow::strings::Printf("dim.%lld", dimension), + start_index, end_index); + array_index[dimension] = loop->GetIndVarValue(); + } else { + // Emit static loop bounds for this dimension. + std::unique_ptr loop = loop_nest.AddLoop( + /*start_index=*/0, + /*end_index=*/shape_.dimensions(dimension), + /*suffix=*/tensorflow::strings::Printf("dim.%lld", dimension)); + array_index[dimension] = loop->GetIndVarValue(); + } + } + // Point IR builder at inner loop BB. + llvm_ir::SetToFirstInsertPoint(loop_nest.GetInnerLoopBodyBasicBlock(), + ir_builder_); + + // Set exit_bb_ to the exit block of the loop nest. + exit_bb_ = loop_nest.GetOuterLoopExitBasicBlock(); + CHECK(exit_bb_ != nullptr); + + return array_index; +} + +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h new file mode 100644 index 0000000000000000000000000000000000000000..9335d2818e99eb3588537d80dabddda08c1c020e --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h @@ -0,0 +1,73 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_PARALLEL_LOOP_EMITTER_H_ +#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_PARALLEL_LOOP_EMITTER_H_ + +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Value.h" +#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" +#include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h" + +namespace xla { +namespace cpu { + +// ParallelLoopEmitter emits a loop nest for the target array shape. +// The outer loop bounds of the loop nest are passed as ir values at runtime +// (specified in 'dynamic_loop_bounds'), and the inner loop bounds are static. +// Dynamic loop bounds are specified as an array of dimension index +// [start, limit) pairs of ir values (one for each partitioned outer dimension). +// +// EX: Let 'shape' = [8, 16, 32], with the loop bounds of the two-most major +// dimensions dynamic. Then 'dynamic_loop_bounds' will contain the +// following ir values for the two most-major dimensions: +// [dim0_index_start_ir_value, dim0_index_limit_ir_value] +// [dim1_index_start_ir_value, dim1_index_limit_ir_value] +// +// Code emitted by ParallelLoopEmitter will be called in a multi-threaded +// context where each thread will be assigned a different set of outer dimension +// partitions, and where all threads will collectively iterate over the +// entire target array shape. +// +// Outer dimension partitions can be generated using the ShapePartitionAssigner +// and ShapePartitionIterator utility classes from shape_partition.cc. +// +class ParallelLoopEmitter : public llvm_ir::LoopEmitter { + public: + // Constructs a ParallelLoopEmitter which uses 'target_element_generator' to + // generate elements, 'dynamic_loop_bounds' to set the loop bounds of the + // most-major dimensions, and 'target_array.' shape to set the static loop + // bounds for the most-minor dimensions. + ParallelLoopEmitter(const llvm_ir::ElementGenerator& target_element_generator, + const llvm_ir::IrArray& target_array, + const DynamicLoopBounds* dynamic_loop_bounds, + llvm::IRBuilder<>* ir_builder); + + ParallelLoopEmitter(const ParallelLoopEmitter&) = delete; + ParallelLoopEmitter& operator=(const ParallelLoopEmitter&) = delete; + ~ParallelLoopEmitter() override = default; + + llvm_ir::IrArray::Index EmitIndexAndSetExitBasicBlock( + tensorflow::StringPiece loop_name) override; + + private: + const DynamicLoopBounds* dynamic_loop_bounds_; +}; + +} // namespace cpu +} // namespace xla + +#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_PARALLEL_LOOP_EMITTER_H_ diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc index 4b44ac8941e222d5954121bbb9654062e41f55d6..deb21bf4ef5895cfdbec5c2449b6ce7b306a7008 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc @@ -126,7 +126,7 @@ int64 ParallelTaskAssignment::GetTargetParallelTaskCount( HloInstruction* instruction) { // Currently, we do not assign parallel tasks to instructions with at least // one of the following properties: - // *) Internal threading (library calls to kConv, kDot, and kCustomCall). + // *) Internal threading (library calls to kConv, kDot, kFft, kCustomCall). // *) Emit custom loops (kSelectAndScatter, FusionKind::kTransposeDot). // *) Tuple-shaped. // TODO(b/27458679) Parallelize instructions which are skipped here. @@ -137,6 +137,7 @@ int64 ParallelTaskAssignment::GetTargetParallelTaskCount( instruction->opcode() == HloOpcode::kSelectAndScatter || instruction->opcode() == HloOpcode::kGetTupleElement || instruction->opcode() == HloOpcode::kBitcast || + instruction->opcode() == HloOpcode::kFft || (instruction->opcode() == HloOpcode::kConvolution && PotentiallyImplementedAsEigenConvolution(*instruction)) || PotentiallyImplementedAsEigenDot(*instruction) || diff --git a/tensorflow/compiler/xla/service/cpu/runtime_fft.cc b/tensorflow/compiler/xla/service/cpu/runtime_fft.cc new file mode 100644 index 0000000000000000000000000000000000000000..848d2d22414e8fc9bca82de90f7676011d8992fd --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_fft.cc @@ -0,0 +1,37 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/runtime_fft.h" + +#define EIGEN_USE_THREADS + +#include "tensorflow/compiler/xla/executable_run_options.h" +#include "tensorflow/compiler/xla/service/cpu/runtime_fft_impl.h" +#include "tensorflow/core/platform/dynamic_annotations.h" +#include "tensorflow/core/platform/types.h" + +using tensorflow::int32; +using tensorflow::int64; + +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenFft( + const void* run_options_ptr, void* out, void* operand, int32 fft_type, + int32 fft_rank, int64 input_batch, int64 fft_length0, int64 fft_length1, + int64 fft_length2) { + const xla::ExecutableRunOptions* run_options = + static_cast(run_options_ptr); + tensorflow::xla::EigenFftImpl(*run_options->intra_op_thread_pool(), out, + operand, fft_type, fft_rank, input_batch, + fft_length0, fft_length1, fft_length2); +} diff --git a/tensorflow/compiler/xla/service/cpu/runtime_fft.h b/tensorflow/compiler/xla/service/cpu/runtime_fft.h new file mode 100644 index 0000000000000000000000000000000000000000..f20c5aa0aa2dcbc700f47c718e75baae18650d1a --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_fft.h @@ -0,0 +1,31 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_FFT_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_FFT_H_ + +#include "tensorflow/core/platform/types.h" + +extern "C" { + +extern void __xla_cpu_runtime_EigenFft( + const void* /* xla::ExecutableRunOptions* */ run_options_ptr, void* out, + void* operand, tensorflow::int32 fft_type, tensorflow::int32 fft_rank, + tensorflow::int64 input_batch, tensorflow::int64 fft_length0, + tensorflow::int64 fft_length1, tensorflow::int64 fft_length2); + +} // extern "C" + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_FFT_H_ diff --git a/tensorflow/compiler/xla/service/cpu/runtime_fft_impl.h b/tensorflow/compiler/xla/service/cpu/runtime_fft_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..984cb0616e02475babad7160d0f43bb23de0b50e --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_fft_impl.h @@ -0,0 +1,240 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_FFT_IMPL_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_FFT_IMPL_H_ + +#include + +#include "third_party/eigen3/Eigen/Core" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/framework/numeric_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/platform/types.h" + +// 'tensorflow' namespace is used so that int64 and other types don't require +// qualification. +namespace tensorflow { +namespace xla { + +namespace internal { + +// Computes either a forward or reverse complex-to-complex FFT. +template +void EigenFftC2C(const EigenDevice& device, complex64* out, complex64* operand, + int64 input_batch, int64 fft_length0, int64 fft_length1, + int64 fft_length2) { + // Create the axes (which are always trailing). + const auto axes = Eigen::ArrayXi::LinSpaced(FFTRank, 1, FFTRank); + constexpr auto direction = Forward ? Eigen::FFT_FORWARD : Eigen::FFT_REVERSE; + + const std::array fft_shape = { + {fft_length0, fft_length1, fft_length2}}; + + Eigen::DSizes dims; + dims[0] = input_batch; + for (int i = 0; i < FFTRank; i++) { + dims[i + 1] = fft_shape[i]; + } + const Eigen::TensorMap, + Eigen::Aligned> + input(operand, dims); + Eigen::TensorMap, + Eigen::Aligned> + output(out, dims); + output.device(device) = input.template fft(axes); +} + +// Computes a forward real->complex FFT, slicing out redundant negative +// frequencies from the innermost dimension. +template +void EigenFftR2C(const EigenDevice& device, complex64* out, float* operand, + int64 input_batch, int64 fft_length0, int64 fft_length1, + int64 fft_length2) { + const std::array fft_shape = { + {fft_length0, fft_length1, fft_length2}}; + + Eigen::DSizes in_dims; + in_dims[0] = input_batch; + Eigen::DSizes out_dims; + out_dims[0] = input_batch; + TensorShape temp_shape{input_batch}; + for (int i = 0; i < FFTRank; i++) { + in_dims[i + 1] = fft_shape[i]; + out_dims[i + 1] = i == FFTRank - 1 ? fft_shape[i] / 2 + 1 : fft_shape[i]; + temp_shape.AddDim(fft_shape[i]); + } + const Eigen::TensorMap, + Eigen::Aligned> + input(operand, in_dims); + Eigen::TensorMap, + Eigen::Aligned> + output(out, out_dims); + + // Create the axes (which are always trailing). + const auto axes = Eigen::ArrayXi::LinSpaced(FFTRank, 1, FFTRank); + + // Compute the full FFT using a temporary tensor. + Tensor temp(DataTypeToEnum::v(), temp_shape); + auto full_fft = temp.flat_inner_dims(); + const Eigen::DSizes zero_start_indices; + full_fft.device(device) = + input.template fft(axes); + + // Slice away the negative frequency components. + output.device(device) = full_fft.slice(zero_start_indices, out_dims); +} + +// Computes a reverse complex->real FFT, reconstructing redundant negative +// frequencies using reverse conjugate on innermost dimension after doing IFFT +// on outer dimensions. +template +void EigenFftC2R(const EigenDevice& device, float* out, complex64* operand, + int64 input_batch, int64 fft_length0, int64 fft_length1, + int64 fft_length2) { + const std::array fft_shape = { + {fft_length0, fft_length1, fft_length2}}; + + Eigen::DSizes in_dims; + in_dims[0] = input_batch; + Eigen::DSizes out_dims; + out_dims[0] = input_batch; + TensorShape temp_shape{input_batch}; + for (int i = 0; i < FFTRank; i++) { + in_dims[i + 1] = i == FFTRank - 1 ? fft_shape[i] / 2 + 1 : fft_shape[i]; + out_dims[i + 1] = fft_shape[i]; + temp_shape.AddDim(fft_shape[i]); + } + const Eigen::TensorMap, + Eigen::Aligned> + input(operand, in_dims); + Eigen::TensorMap, + Eigen::Aligned> + output(out, out_dims); + + // Calculate the shape of the temporary tensor for the full FFT and the + // region we will slice from input given fft_shape. We slice input to + // fft_shape on its inner-most dimensions, except the last (which we + // slice to fft_shape[-1] / 2 + 1). + Tensor temp(DataTypeToEnum::v(), temp_shape); + auto full_fft = temp.flat_inner_dims(); + + // Calculate the starting point and range of the source of + // negative frequency part. + auto neg_sizes = in_dims; + neg_sizes[FFTRank] = fft_shape[FFTRank - 1] - in_dims[FFTRank]; + Eigen::DSizes neg_target_indices; + neg_target_indices[FFTRank] = in_dims[FFTRank]; + + const Eigen::DSizes zero_start_indices; + Eigen::DSizes neg_start_indices; + neg_start_indices[FFTRank] = 1; + + full_fft.slice(zero_start_indices, in_dims).device(device) = input; + + // First, conduct IFFTs on outer dimensions. We save computation (and + // avoid touching uninitialized memory) by slicing full_fft to the + // subregion we wrote input to. + if (FFTRank > 1) { + const auto outer_axes = + Eigen::ArrayXi::LinSpaced(FFTRank - 1, 1, FFTRank - 1); + full_fft.slice(zero_start_indices, in_dims).device(device) = + full_fft.slice(zero_start_indices, in_dims) + .template fft(outer_axes); + } + + // Reconstruct the full FFT by appending reversed and conjugated + // spectrum as the negative frequency part. + Eigen::array reverse_last_axis; + for (auto i = 0; i <= FFTRank; i++) { + reverse_last_axis[i] = i == FFTRank; + } + + if (neg_sizes[FFTRank] != 0) { + full_fft.slice(neg_target_indices, neg_sizes).device(device) = + full_fft.slice(neg_start_indices, neg_sizes) + .reverse(reverse_last_axis) + .conjugate(); + } + + auto inner_axis = Eigen::array{FFTRank}; + output.device(device) = + full_fft.template fft(inner_axis); +} + +template +void EigenFftWithRank(const EigenDevice& device, void* out, void* operand, + int32 fft_type, int64 input_batch, int64 fft_length0, + int64 fft_length1, int64 fft_length2) { + CHECK(::xla::FftType_IsValid(fft_type)) << fft_type; + switch (fft_type) { + case ::xla::FftType::FFT: + EigenFftC2C( + device, static_cast(out), + static_cast(operand), input_batch, fft_length0, + fft_length1, fft_length2); + break; + case ::xla::FftType::IFFT: + EigenFftC2C( + device, static_cast(out), + static_cast(operand), input_batch, fft_length0, + fft_length1, fft_length2); + break; + case ::xla::FftType::RFFT: + EigenFftR2C( + device, static_cast(out), static_cast(operand), + input_batch, fft_length0, fft_length1, fft_length2); + break; + case ::xla::FftType::IRFFT: + EigenFftC2R( + device, static_cast(out), static_cast(operand), + input_batch, fft_length0, fft_length1, fft_length2); + break; + default: + LOG(FATAL) << "Unsupported FFT type: " << fft_type; + } +} + +} // namespace internal + +template +void EigenFftImpl(const EigenDevice& device, void* out, void* operand, + int32 fft_type, int32 fft_rank, int64 input_batch, + int64 fft_length0, int64 fft_length1, int64 fft_length2) { + switch (fft_rank) { + case 1: + internal::EigenFftWithRank<1, EigenDevice>( + device, out, operand, fft_type, input_batch, fft_length0, 0, 0); + break; + case 2: + internal::EigenFftWithRank<2, EigenDevice>(device, out, operand, fft_type, + input_batch, fft_length0, + fft_length1, 0); + break; + case 3: + internal::EigenFftWithRank<3, EigenDevice>(device, out, operand, fft_type, + input_batch, fft_length0, + fft_length1, fft_length2); + break; + default: + LOG(FATAL) << "Unsupported FFT rank " << fft_rank; + } +} + +} // namespace xla +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_FFT_IMPL_H_ diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc index cda2783307925b77ac6d8cfe679c5b325db2befc..5403bf48b748c587802c6ed7abb4699e8395ca67 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc @@ -15,7 +15,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h" -#include #include #include #include @@ -34,10 +33,12 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h" #include "tensorflow/compiler/xla/service/cpu/orc_jit_memory_mapper.h" #include "tensorflow/compiler/xla/service/cpu/runtime_conv2d.h" +#include "tensorflow/compiler/xla/service/cpu/runtime_fft.h" #include "tensorflow/compiler/xla/service/cpu/runtime_fork_join.h" #include "tensorflow/compiler/xla/service/cpu/runtime_matmul.h" #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.h" #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h" +#include "tensorflow/compiler/xla/service/cpu/windows_compatibility.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/platform/logging.h" @@ -102,9 +103,21 @@ llvm::StringRef GetHostCpuName() { CompilerFunctor::VectorIntrinsics GetAvailableIntrinsics() { CompilerFunctor::VectorIntrinsics intrinsics; - intrinsics.sse_intrinsics = (&__xla_cpu_runtime_ExpV4F32SSE != nullptr); - intrinsics.avx_intrinsics = (&__xla_cpu_runtime_ExpV8F32AVX != nullptr); - intrinsics.neon_intrinsics = (&__xla_cpu_runtime_ExpV4F32NEON != nullptr); +#ifdef TF_XLA_HAS_SSE4_1 + intrinsics.sse_intrinsics = true; +#else + intrinsics.sse_intrinsics = false; +#endif +#ifdef TF_XLA_HAS_AVX + intrinsics.avx_intrinsics = true; +#else + intrinsics.avx_intrinsics = false; +#endif +#ifdef TF_XLA_HAS_NEON + intrinsics.neon_intrinsics = true; +#else + intrinsics.neon_intrinsics = false; +#endif return intrinsics; } @@ -196,17 +209,24 @@ bool RegisterKnownJITSymbols() { REGISTER_CPU_RUNTIME_SYMBOL(AcquireInfeedBufferForDequeue); REGISTER_CPU_RUNTIME_SYMBOL(AcquireOutfeedBufferForPopulation); REGISTER_CPU_RUNTIME_SYMBOL(EigenConvF32); + REGISTER_CPU_RUNTIME_SYMBOL(EigenFft); REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulF32); REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulF64); REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedConvF32); REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF32); REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF64); +#ifdef TF_XLA_HAS_NEON REGISTER_CPU_RUNTIME_SYMBOL(ExpV4F32NEON); - REGISTER_CPU_RUNTIME_SYMBOL(ExpV4F32SSE); - REGISTER_CPU_RUNTIME_SYMBOL(ExpV8F32AVX); REGISTER_CPU_RUNTIME_SYMBOL(LogV4F32NEON); +#endif +#ifdef TF_XLA_HAS_SSE4_1 + REGISTER_CPU_RUNTIME_SYMBOL(ExpV4F32SSE); REGISTER_CPU_RUNTIME_SYMBOL(LogV4F32SSE); +#endif +#ifdef TF_XLA_HAS_AVX + REGISTER_CPU_RUNTIME_SYMBOL(ExpV8F32AVX); REGISTER_CPU_RUNTIME_SYMBOL(LogV8F32AVX); +#endif REGISTER_CPU_RUNTIME_SYMBOL(ParallelForkJoin); REGISTER_CPU_RUNTIME_SYMBOL(ReleaseInfeedBufferAfterDequeue); REGISTER_CPU_RUNTIME_SYMBOL(ReleaseOutfeedBufferAfterPopulation); @@ -275,7 +295,11 @@ bool RegisterKnownJITSymbols() { REGISTER_LIBM_SYMBOL(scalbln, double (*)(double, long)); REGISTER_LIBM_SYMBOL(scalbn, double (*)(double, int)); REGISTER_LIBM_SYMBOL(sin, double (*)(double)); +#ifdef __APPLE__ + REGISTER_LIBM_SYMBOL(__sincos, void (*)(double, double*, double*)); +#else REGISTER_LIBM_SYMBOL(sincos, void (*)(double, double*, double*)); +#endif REGISTER_LIBM_SYMBOL(sinh, double (*)(double)); REGISTER_LIBM_SYMBOL(sqrt, double (*)(double)); REGISTER_LIBM_SYMBOL(tan, double (*)(double)); diff --git a/tensorflow/compiler/xla/service/cpu/target_machine_features.cc b/tensorflow/compiler/xla/service/cpu/target_machine_features.cc new file mode 100644 index 0000000000000000000000000000000000000000..eeb049737dddd11ef2ce229df772baec3ac03dd8 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/target_machine_features.cc @@ -0,0 +1,35 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/target_machine_features.h" + +namespace xla { +namespace cpu { + +llvm::TargetTransformInfo* TargetMachineFeatures::GetTargetTransformInfoFor( + const llvm::Function& function) const { + auto it = target_transform_info_cache_.find(&function); + if (it == target_transform_info_cache_.end()) { + auto emplace_result = target_transform_info_cache_.emplace( + &function, target_machine_->getTargetTransformInfo(function)); + CHECK(emplace_result.second); + it = emplace_result.first; + } + + return &it->second; +} + +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/target_machine_features.h b/tensorflow/compiler/xla/service/cpu/target_machine_features.h new file mode 100644 index 0000000000000000000000000000000000000000..703942615e552dccde7ddec8c8b90e8a486652af --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/target_machine_features.h @@ -0,0 +1,84 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_TARGET_MACHINE_FEATURES_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_TARGET_MACHINE_FEATURES_H_ + +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/Target/TargetMachine.h" +#include "tensorflow/compiler/xla/primitive_util.h" +#include "tensorflow/core/lib/gtl/flatmap.h" + +namespace xla { +namespace cpu { + +// Wraps an llvm::TargetMachine and parses out some information that feeds into +// LLVM IR code generation decisions. +class TargetMachineFeatures { + public: + static constexpr int kX86AvxVectorByteSize = 32; + + TargetMachineFeatures(llvm::TargetMachine* target_machine) + : target_machine_(target_machine) {} + + // Return the vectorization factor, which is the number of bytes of data + // explicitly vectorized routines will try to process at once. + int vectorization_factor_in_bytes() const { + // Ideally this should be a function of the cache line size (which we can + // get from llvm::TargetTransformInfo::getCacheLineSize) of the target + // machine. Guess a value of 128 bytes for now. + return 128; + } + + // Return the size of the largest vector size in bytes. We need to pass in + // "function" since llvm functions can contain annotations for specializing + // them to specific micro-architectures (though currently XLA does not use + // this functionality). + int vector_register_byte_size(const llvm::Function& function) const { + llvm::TargetTransformInfo* tti = GetTargetTransformInfoFor(function); + return tti->getRegisterBitWidth(/*Vector=*/true) / 8; + } + + // Return the number of elements of type `type` that can fit into the largest + // vector register available. We need to pass in "function" since llvm + // functions can contain annotations for specializing them to specific + // micro-architectures (though currently XLA does not use this functionality). + int vector_register_num_elements(const llvm::Function& function, + PrimitiveType type) const { + return vector_register_byte_size(function) / + (primitive_util::BitWidth(type) / 8); + } + + private: + llvm::TargetTransformInfo* GetTargetTransformInfoFor( + const llvm::Function& function) const; + + // This cache saves us from having to create a llvm::TargetTransformInfo for + // every call to GetTargetTransformInfoFor (creating a TargetTransformInfo + // costs one heap allocation on X86). + // + // This is mutated from within `GetTargetTransformInfoFor` which is + // semantically a getter (and thus `const`); and is therefore declared + // mutable. Making this mutable is okay because it has cache semantics. + mutable tensorflow::gtl::FlatMap + target_transform_info_cache_; + llvm::TargetMachine* target_machine_; +}; + +} // namespace cpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_TARGET_MACHINE_FEATURES_H_ diff --git a/tensorflow/compiler/xla/service/llvm_ir/vector_support_library.cc b/tensorflow/compiler/xla/service/cpu/vector_support_library.cc similarity index 51% rename from tensorflow/compiler/xla/service/llvm_ir/vector_support_library.cc rename to tensorflow/compiler/xla/service/cpu/vector_support_library.cc index e8c6a83618eaa8430521197f1c166cb7eb11a28e..128b465be239130918687d8e2ba0458684086ee1 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/vector_support_library.cc +++ b/tensorflow/compiler/xla/service/cpu/vector_support_library.cc @@ -13,11 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/llvm_ir/vector_support_library.h" +#include "tensorflow/compiler/xla/service/cpu/vector_support_library.h" +#include "tensorflow/compiler/xla/service/cpu/target_machine_features.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" namespace xla { +namespace cpu { VectorSupportLibrary::VectorSupportLibrary(PrimitiveType primitive_type, int64 vector_size, llvm::IRBuilder<>* ir_builder, @@ -34,6 +36,12 @@ VectorSupportLibrary::VectorSupportLibrary(PrimitiveType primitive_type, } llvm::Value* VectorSupportLibrary::Mul(llvm::Value* lhs, llvm::Value* rhs) { + CHECK(lhs->getType() == scalar_type() || lhs->getType() == vector_type()); + return MulInternal(lhs, rhs); +} + +llvm::Value* VectorSupportLibrary::MulInternal(llvm::Value* lhs, + llvm::Value* rhs) { if (scalar_type_->isFloatingPointTy()) { return ir_builder()->CreateFMul(lhs, rhs, name()); } else { @@ -42,6 +50,12 @@ llvm::Value* VectorSupportLibrary::Mul(llvm::Value* lhs, llvm::Value* rhs) { } llvm::Value* VectorSupportLibrary::Add(llvm::Value* lhs, llvm::Value* rhs) { + CHECK(lhs->getType() == scalar_type() || lhs->getType() == vector_type()); + return AddInternal(lhs, rhs); +} + +llvm::Value* VectorSupportLibrary::AddInternal(llvm::Value* lhs, + llvm::Value* rhs) { if (scalar_type_->isFloatingPointTy()) { return ir_builder()->CreateFAdd(lhs, rhs, name()); } else { @@ -129,6 +143,123 @@ llvm::Value* VectorSupportLibrary::AddReduce(llvm::Value* vector) { name()); } +llvm::Value* VectorSupportLibrary::AvxStyleHorizontalAdd(llvm::Value* lhs, + llvm::Value* rhs) { + CHECK_EQ(lhs->getType(), vector_type()); + CHECK_EQ(rhs->getType(), vector_type()); + CHECK_EQ(vector_size() % 2, 0); + + llvm::SmallVector mask_a, mask_b; + + // Adding the values shuffled using mask_a and mask_b gives us the + // AVX-style horizontal add we want. The masks work as documented + // in https://llvm.org/docs/LangRef.html#shufflevector-instruction + // + // Here are the masks for vector_width() == 8: + // + // index: |0 |1 |2 | 3 |4 |5 | 6 | 7 + // --------+--+--+--+---+--+--+---+--- + // mask_a: |0 |2 |8 |10 |4 |6 |12 |14 + // mask_b: |1 |3 |9 |11 |5 |7 |13 |16 + // + // So, as an example, the value at lane 3 of the result vector is + // the result of adding lane 10 and lane 11 in the combined lhs++rhs + // vector, which are the lanes 2 and 3 in the rhs vector. + for (int i = 0; i < vector_size(); i += 2) { + int increment = i < vector_size() / 2 ? 0 : (vector_size() / 2); + mask_a.push_back(ir_builder()->getInt32(increment + i)); + mask_b.push_back(ir_builder()->getInt32(increment + i + 1)); + } + for (int i = 0; i < vector_size(); i += 2) { + int increment = i < vector_size() / 2 ? (vector_size() / 2) : vector_size(); + mask_a.push_back(ir_builder()->getInt32(increment + i)); + mask_b.push_back(ir_builder()->getInt32(increment + i + 1)); + } + + llvm::Value* shuffle_0 = ir_builder()->CreateShuffleVector( + lhs, rhs, llvm::ConstantVector::get(mask_a)); + llvm::Value* shuffle_1 = ir_builder()->CreateShuffleVector( + lhs, rhs, llvm::ConstantVector::get(mask_b)); + + return Add(shuffle_0, shuffle_1); +} + +llvm::Value* VectorSupportLibrary::ExtractLowHalf(llvm::Value* vector) { + llvm::SmallVector mask; + for (int i = 0; i < vector_size() / 2; i++) { + mask.push_back(ir_builder()->getInt32(i)); + } + + return ir_builder()->CreateShuffleVector(vector, + llvm::UndefValue::get(vector_type()), + llvm::ConstantVector::get(mask)); +} + +llvm::Value* VectorSupportLibrary::ExtractHighHalf(llvm::Value* vector) { + llvm::SmallVector mask; + for (int i = 0; i < vector_size() / 2; i++) { + mask.push_back(ir_builder()->getInt32(i + vector_size() / 2)); + } + + return ir_builder()->CreateShuffleVector(vector, + llvm::UndefValue::get(vector_type()), + llvm::ConstantVector::get(mask)); +} + +std::vector VectorSupportLibrary::ComputeHorizontalSums( + std::vector vectors, llvm::Value* init_values) { + const int x86_avx_vector_elements = + TargetMachineFeatures::kX86AvxVectorByteSize / scalar_byte_size(); + if (vector_size() == x86_avx_vector_elements && + vectors.size() == x86_avx_vector_elements) { + return ComputeAvxOptimizedHorizontalSums(std::move(vectors), init_values); + } + + std::vector result; + std::transform(vectors.begin(), vectors.end(), std::back_inserter(result), + [this](llvm::Value* vector) { return AddReduce(vector); }); + if (init_values) { + for (int64 i = 0, e = result.size(); i < e; i++) { + result[i] = Add(result[i], ir_builder()->CreateExtractElement( + init_values, ir_builder()->getInt32(i))); + } + } + return result; +} + +std::vector +VectorSupportLibrary::ComputeAvxOptimizedHorizontalSums( + std::vector vectors, llvm::Value* init_values) { + while (vectors.size() != 2) { + std::vector new_vectors; + for (int i = 0; i < vectors.size(); i += 2) { + new_vectors.push_back(AvxStyleHorizontalAdd(vectors[i], vectors[i + 1])); + } + + vectors = std::move(new_vectors); + } + + llvm::Value* low = + AddInternal(ExtractLowHalf(vectors[0]), ExtractHighHalf(vectors[0])); + if (init_values) { + low = AddInternal(ExtractLowHalf(init_values), low); + } + llvm::Value* high = + AddInternal(ExtractLowHalf(vectors[1]), ExtractHighHalf(vectors[1])); + if (init_values) { + high = AddInternal(ExtractHighHalf(init_values), high); + } + + std::vector results; + for (int i = 0; i < 8; i++) { + llvm::Value* scalar_result = ir_builder()->CreateExtractElement( + i < 4 ? low : high, ir_builder()->getInt32(i % 4), name()); + results.push_back(scalar_result); + } + + return results; +} + llvm::Value* VectorSupportLibrary::GetZeroVector() { return llvm::Constant::getNullValue(vector_type()); } @@ -142,9 +273,12 @@ LlvmVariable::LlvmVariable(llvm::Type* type, llvm::IRBuilder<>* ir_builder) alloca_ = llvm_ir::EmitAllocaAtFunctionEntry(type, "", ir_builder_); } -llvm::Value* LlvmVariable::Get() { return ir_builder_->CreateLoad(alloca_); } +llvm::Value* LlvmVariable::Get() const { + return ir_builder_->CreateLoad(alloca_); +} void LlvmVariable::Set(llvm::Value* new_value) { ir_builder_->CreateStore(new_value, alloca_); } +} // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_ir/vector_support_library.h b/tensorflow/compiler/xla/service/cpu/vector_support_library.h similarity index 76% rename from tensorflow/compiler/xla/service/llvm_ir/vector_support_library.h rename to tensorflow/compiler/xla/service/cpu/vector_support_library.h index 3072677ab05aa91c736baaa0dc3023329d810a52..8fbac2a6670f8ef18c00877a1566bd4ab896a7c8 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/vector_support_library.h +++ b/tensorflow/compiler/xla/service/cpu/vector_support_library.h @@ -13,17 +13,19 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_VECTOR_SUPPORT_LIBRARY_H_ -#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_VECTOR_SUPPORT_LIBRARY_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_VECTOR_SUPPORT_LIBRARY_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_VECTOR_SUPPORT_LIBRARY_H_ #include #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Value.h" +#include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { +namespace cpu { // A thin wrapper around llvm_util.h to make code generating vector math flow // more readable. class VectorSupportLibrary { @@ -111,7 +113,12 @@ class VectorSupportLibrary { return LoadBroadcast(base_pointer, ir_builder()->getInt64(offset_elements)); } - llvm::Value* AddReduce(llvm::Value* vector); + // Compute the horizontal sum of each vector in `vectors`. The i'th element + // in the result vector is the (scalar) horizontal sum of the i'th vector in + // `vectors`. If `init_values` is not nullptr then the value in the i'th lane + // in `init_values` is added to the i'th horizontal sum. + std::vector ComputeHorizontalSums( + std::vector vectors, llvm::Value* init_values = nullptr); llvm::Value* GetZeroVector(); llvm::Value* GetZeroScalar(); @@ -122,10 +129,40 @@ class VectorSupportLibrary { llvm::Type* vector_pointer_type() const { return vector_pointer_type_; } llvm::Type* scalar_type() const { return scalar_type_; } llvm::Type* scalar_pointer_type() const { return scalar_pointer_type_; } + int64 scalar_byte_size() const { + return primitive_util::BitWidth(primitive_type_) / 8; + } const std::string& name() const { return name_; } private: + llvm::Value* ExtractLowHalf(llvm::Value*); + llvm::Value* ExtractHighHalf(llvm::Value*); + + llvm::Value* MulInternal(llvm::Value* lhs, llvm::Value* rhs); + llvm::Value* AddInternal(llvm::Value* lhs, llvm::Value* rhs); + + llvm::Value* AddReduce(llvm::Value* vector); + + // Perform an X86 AVX style horizontal add between `lhs` and `rhs`. The + // resulting IR for an 8-float wide vector is expected to lower to a single + // vhaddps instruction on a CPU that supports vhaddps, and not be too bad in + // other cases. + // + // For a vector width of 8, the result vector is computed as: + // Result[0] = Lhs[0] + Lhs[1] + // Result[1] = Lhs[2] + Lhs[3] + // Result[2] = Rhs[0] + Rhs[1] + // Result[3] = Rhs[2] + Rhs[3] + // Result[4] = Lhs[4] + Lhs[5] + // Result[5] = Lhs[6] + Lhs[7] + // Result[6] = Rhs[4] + Rhs[5] + // Result[7] = Rhs[6] + Rhs[7] + llvm::Value* AvxStyleHorizontalAdd(llvm::Value* lhs, llvm::Value* rhs); + + std::vector ComputeAvxOptimizedHorizontalSums( + std::vector vectors, llvm::Value* init_values); + int64 vector_size_; PrimitiveType primitive_type_; llvm::IRBuilder<>* ir_builder_; @@ -142,7 +179,7 @@ class LlvmVariable { public: LlvmVariable(llvm::Type*, llvm::IRBuilder<>* ir_builder); - llvm::Value* Get(); + llvm::Value* Get() const; void Set(llvm::Value* new_value); private: @@ -169,6 +206,7 @@ class ScalarVariable : public LlvmVariable { Set(initial_value); } }; +} // namespace cpu } // namespace xla -#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_VECTOR_SUPPORT_LIBRARY_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_VECTOR_SUPPORT_LIBRARY_H_ diff --git a/tensorflow/compiler/xla/service/cpu/windows_compatibility.cc b/tensorflow/compiler/xla/service/cpu/windows_compatibility.cc new file mode 100644 index 0000000000000000000000000000000000000000..ab308ee6cb16ba95e24694b59a4b5737765bbb8b --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/windows_compatibility.cc @@ -0,0 +1,32 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/windows_compatibility.h" + +#ifdef _MSC_VER + +#include + +void sincos(double x, double *sinv, double *cosv) { + *sinv = sin(x); + *cosv = cos(x); +} + +void sincosf(float x, float *sinv, float *cosv) { + *sinv = sinf(x); + *cosv = cosf(x); +} + +#endif // _MSC_VER diff --git a/tensorflow/compiler/xla/service/cpu/windows_compatibility.h b/tensorflow/compiler/xla/service/cpu/windows_compatibility.h new file mode 100644 index 0000000000000000000000000000000000000000..262f379d8b6017f4a7e0156b724bfee7e8ec5b9a --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/windows_compatibility.h @@ -0,0 +1,31 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_WINDOWS_COMPATIBILITY_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_WINDOWS_COMPATIBILITY_H_ + +#ifdef _MSC_VER + +extern "C" { + +// MSVC does not have sincos[f]. +void sincos(double x, double *sinv, double *cosv); +void sincosf(float x, float *sinv, float *cosv); + +} + +#endif // _MSC_VER + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_WINDOWS_COMPATIBILITY_H_ diff --git a/tensorflow/compiler/xla/service/cpu/xfeed_manager.cc b/tensorflow/compiler/xla/service/cpu/xfeed_manager.cc index d0f214202908266371639af8f431ad8269ad0e35..47543b2082f55cf7b8cf60f1c5bbb16a0a609912 100644 --- a/tensorflow/compiler/xla/service/cpu/xfeed_manager.cc +++ b/tensorflow/compiler/xla/service/cpu/xfeed_manager.cc @@ -41,6 +41,8 @@ void XfeedQueueManager::EnqueueBuffersAtomically( tensorflow::mutex_lock l(mu_); bool was_empty = enqueued_buffers_.empty(); for (XfeedBuffer* b : buffers) { + VLOG(3) << "Enqueueing " << queue_name_ << " buffer (of " << buffers.size() + << " buffers) with length: " << b->length(); enqueued_buffers_.push_back(b); } if (was_empty && !buffers.empty()) { @@ -54,9 +56,11 @@ void XfeedQueueManager::EnqueueBuffersAtomically( XfeedBuffer* XfeedQueueManager::BlockingDequeueBuffer() { tensorflow::mutex_lock l(mu_); + VLOG(3) << "Waiting for an available buffer."; while (enqueued_buffers_.empty()) { cv_.wait(l); } + VLOG(3) << "A buffer is available!"; CHECK(current_buffer_ == nullptr); current_buffer_ = enqueued_buffers_.front(); enqueued_buffers_.pop_front(); @@ -65,6 +69,9 @@ XfeedBuffer* XfeedQueueManager::BlockingDequeueBuffer() { void XfeedQueueManager::ReleaseCurrentBuffer(int32 length, void* data, StatusOr shape) { + VLOG(3) << "Releasing buffer with shape: " + << (shape.ok() ? ShapeUtil::HumanString(shape.ValueOrDie()) + : ""); tensorflow::mutex_lock l(mu_); CHECK(current_buffer_ != nullptr); CHECK_EQ(length, current_buffer_->length()); diff --git a/tensorflow/compiler/xla/service/cpu/xfeed_manager.h b/tensorflow/compiler/xla/service/cpu/xfeed_manager.h index 6af55700052007a2ca419d52b63dddea2052bd0b..b4ace232607e14fbfec01d48946f0031d96cd027 100644 --- a/tensorflow/compiler/xla/service/cpu/xfeed_manager.h +++ b/tensorflow/compiler/xla/service/cpu/xfeed_manager.h @@ -50,7 +50,7 @@ class XfeedBuffer { // Reusable component for managing the infeed and outfeed queue state. class XfeedQueueManager { public: - XfeedQueueManager() = default; + XfeedQueueManager(string queue_name) : queue_name_(queue_name) {} // Calls the completion callback for any enqueued buffers that have // not been dequeued by the runtime, and empties the @@ -86,6 +86,8 @@ class XfeedQueueManager { void ReleaseCurrentBuffer(int32 length, void* data, StatusOr shape); private: + const string queue_name_; + tensorflow::mutex mu_; // Condition variable that is signaled every time a buffer is @@ -112,8 +114,8 @@ class XfeedManager { XfeedQueueManager* outfeed() { return &outfeed_; } private: - XfeedQueueManager infeed_; - XfeedQueueManager outfeed_; + XfeedQueueManager infeed_ = {"infeed"}; + XfeedQueueManager outfeed_ = {"outfeed"}; }; } // namespace runtime diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h index 91086fd4a5f68211ef56c2417bb0ef4a38de2cff..a803b3171f9afa6297553c5507c4f9aa45e420ab 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h @@ -103,6 +103,7 @@ class DfsHloVisitorBase { return HandleElementwiseBinary(hlo); } virtual Status HandleConvolution(HloInstructionPtr hlo) = 0; + virtual Status HandleFft(HloInstructionPtr fft) = 0; virtual Status HandleCrossReplicaSum(HloInstructionPtr hlo) = 0; virtual Status HandleCompare(HloInstructionPtr hlo) { return HandleElementwiseBinary(hlo); @@ -247,6 +248,10 @@ class DfsHloVisitorBase { // affecting correctness. void ReserveVisitStates(int num) { visit_state_.Reserve(num); } + // Useful when we want to visit the same computation more than once with the + // same visitor. + void ResetVisitStates() { visit_state_.Reset(); } + void SetVisitState(int id, VisitState state) { visit_state_.SetState(id, state); } @@ -326,6 +331,7 @@ class DfsHloVisitorBase { *w = (*w & ~mask) | (static_cast(state) << shift); DCHECK_EQ(GetState(id), state); } + void Reset() { states_.clear(); } private: static const uint32 kStatesPerWord = sizeof(uint64) / 2 /*bits per entry*/; diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h index 133aa2509405738de8388708b0c61a82023e2738..170adb3d241b3648bc53f96dde9866f0b794f80a 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h @@ -85,6 +85,9 @@ class DfsHloVisitorWithDefaultBase Status HandleConvolution(HloInstructionPtr convolution) override { return DefaultAction(convolution); } + Status HandleFft(HloInstructionPtr fft) override { + return DefaultAction(fft); + } Status HandleCrossReplicaSum(HloInstructionPtr crs) override { return DefaultAction(crs); } diff --git a/tensorflow/compiler/xla/service/dot_decomposer.cc b/tensorflow/compiler/xla/service/dot_decomposer.cc new file mode 100644 index 0000000000000000000000000000000000000000..12faed69677cd99c6ed82c8d13dad3138d9461b7 --- /dev/null +++ b/tensorflow/compiler/xla/service/dot_decomposer.cc @@ -0,0 +1,185 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/dot_decomposer.h" + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { + +namespace { + +// TODO(b/69062148) Remove this code when all backends support BatchDot +// natively. +Status DecomposeBatchDot(HloInstruction* dot) { + auto computation = dot->parent(); + const DotDimensionNumbers& dnums = dot->dot_dimension_numbers(); + HloInstruction* lhs = dot->mutable_operand(0); + HloInstruction* rhs = dot->mutable_operand(1); + const Shape& lhs_shape = lhs->shape(); + const Shape& rhs_shape = rhs->shape(); + const Shape& dot_shape = dot->shape(); + + // ShapeInference should guarantee that lhs/rhs batch dimensions match. + CHECK_EQ(dnums.lhs_batch_dimensions_size(), + dnums.rhs_batch_dimensions_size()); + const int64 num_batch_dims = dnums.lhs_batch_dimensions_size(); + // Calculate total batch size (note that ShapeInference requires that + // the batch dimensions are most-major). + int64 batch_size = 1; + for (int i = 0; i < num_batch_dims; ++i) { + CHECK_EQ(lhs_shape.dimensions(dnums.lhs_batch_dimensions(i)), + rhs_shape.dimensions(dnums.rhs_batch_dimensions(i))); + batch_size *= lhs_shape.dimensions(dnums.lhs_batch_dimensions(i)); + } + + // Set lhs/rhs_transpose. + CHECK_EQ(1, dnums.lhs_contracting_dimensions_size()); + const int64 lhs_contracting_dim_number = dnums.lhs_contracting_dimensions(0); + const bool lhs_transpose = (lhs_contracting_dim_number - num_batch_dims) == 0; + + CHECK_EQ(1, dnums.rhs_contracting_dimensions_size()); + const int64 rhs_contracting_dim_number = dnums.rhs_contracting_dimensions(0); + const bool rhs_transpose = (rhs_contracting_dim_number - num_batch_dims) == 1; + + // Compute R3 and R3 shapes for lhs. + PrimitiveType lhs_type = lhs_shape.element_type(); + const int64 lhs_rows = lhs_shape.dimensions(num_batch_dims + 0); + const int64 lhs_cols = lhs_shape.dimensions(num_batch_dims + 1); + Shape lhs_shape_r3 = + ShapeUtil::MakeShape(lhs_type, {batch_size, lhs_rows, lhs_cols}); + Shape lhs_slice_shape_r3 = + ShapeUtil::MakeShape(lhs_type, {1, lhs_rows, lhs_cols}); + Shape lhs_slice_shape_r2 = + ShapeUtil::MakeShape(lhs_type, {lhs_rows, lhs_cols}); + + // Compute R3 and R3 shapes for rhs. + PrimitiveType rhs_type = rhs_shape.element_type(); + const int64 rhs_rows = rhs_shape.dimensions(num_batch_dims + 0); + const int64 rhs_cols = rhs_shape.dimensions(num_batch_dims + 1); + Shape rhs_shape_r3 = + ShapeUtil::MakeShape(rhs_type, {batch_size, rhs_rows, rhs_cols}); + Shape rhs_slice_shape_r3 = + ShapeUtil::MakeShape(rhs_type, {1, rhs_rows, rhs_cols}); + Shape rhs_slice_shape_r2 = + ShapeUtil::MakeShape(rhs_type, {rhs_rows, rhs_cols}); + + // Compute R3 and R3 shapes for dot output. + PrimitiveType dot_type = dot_shape.element_type(); + const int64 dot_rows = dot_shape.dimensions(num_batch_dims + 0); + const int64 dot_cols = dot_shape.dimensions(num_batch_dims + 1); + Shape dot_shape_r2 = ShapeUtil::MakeShape(dot_type, {dot_rows, dot_cols}); + Shape dot_shape_r3 = ShapeUtil::MakeShape(dot_type, {1, dot_rows, dot_cols}); + Shape concat_shape_r3 = + ShapeUtil::MakeShape(dot_type, {batch_size, dot_rows, dot_cols}); + + // Reshape lhs/rhs into R3. + auto lhs_r3 = computation->AddInstruction( + HloInstruction::CreateReshape(lhs_shape_r3, lhs)); + auto rhs_r3 = computation->AddInstruction( + HloInstruction::CreateReshape(rhs_shape_r3, rhs)); + + // Loop through batch size, slicing out required lhs/rhs to compute each Dot. + std::vector output_slices(batch_size); + for (int64 i = 0; i < batch_size; ++i) { + // Slice R3 shape from 'lhs' and reshape to R2. + auto lhs_slice_r3 = computation->AddInstruction( + HloInstruction::CreateSlice(lhs_slice_shape_r3, lhs_r3, {i, 0, 0}, + {i + 1, lhs_rows, lhs_cols}, {1, 1, 1})); + auto lhs_slice_r2 = computation->AddInstruction( + HloInstruction::CreateReshape(lhs_slice_shape_r2, lhs_slice_r3)); + + // Slice R3 shape from 'rhs' and reshape to R2. + auto rhs_slice_r3 = computation->AddInstruction( + HloInstruction::CreateSlice(rhs_slice_shape_r3, rhs_r3, {i, 0, 0}, + {i + 1, rhs_rows, rhs_cols}, {1, 1, 1})); + auto rhs_slice_r2 = computation->AddInstruction( + HloInstruction::CreateReshape(rhs_slice_shape_r2, rhs_slice_r3)); + + // Transpose lhs/rhs (if needed). + if (lhs_transpose) { + Shape lhs_slice_shape_r2_transpose = + ShapeUtil::MakeShape(lhs_type, {lhs_cols, lhs_rows}); + lhs_slice_r2 = + computation->AddInstruction(HloInstruction::CreateTranspose( + lhs_slice_shape_r2_transpose, lhs_slice_r2, {1, 0})); + } + if (rhs_transpose) { + Shape rhs_slice_shape_r2_transpose = + ShapeUtil::MakeShape(rhs_type, {rhs_cols, rhs_rows}); + rhs_slice_r2 = + computation->AddInstruction(HloInstruction::CreateTranspose( + rhs_slice_shape_r2_transpose, rhs_slice_r2, {1, 0})); + } + + // Compute Dot of lhs/rhs R2 slices. + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + auto dot_r2 = computation->AddInstruction(HloInstruction::CreateDot( + dot_shape_r2, lhs_slice_r2, rhs_slice_r2, dot_dnums)); + + // Reshape Dot to R3 so we can concat along batch dimension. + auto dot_r3 = computation->AddInstruction( + HloInstruction::CreateReshape(dot_shape_r3, dot_r2)); + + output_slices[i] = dot_r3; + } + + // Concatenate slices from 'output_slices' along batch dimension. + auto concat = computation->AddInstruction( + HloInstruction::CreateConcatenate(concat_shape_r3, output_slices, 0)); + // Reshape output 'new_dot' to original dimensions. + auto new_dot = computation->AddInstruction( + HloInstruction::CreateReshape(dot_shape, concat)); + + // Replace all uses of 'dot' in 'computation' with 'new_dot'. + return computation->ReplaceInstruction(dot, new_dot); +} + +} // namespace + +StatusOr DotDecomposer::Run(HloModule* module) { + XLA_VLOG_LINES(2, "DotDecomposer ENTRY\n" + module->ToString()); + // Gather all batch Dot operations. + std::vector batch_dots; + for (auto* computation : module->MakeNonfusionComputations()) { + for (auto* instruction : computation->instructions()) { + if (instruction->opcode() != HloOpcode::kDot) { + continue; + } + const DotDimensionNumbers& dnums = instruction->dot_dimension_numbers(); + if (dnums.lhs_batch_dimensions_size() > 0 && decompose_batch_dot_) { + batch_dots.push_back(instruction); + } + } + } + // Decompose each batch Dot in 'batch_dots'. + bool changed = false; + for (auto* dot : batch_dots) { + TF_RETURN_IF_ERROR(DecomposeBatchDot(dot)); + changed = true; + } + XLA_VLOG_LINES(2, "DotDecompose EXIT\n" + module->ToString()); + return changed; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/dot_decomposer.h b/tensorflow/compiler/xla/service/dot_decomposer.h new file mode 100644 index 0000000000000000000000000000000000000000..5ff0ab34eac0cd0fbc264b408c57653c944402a6 --- /dev/null +++ b/tensorflow/compiler/xla/service/dot_decomposer.h @@ -0,0 +1,44 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_DOT_DECOMPOSER_H_ +#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_DOT_DECOMPOSER_H_ + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { + +// DotDecomposer is a pass which decomposes batch Dot operations into a +// sequence of smaller (R2) Dot operations. +class DotDecomposer : public HloPassInterface { + public: + // Decomposes batch Dot operations when 'decompose_batch_dot' is true. + DotDecomposer(bool decompose_batch_dot = true) + : decompose_batch_dot_(decompose_batch_dot) {} + ~DotDecomposer() = default; + tensorflow::StringPiece name() const override { return "dot_decomposer"; } + + // Run DotDecomposer pass on computations in 'module'. + // Returns whether the 'module' was changed. + StatusOr Run(HloModule* module) override; + + private: + bool decompose_batch_dot_; +}; + +} // namespace xla + +#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_DOT_DECOMPOSER_H_ diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index b9407818cd8bc82aabd32ed02f61ef66fe442625..9780bac16ec17eed2c1df64f01bcb753e26b46f0 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -50,11 +50,161 @@ using llvm_ir::IrName; using llvm_ir::SetToFirstInsertPoint; using tensorflow::strings::StrCat; +namespace { + +llvm::Value* EmitReducePrecisionFloat(llvm::Value* x, int64 exponent_bits, + int64 mantissa_bits, + llvm::IRBuilder<>* ir_builder) { + // Integer and float types for casting and constant generation. + llvm::Type* float_type = x->getType(); + llvm::IntegerType* int_type = ir_builder->getInt32Ty(); + + // Cast the input value to an integer for bitwise manipulation. + llvm::Value* x_as_int = ir_builder->CreateBitCast(x, int_type); + + if (mantissa_bits < 23) { + // Last remaining mantissa bit. + const uint32_t last_mantissa_bit_mask = 1u << (23 - mantissa_bits); + + // Compute rounding bias for round-to-nearest with ties to even. This is + // equal to a base value of 0111... plus one bit if the last remaining + // mantissa bit is 1. + const uint32_t base_rounding_bias = (last_mantissa_bit_mask >> 1) - 1; + llvm::Value* x_last_mantissa_bit = ir_builder->CreateLShr( + ir_builder->CreateAnd( + x_as_int, llvm::ConstantInt::get(int_type, last_mantissa_bit_mask)), + (23 - mantissa_bits)); + llvm::Value* x_rounding_bias = ir_builder->CreateAdd( + x_last_mantissa_bit, + llvm::ConstantInt::get(int_type, base_rounding_bias)); + + // Add rounding bias, and mask out truncated bits. Note that the case + // where adding the rounding bias overflows into the exponent bits is + // correct; the non-masked mantissa bits will all be zero, and the + // exponent will be incremented by one. + const uint32_t truncation_mask = ~(last_mantissa_bit_mask - 1); + x_as_int = ir_builder->CreateAdd(x_as_int, x_rounding_bias); + x_as_int = ir_builder->CreateAnd( + x_as_int, llvm::ConstantInt::get(int_type, truncation_mask)); + } + + if (exponent_bits < 8) { + // Masks for f32 values. + const uint32_t f32_sign_bit_mask = 1u << 31; + const uint32_t f32_exp_bits_mask = 0xffu << 23; + + // An exponent of 2^(n-1)-1 -- that is, 0111... with the zero in the most- + // significant bit -- is equal to 1.0f for all exponent sizes. Adding + // 2^(n-1)-1 to this gives us the highest non-infinite exponent for a bit- + // size of n, and subtracting 2^(n-1)-1 from this gives us the lowest' + // exponent (corresponding to 0.0f). + // + // Thus, the f32 exponent corresponding to the highest non-infinite + // exponent for a bit size of n is (2^7-1) + 2^(n-1)-1, and the f32 + // exponent corresponding to the lowest exponent for a bit size of n is + // (2^7-1) - 2^(n-1)-1. + // + // Note that we have already checked that exponents_bits >= 1. + const uint32_t f32_exponent_bias = (1 << 7) - 1; + const uint32_t reduced_exponent_bias = (1 << (exponent_bits - 1)) - 1; + const uint32_t reduced_max_exponent = + f32_exponent_bias + reduced_exponent_bias; + const uint32_t reduced_min_exponent = + f32_exponent_bias - reduced_exponent_bias; + + // Do we overflow or underflow? + llvm::Value* x_exponent = ir_builder->CreateAnd( + x_as_int, llvm::ConstantInt::get(int_type, f32_exp_bits_mask)); + llvm::Value* x_overflows = ir_builder->CreateICmpUGT( + x_exponent, + llvm::ConstantInt::get(int_type, reduced_max_exponent << 23)); + llvm::Value* x_underflows = ir_builder->CreateICmpULE( + x_exponent, + llvm::ConstantInt::get(int_type, reduced_min_exponent << 23)); + + // Compute appropriately-signed values of zero and infinity. + llvm::Value* x_signed_zero = ir_builder->CreateAnd( + x_as_int, llvm::ConstantInt::get(int_type, f32_sign_bit_mask)); + llvm::Value* x_signed_inf = ir_builder->CreateOr( + x_signed_zero, llvm::ConstantInt::get(int_type, f32_exp_bits_mask)); + + // Force to zero or infinity if overflow or underflow. (Note that this + // truncates all denormal values to zero, rather than rounding them.) + x_as_int = ir_builder->CreateSelect(x_overflows, x_signed_inf, x_as_int); + x_as_int = ir_builder->CreateSelect(x_underflows, x_signed_zero, x_as_int); + } + + // Cast the result back to a floating-point type. + llvm::Value* result = ir_builder->CreateBitCast(x_as_int, float_type); + + // Correct result for NaN inputs. + // + // The exponent handling will "normalize" NaN values to infinities, which is + // undesirable (except in the case with no mantissa bits, in which case it + // is mandatory). This logic also handles cases where mantissa-rounding + // causes a NaN's mantissa to overflow into the exponent bits, which would + // otherwise create an erroneous zero value. + // + // If the fast-math flags are set to assume no NaNs, the comparison is likely + // to be optimized away, so there's no point in even emitting it. + if (!ir_builder->getFastMathFlags().noNaNs()) { + llvm::Value* x_is_nan = ir_builder->CreateFCmpUNO(x, x); + + if (mantissa_bits > 0) { + result = ir_builder->CreateSelect(x_is_nan, x, result); + } else { + result = ir_builder->CreateSelect( + x_is_nan, llvm::ConstantFP::getInfinity(float_type), result); + } + } + return result; +} + +llvm::Value* EmitF32ToBF16(llvm::Value* f32_value, + llvm::IRBuilder<>* ir_builder) { + auto reduced_precision = EmitReducePrecisionFloat( + f32_value, + /*exponent_bits=*/primitive_util::kBFloat16ExponentBits, + /*mantissa_bits=*/primitive_util::kBFloat16MantissaBits, ir_builder); + auto as_int32 = + ir_builder->CreateBitCast(reduced_precision, ir_builder->getInt32Ty()); + auto shifted = ir_builder->CreateLShr(as_int32, 16); + auto truncated = ir_builder->CreateTrunc(shifted, ir_builder->getInt16Ty()); + return ir_builder->CreateBitCast(truncated, ir_builder->getInt16Ty()); +} + +llvm::Value* EmitBF16ToF32(llvm::Value* bf16_value, + llvm::IRBuilder<>* ir_builder) { + auto as_int16 = + ir_builder->CreateBitCast(bf16_value, ir_builder->getInt16Ty()); + auto as_int32 = ir_builder->CreateZExt(as_int16, ir_builder->getInt32Ty()); + auto shifted = ir_builder->CreateShl(as_int32, 16); + return ir_builder->CreateBitCast(shifted, ir_builder->getFloatTy()); +} + +llvm::Value* EmitIntegralToFloating(llvm::Value* integer_value, + PrimitiveType from_type, + PrimitiveType to_type, llvm::Module* module, + llvm::IRBuilder<>* ir_builder) { + if (primitive_util::IsSignedIntegralType(from_type)) { + return ir_builder->CreateSIToFP( + integer_value, llvm_ir::PrimitiveTypeToIrType(to_type, module)); + } else { + CHECK(primitive_util::IsUnsignedIntegralType(from_type) || + from_type == PRED); + return ir_builder->CreateUIToFP( + integer_value, llvm_ir::PrimitiveTypeToIrType(to_type, module)); + } +} + +} // namespace + StatusOr ElementalIrEmitter::EmitUnaryOp( const HloInstruction* op, llvm::Value* operand_value) const { if (op->opcode() == HloOpcode::kCopy) { return operand_value; - } else if (operand_value->getType()->isIntegerTy()) { + } else if (ShapeUtil::ElementIsIntegral(op->operand(0)->shape()) || + op->operand(0)->shape().element_type() == PRED) { return EmitIntegerUnaryOp(op, operand_value); } else if (ShapeUtil::ElementIsComplex(op->operand(0)->shape())) { return EmitComplexUnaryOp(op, operand_value); @@ -79,15 +229,14 @@ StatusOr ElementalIrEmitter::EmitIntegerUnaryOp( primitive_util::IsSignedIntegralType(to_type)); } if (primitive_util::IsFloatingPointType(to_type)) { - if (primitive_util::IsSignedIntegralType(from_type)) { - return ir_builder_->CreateSIToFP( - operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_)); - } - if (primitive_util::IsUnsignedIntegralType(from_type) || - from_type == PRED) { - return ir_builder_->CreateUIToFP( - operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_)); + if (to_type == BF16) { + return EmitF32ToBF16( + EmitIntegralToFloating(operand_value, from_type, F32, module_, + ir_builder_), + ir_builder_); } + return EmitIntegralToFloating(operand_value, from_type, to_type, + module_, ir_builder_); } if (primitive_util::IsComplexType(to_type)) { auto to_ir_component_type = llvm_ir::PrimitiveTypeToIrType( @@ -207,6 +356,17 @@ StatusOr ElementalIrEmitter::EmitFloatUnaryOp( llvm_ir::PrimitiveTypeToIrType(to_component_type, module_)), nullptr); } + if (from_type == BF16) { + TF_RET_CHECK(to_type != BF16); + operand_value = EmitBF16ToF32(operand_value, ir_builder_); + from_type = F32; + if (from_type == to_type) { + return operand_value; + } + } + if (from_type == F32 && to_type == BF16) { + return EmitF32ToBF16(operand_value, ir_builder_); + } if (primitive_util::IsFloatingPointType(to_type)) { return ir_builder_->CreateFPCast( operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_)); @@ -244,21 +404,13 @@ StatusOr ElementalIrEmitter::EmitFloatUnaryOp( primitive_util::BitWidth(to_type)); } case HloOpcode::kExp: - return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::exp, {operand_value}, - {operand_value->getType()}, - ir_builder_); + return EmitExp(op->shape().element_type(), operand_value); case HloOpcode::kLog: - return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::log, {operand_value}, - {operand_value->getType()}, - ir_builder_); + return EmitLog(op->shape().element_type(), operand_value); case HloOpcode::kCos: - return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::cos, {operand_value}, - {operand_value->getType()}, - ir_builder_); + return EmitCos(op->shape().element_type(), operand_value); case HloOpcode::kSin: - return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sin, {operand_value}, - {operand_value->getType()}, - ir_builder_); + return EmitSin(op->shape().element_type(), operand_value); case HloOpcode::kFloor: return llvm_ir::EmitCallToIntrinsic( llvm::Intrinsic::floor, {operand_value}, {operand_value->getType()}, @@ -309,9 +461,25 @@ StatusOr ElementalIrEmitter::EmitFloatUnaryOp( StatusOr ElementalIrEmitter::EmitComplexUnaryOp( const HloInstruction* op, llvm::Value* operand_value) const { + PrimitiveType input_type = op->operand(0)->shape().element_type(); + PrimitiveType component_type = + primitive_util::IsComplexType(input_type) + ? primitive_util::ComplexComponentType(input_type) + : input_type; switch (op->opcode()) { - // TODO(b/65209142): Angle/Log require atan2. - // case HloOpcode::kLog: // log(a+bi) = .5*log(a^2+b^2) + i*atan2(b, a) + case HloOpcode::kLog: { + // log(a+bi) = .5*log(a^2+b^2) + i*atan2(b, a) + auto a = EmitExtractReal(operand_value); + auto b = EmitExtractImag(operand_value); + llvm::Type* llvm_ty = a->getType(); + auto sum_sq = ir_builder_->CreateFAdd(ir_builder_->CreateFMul(a, a), + ir_builder_->CreateFMul(b, b)); + TF_ASSIGN_OR_RETURN(auto log_sum_sq, EmitLog(component_type, sum_sq)); + TF_ASSIGN_OR_RETURN(auto angle, EmitAtan2(component_type, b, a)); + auto one_half = llvm::ConstantFP::get(llvm_ty, 0.5); + return EmitComposeComplex( + op, ir_builder_->CreateFMul(one_half, log_sum_sq), angle); + } case HloOpcode::kConvert: { PrimitiveType from_type = op->operand(0)->shape().element_type(); TF_RET_CHECK(primitive_util::IsComplexType(from_type)); @@ -333,15 +501,12 @@ StatusOr ElementalIrEmitter::EmitComplexUnaryOp( } case HloOpcode::kExp: { // e^(a+bi) = e^a*(cos(b)+sin(b)i) - auto exp_a = llvm_ir::EmitCallToIntrinsic( - llvm::Intrinsic::exp, {EmitExtractReal(operand_value)}, - {EmitExtractReal(operand_value)->getType()}, ir_builder_); - auto cos_b = llvm_ir::EmitCallToIntrinsic( - llvm::Intrinsic::cos, {EmitExtractImag(operand_value)}, - {EmitExtractImag(operand_value)->getType()}, ir_builder_); - auto sin_b = llvm_ir::EmitCallToIntrinsic( - llvm::Intrinsic::sin, {EmitExtractImag(operand_value)}, - {EmitExtractImag(operand_value)->getType()}, ir_builder_); + TF_ASSIGN_OR_RETURN( + auto exp_a, EmitExp(component_type, EmitExtractReal(operand_value))); + TF_ASSIGN_OR_RETURN( + auto cos_b, EmitCos(component_type, EmitExtractImag(operand_value))); + TF_ASSIGN_OR_RETURN( + auto sin_b, EmitSin(component_type, EmitExtractImag(operand_value))); return EmitComposeComplex(op, ir_builder_->CreateFMul(exp_a, cos_b), ir_builder_->CreateFMul(exp_a, sin_b)); } @@ -356,16 +521,13 @@ StatusOr ElementalIrEmitter::EmitComplexUnaryOp( auto a = EmitExtractReal(operand_value); auto b = EmitExtractImag(operand_value); auto type = a->getType(); - auto exp_b = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::exp, {b}, - {type}, ir_builder_); + TF_ASSIGN_OR_RETURN(auto exp_b, EmitExp(component_type, b)); auto half_exp_b = ir_builder_->CreateFMul(llvm::ConstantFP::get(type, 0.5), exp_b); auto half_exp_neg_b = ir_builder_->CreateFDiv(llvm::ConstantFP::get(type, 0.5), exp_b); - auto cos_a = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::cos, {a}, - {type}, ir_builder_); - auto sin_a = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sin, {a}, - {type}, ir_builder_); + TF_ASSIGN_OR_RETURN(auto cos_a, EmitCos(component_type, a)); + TF_ASSIGN_OR_RETURN(auto sin_a, EmitSin(component_type, a)); return EmitComposeComplex( op, ir_builder_->CreateFMul( @@ -386,16 +548,13 @@ StatusOr ElementalIrEmitter::EmitComplexUnaryOp( auto a = EmitExtractReal(operand_value); auto b = EmitExtractImag(operand_value); auto type = a->getType(); - auto exp_b = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::exp, {b}, - {type}, ir_builder_); + TF_ASSIGN_OR_RETURN(auto exp_b, EmitExp(component_type, b)); auto half_exp_b = ir_builder_->CreateFMul(llvm::ConstantFP::get(type, 0.5), exp_b); auto half_exp_neg_b = ir_builder_->CreateFDiv(llvm::ConstantFP::get(type, 0.5), exp_b); - auto cos_a = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::cos, {a}, - {type}, ir_builder_); - auto sin_a = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sin, {a}, - {type}, ir_builder_); + TF_ASSIGN_OR_RETURN(auto cos_a, EmitCos(component_type, a)); + TF_ASSIGN_OR_RETURN(auto sin_a, EmitSin(component_type, a)); return EmitComposeComplex( op, ir_builder_->CreateFMul( @@ -403,6 +562,58 @@ StatusOr ElementalIrEmitter::EmitComplexUnaryOp( ir_builder_->CreateFMul( cos_a, ir_builder_->CreateFSub(half_exp_b, half_exp_neg_b))); } + case HloOpcode::kTanh: { + /* + tanh=(exp(x)-exp(-x)) / (exp(x)+exp(-x)) + e^(a+bi) = e^a*(cos(b)+sin(b)i) + so tanh=(((cos(b)+sin(b)i)e^a - (cos(-b)+sin(-b)i)e^-a)) / + (((cos(b)+sin(b)i)e^a + (cos(-b)+sin(-b)i)e^-a)) + cos(b)=cos(-b), sin(-b)=-sin(b) + so tanh=(((cos(b)+sin(b)i)e^a - (cos(b)-sin(b)i)e^-a)) / + (((cos(b)+sin(b)i)e^a + (cos(b)-sin(b)i)e^-a)) + =(cos(b)e^a+i*sin(b)e^a + cos(b)(-e^-a)+i*sin(b)e^-a) / + (cos(b)e^a+i*sin(b)e^a + cos(b)e^-a+i*sin(b)(-e^-a)) + =(cos(b)(e^a-e^-a) + i*sin(b)(e^a+e^-a)) / + (cos(b)(e^a+e^-a) + i*sin(b)(e^a-e^-a)) + This is a complex division, so we can multiply by denom_conj/denom_conj + =(cos(b)(e^a-e^-a) + i*sin(b)(e^a+e^-a)) * + (cos(b)(e^a+e^-a) - i*sin(b)(e^a-e^-a)) / + ((cos(b)(e^a+e^-a))^2 + (sin(b)(e^a-e^-a))^2) + =(cos(b)^2(e^(2a)-e^(-2a)) + sin(b)^2(e^(2a)-e^(-2a)) + + i*(cos(b)sin(b)(e^a+e^-a)^2 - cos(b)sin(b)(e^a-e^-a)^2)) / + ((cos(b)(e^a+e^-a))^2 + (sin(b)(e^a-e^-a))^2) + */ + auto a = EmitExtractReal(operand_value); + auto b = EmitExtractImag(operand_value); + TF_ASSIGN_OR_RETURN(auto exp_a, EmitExp(component_type, a)); + TF_ASSIGN_OR_RETURN(auto cos_b, EmitCos(component_type, b)); + TF_ASSIGN_OR_RETURN(auto sin_b, EmitSin(component_type, b)); + auto exp_neg_a = ir_builder_->CreateFDiv( + llvm::ConstantFP::get(exp_a->getType(), 1), exp_a); + auto exp_2a_minus_exp_neg_2a = ir_builder_->CreateFSub( + ir_builder_->CreateFMul(exp_a, exp_a), + ir_builder_->CreateFMul(exp_neg_a, exp_neg_a)); + auto cos_b_sq = ir_builder_->CreateFMul(cos_b, cos_b); + auto sin_b_sq = ir_builder_->CreateFMul(sin_b, sin_b); + auto real_num = ir_builder_->CreateFAdd( + ir_builder_->CreateFMul(cos_b_sq, exp_2a_minus_exp_neg_2a), + ir_builder_->CreateFMul(sin_b_sq, exp_2a_minus_exp_neg_2a)); + auto cos_b_sin_b = ir_builder_->CreateFMul(cos_b, sin_b); + auto exp_a_plus_exp_neg_a = ir_builder_->CreateFAdd(exp_a, exp_neg_a); + auto exp_a_plus_exp_neg_a_sq = + ir_builder_->CreateFMul(exp_a_plus_exp_neg_a, exp_a_plus_exp_neg_a); + auto exp_a_minus_exp_neg_a = ir_builder_->CreateFSub(exp_a, exp_neg_a); + auto exp_a_minus_exp_neg_a_sq = + ir_builder_->CreateFMul(exp_a_minus_exp_neg_a, exp_a_minus_exp_neg_a); + auto imag_num = ir_builder_->CreateFMul( + cos_b_sin_b, ir_builder_->CreateFSub(exp_a_plus_exp_neg_a_sq, + exp_a_minus_exp_neg_a_sq)); + auto denom = ir_builder_->CreateFAdd( + ir_builder_->CreateFMul(cos_b_sq, exp_a_plus_exp_neg_a_sq), + ir_builder_->CreateFMul(sin_b_sq, exp_a_minus_exp_neg_a_sq)); + return EmitComposeComplex(op, ir_builder_->CreateFDiv(real_num, denom), + ir_builder_->CreateFDiv(imag_num, denom)); + } case HloOpcode::kAbs: { auto sum_sq = ir_builder_->CreateFAdd( ir_builder_->CreateFMul(EmitExtractReal(operand_value), @@ -449,7 +660,8 @@ StatusOr ElementalIrEmitter::EmitBinaryOp( const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) const { PrimitiveType operand_type = op->operand(0)->shape().element_type(); - if (lhs_value->getType()->isIntegerTy()) { + if (ShapeUtil::ElementIsIntegral(op->operand(0)->shape()) || + operand_type == PRED) { return EmitIntegerBinaryOp( op, lhs_value, rhs_value, primitive_util::IsSignedIntegralType(operand_type)); @@ -464,7 +676,6 @@ StatusOr ElementalIrEmitter::EmitFloatBinaryOp( const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) const { switch (op->opcode()) { - // case HloOpcode::kAtan2: // TODO(b/65209142): CPU atan2 support case HloOpcode::kComplex: return EmitComposeComplex(op, lhs_value, rhs_value); case HloOpcode::kAdd: @@ -508,10 +719,9 @@ StatusOr ElementalIrEmitter::EmitFloatBinaryOp( case HloOpcode::kMinimum: return EmitFloatMin(lhs_value, rhs_value); case HloOpcode::kPower: - return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::pow, - {lhs_value, rhs_value}, - {lhs_value->getType()}, ir_builder_); - + return EmitPow(op->shape().element_type(), lhs_value, rhs_value); + case HloOpcode::kAtan2: + return EmitAtan2(op->shape().element_type(), lhs_value, rhs_value); default: return Unimplemented("binary floating point op '%s'", HloOpcodeString(op->opcode()).c_str()); @@ -607,9 +817,40 @@ StatusOr ElementalIrEmitter::EmitComplexBinaryOp( EmitExtractImag(lhs_value), EmitExtractImag(rhs_value), ir_builder_)); - // TODO(b/65209142): requires arg(z) -> requires atan|atan2 intrinsic - // case HloOpcode::kPower: - // // (a+bi)^(c+di) = exp(i(c+di)*arg(a+bi)) * (a*a+b*b)^(c/2+di/2) + case HloOpcode::kPower: { + // (a+bi)^(c+di) = + // (a*a+b*b)^(0.5c) * exp(-d*atan2(b,a)) * (cos(q) + i*sin(q)), + // where q = c*atan2(b,a)+0.5d*ln(a*a+b*b) + PrimitiveType component_type = + primitive_util::ComplexComponentType(op->shape().element_type()); + auto a = EmitExtractReal(lhs_value); + auto b = EmitExtractImag(lhs_value); + auto c = EmitExtractReal(rhs_value); + auto d = EmitExtractImag(rhs_value); + auto aa_p_bb = ir_builder_->CreateFAdd(ir_builder_->CreateFMul(a, a), + ir_builder_->CreateFMul(b, b)); + auto one_half = llvm::ConstantFP::get(a->getType(), 0.5); + auto half_c = ir_builder_->CreateFMul(one_half, c); + + TF_ASSIGN_OR_RETURN(auto aa_p_bb_to_half_c, + EmitPow(component_type, aa_p_bb, half_c)); + auto neg_d = ir_builder_->CreateFNeg(d); + TF_ASSIGN_OR_RETURN(auto arg_lhs, EmitAtan2(component_type, b, a)); + auto neg_d_arg_lhs = ir_builder_->CreateFMul(neg_d, arg_lhs); + TF_ASSIGN_OR_RETURN(auto e_to_neg_d_arg_lhs, + EmitExp(component_type, neg_d_arg_lhs)); + auto coeff = + ir_builder_->CreateFMul(aa_p_bb_to_half_c, e_to_neg_d_arg_lhs); + TF_ASSIGN_OR_RETURN(auto ln_aa_p_bb, EmitLog(component_type, aa_p_bb)); + auto half_d = ir_builder_->CreateFMul(one_half, d); + auto q = + ir_builder_->CreateFAdd(ir_builder_->CreateFMul(c, arg_lhs), + ir_builder_->CreateFMul(half_d, ln_aa_p_bb)); + TF_ASSIGN_OR_RETURN(auto cos_q, EmitCos(component_type, q)); + TF_ASSIGN_OR_RETURN(auto sin_q, EmitSin(component_type, q)); + return EmitComposeComplex(op, ir_builder_->CreateFMul(coeff, cos_q), + ir_builder_->CreateFMul(coeff, sin_q)); + } default: return Unimplemented("binary complex op '%s'", HloOpcodeString(op->opcode()).c_str()); @@ -712,116 +953,51 @@ StatusOr ElementalIrEmitter::EmitErfcInv( return EmitErfInv(prim_type, ir_builder_->CreateFSub(one, value)); } -StatusOr ElementalIrEmitter::EmitReducePrecision( - const HloInstruction* hlo, llvm::Value* x) const { - if (hlo->operand(0)->shape().element_type() != F32) { - return Unimplemented("reduce-precision only implemented for F32"); - } - - // Integer and float types for casting and constant generation. - llvm::Type* float_type = x->getType(); - llvm::IntegerType* int_type = ir_builder_->getInt32Ty(); - - // Cast the input value to an integer for bitwise manipulation. - llvm::Value* x_as_int = ir_builder_->CreateBitCast(x, int_type); - - if (hlo->mantissa_bits() < 23) { - // Last remaining mantissa bit. - const uint32_t last_mantissa_bit_mask = 1u << (23 - hlo->mantissa_bits()); - - // Compute rounding bias for round-to-nearest with ties to even. This is - // equal to a base value of 0111... plus one bit if the last remaining - // mantissa bit is 1. - const uint32_t base_rounding_bias = (last_mantissa_bit_mask >> 1) - 1; - llvm::Value* x_last_mantissa_bit = ir_builder_->CreateLShr( - ir_builder_->CreateAnd( - x_as_int, llvm::ConstantInt::get(int_type, last_mantissa_bit_mask)), - (23 - hlo->mantissa_bits())); - llvm::Value* x_rounding_bias = ir_builder_->CreateAdd( - x_last_mantissa_bit, - llvm::ConstantInt::get(int_type, base_rounding_bias)); - - // Add rounding bias, and mask out truncated bits. Note that the case - // where adding the rounding bias overflows into the exponent bits is - // correct; the non-masked mantissa bits will all be zero, and the - // exponent will be incremented by one. - const uint32_t truncation_mask = ~(last_mantissa_bit_mask - 1); - x_as_int = ir_builder_->CreateAdd(x_as_int, x_rounding_bias); - x_as_int = ir_builder_->CreateAnd( - x_as_int, llvm::ConstantInt::get(int_type, truncation_mask)); - } - - if (hlo->exponent_bits() < 8) { - // Masks for f32 values. - const uint32_t f32_sign_bit_mask = 1u << 31; - const uint32_t f32_exp_bits_mask = 0xffu << 23; - - // An exponent of 2^(n-1)-1 -- that is, 0111... with the zero in the most- - // significant bit -- is equal to 1.0f for all exponent sizes. Adding - // 2^(n-1)-1 to this gives us the highest non-infinite exponent for a bit- - // size of n, and subtracting 2^(n-1)-1 from this gives us the lowest' - // exponent (corresponding to 0.0f). - // - // Thus, the f32 exponent corresponding to the highest non-infinite - // exponent for a bit size of n is (2^7-1) + 2^(n-1)-1, and the f32 - // exponent corresponding to the lowest exponent for a bit size of n is - // (2^7-1) - 2^(n-1)-1. - // - // Note that we have already checked that exponents_bits >= 1. - const uint32_t f32_exponent_bias = (1 << 7) - 1; - const uint32_t reduced_exponent_bias = - (1 << (hlo->exponent_bits() - 1)) - 1; - const uint32_t reduced_max_exponent = - f32_exponent_bias + reduced_exponent_bias; - const uint32_t reduced_min_exponent = - f32_exponent_bias - reduced_exponent_bias; +StatusOr ElementalIrEmitter::EmitLog(PrimitiveType prim_type, + llvm::Value* value) const { + return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::log, {value}, + {value->getType()}, ir_builder_); +} - // Do we overflow or underflow? - llvm::Value* x_exponent = ir_builder_->CreateAnd( - x_as_int, llvm::ConstantInt::get(int_type, f32_exp_bits_mask)); - llvm::Value* x_overflows = ir_builder_->CreateICmpUGT( - x_exponent, - llvm::ConstantInt::get(int_type, reduced_max_exponent << 23)); - llvm::Value* x_underflows = ir_builder_->CreateICmpULE( - x_exponent, - llvm::ConstantInt::get(int_type, reduced_min_exponent << 23)); +StatusOr ElementalIrEmitter::EmitSin(PrimitiveType prim_type, + llvm::Value* value) const { + return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sin, {value}, + {value->getType()}, ir_builder_); +} - // Compute appropriately-signed values of zero and infinity. - llvm::Value* x_signed_zero = ir_builder_->CreateAnd( - x_as_int, llvm::ConstantInt::get(int_type, f32_sign_bit_mask)); - llvm::Value* x_signed_inf = ir_builder_->CreateOr( - x_signed_zero, llvm::ConstantInt::get(int_type, f32_exp_bits_mask)); +StatusOr ElementalIrEmitter::EmitCos(PrimitiveType prim_type, + llvm::Value* value) const { + return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::cos, {value}, + {value->getType()}, ir_builder_); +} - // Force to zero or infinity if overflow or underflow. (Note that this - // truncates all denormal values to zero, rather than rounding them.) - x_as_int = ir_builder_->CreateSelect(x_overflows, x_signed_inf, x_as_int); - x_as_int = ir_builder_->CreateSelect(x_underflows, x_signed_zero, x_as_int); - } +StatusOr ElementalIrEmitter::EmitExp(PrimitiveType prim_type, + llvm::Value* value) const { + return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::exp, {value}, + {value->getType()}, ir_builder_); +} - // Cast the result back to a floating-point type. - llvm::Value* result = ir_builder_->CreateBitCast(x_as_int, float_type); +StatusOr ElementalIrEmitter::EmitPow(PrimitiveType prim_type, + llvm::Value* lhs, + llvm::Value* rhs) const { + return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::pow, {lhs, rhs}, + {lhs->getType()}, ir_builder_); +} - // Correct result for NaN inputs. - // - // The exponent handling will "normalize" NaN values to infinities, which is - // undesirable (except in the case with no mantissa bits, in which case it - // is mandatory). This logic also handles cases where mantissa-rounding - // causes a NaN's mantissa to overflow into the exponent bits, which would - // otherwise create an erroneous zero value. - // - // If the fast-math flags are set to assume no NaNs, the comparison is likely - // to be optimized away, so there's no point in even emitting it. - if (!ir_builder_->getFastMathFlags().noNaNs()) { - llvm::Value* x_is_nan = ir_builder_->CreateFCmpUNO(x, x); +StatusOr ElementalIrEmitter::EmitAtan2(PrimitiveType prim_type, + llvm::Value* lhs, + llvm::Value* rhs) const { + return Unimplemented("atan2"); +} - if (hlo->mantissa_bits() > 0) { - result = ir_builder_->CreateSelect(x_is_nan, x, result); - } else { - result = ir_builder_->CreateSelect( - x_is_nan, llvm::ConstantFP::getInfinity(float_type), result); - } +StatusOr ElementalIrEmitter::EmitReducePrecision( + const HloInstruction* hlo, llvm::Value* x) const { + if (hlo->operand(0)->shape().element_type() != F32) { + return Unimplemented("reduce-precision only implemented for F32"); } - return result; + return EmitReducePrecisionFloat(x, /*exponent_bits=*/hlo->exponent_bits(), + /*mantissa_bits=*/hlo->mantissa_bits(), + ir_builder_); } StatusOr ElementalIrEmitter::EmitIntegerBinaryOp( @@ -1088,14 +1264,6 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeRngElementGenerator( get_next_uniform_float()))); return ir_builder_->CreateFAdd(ir_builder_->CreateFMul(r, s), m); } - case RNG_BERNOULLI: { - TF_ASSIGN_OR_RETURN(llvm::Value * p, - operand_to_generator.at(hlo->operand(0))(index)); - return ir_builder_->CreateZExt( - ir_builder_->CreateFCmpOLT(get_next_uniform_float(), p), - llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), - module_)); - } default: return InvalidArgument( "unhandled distribution %s", diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/elemental_ir_emitter.h index cccb498f82936283a215370787907b293827ff2d..1a48eb5fcb960b60d524ea56a43e15269576db76 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.h @@ -39,7 +39,7 @@ class ElementalIrEmitter { module_(module), hlo_module_config_(hlo_module_config) {} - virtual ~ElementalIrEmitter() {} + virtual ~ElementalIrEmitter() = default; virtual StatusOr EmitUnaryOp(const HloInstruction* op, llvm::Value* operand_value) const; @@ -92,6 +92,26 @@ class ElementalIrEmitter { virtual StatusOr EmitErfcInv(PrimitiveType prim_type, llvm::Value* value) const; + virtual StatusOr EmitAtan2(PrimitiveType prim_type, + llvm::Value* lhs, + llvm::Value* rhs) const; + + virtual StatusOr EmitLog(PrimitiveType prim_type, + llvm::Value* value) const; + + virtual StatusOr EmitSin(PrimitiveType prim_type, + llvm::Value* value) const; + + virtual StatusOr EmitCos(PrimitiveType prim_type, + llvm::Value* value) const; + + virtual StatusOr EmitExp(PrimitiveType prim_type, + llvm::Value* value) const; + + virtual StatusOr EmitPow(PrimitiveType prim_type, + llvm::Value* lhs, + llvm::Value* rhs) const; + virtual StatusOr EmitReducePrecision(const HloInstruction* hlo, llvm::Value* x) const; diff --git a/tensorflow/compiler/xla/service/executable.cc b/tensorflow/compiler/xla/service/executable.cc index 9c96d9eb30b5f9e51b7f5d82391c6b9f366898d6..21e7fbea291721dfc446bae2a7002a8ec2520be4 100644 --- a/tensorflow/compiler/xla/service/executable.cc +++ b/tensorflow/compiler/xla/service/executable.cc @@ -24,25 +24,25 @@ limitations under the License. #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/env.h" +using tensorflow::gtl::ArraySlice; + namespace xla { -StatusOr> +StatusOr>> Executable::ExecuteOnStreams( - tensorflow::gtl::ArraySlice run_options, - tensorflow::gtl::ArraySlice< - tensorflow::gtl::ArraySlice> - arguments) { + ArraySlice run_options, + ArraySlice> arguments) { TF_RET_CHECK(run_options.size() == arguments.size()); + std::vector> return_values(run_options.size()); + if (run_options.size() == 1) { - TF_ASSIGN_OR_RETURN(auto result, + TF_ASSIGN_OR_RETURN(return_values[0], ExecuteOnStream(&run_options[0], arguments[0], /*hlo_execution_profile=*/nullptr)); - return std::vector({result}); + return std::move(return_values); } - std::vector return_values( - run_options.size()); for (size_t i = 0; i < run_options.size(); ++i) { // We cannot BlockHostUntilDone() on the already-launched executions in case // of error, since if the executions communicate, the initially launched @@ -52,9 +52,77 @@ Executable::ExecuteOnStreams( } for (const auto& options : run_options) { TF_RET_CHECK(options.stream() != nullptr); - options.stream()->BlockHostUntilDone(); + TF_RETURN_IF_ERROR(options.stream()->BlockHostUntilDone()); + } + return std::move(return_values); +} + +StatusOr> Executable::ExecuteOnStreamWrapper( + const ServiceExecutableRunOptions* run_options, ExecutionProfile* profile, + ArraySlice arguments) { + perftools::gputools::Stream* stream = run_options->stream(); + std::unique_ptr timer; + if (profile != nullptr) { + timer.reset(new perftools::gputools::Timer(stream->parent())); + stream->InitTimer(timer.get()).ThenStartTimer(timer.get()); } - return return_values; + + VLOG(1) << "enqueueing executable on stream..."; + // If the profiling flag isn't enabled, we pass nullptr as the profile to + // indicate profiling is not requested. + std::unique_ptr profile_ptr = + module_config().debug_options().xla_hlo_profile() && + hlo_profiling_enabled() + ? MakeUnique(&hlo_profile_printer(), + &hlo_profile_index_map()) + : nullptr; + + StatusOr> return_value = + ExecuteOnStream(run_options, arguments, profile_ptr.get()); + + if (profile != nullptr) { + VLOG(1) << "enqueueing 'stop timer' and blocking host until done..."; + stream->ThenStopTimer(timer.get()); + TF_RETURN_IF_ERROR(stream->BlockHostUntilDone()); + VLOG(1) << "done with block-host-until-done"; + + // Merge in run-time profile information from execution_profile. + // + // TODO(b/71713097): This is buggy -- even though the mutex takes care of + // C++ level races, some other concurrent ExecuteOnStreamWrapper call could + // have rewritten the execution_profile before we get to it. + profile->MergeFrom(execution_profile()); + + // Overall execution time (in nanoseconds) from the executor timer. + if (stream->ok()) { + // Don't read timer->Nanoseconds() if the stream isn't OK -- that's + // illegal. + profile->set_compute_and_transfer_time_ns(timer->Nanoseconds()); + } + + // TODO(b/28123297): On GPU we end up including transfer time in + // the compute time this way. Instead, we should get the correct + // value by measuring it. Setting the field here at least lets + // benchmarks provide *some* value for GPU computations. + // + // TODO(b/28447609): The value in compute_and_transfer_time_ns is actually + // the compute time without the transfer time, so this way we get the + // correct compute time. We should instead have the correct value for + // compute_and_transfer_time and set compute_time to the compute time. + if (profile->compute_time_ns() == 0) { + profile->set_compute_time_ns(profile->compute_and_transfer_time_ns()); + } + } + + if (profile_ptr != nullptr) { + XLA_LOG_LINES( + tensorflow::INFO, + profile_ptr->ToString(stream->parent()->GetDeviceDescription())); + hlo_graph_dumper::MaybeDumpHloModule(module(), "Service::Execute", + profile_ptr.get()); + } + + return return_value; } Status Executable::DumpSessionModule() { diff --git a/tensorflow/compiler/xla/service/executable.h b/tensorflow/compiler/xla/service/executable.h index 08862308c90af736c1adcaa9438973f858852506..5ecfdffe211c571b1bb2bc30ff2acd3021c735ae 100644 --- a/tensorflow/compiler/xla/service/executable.h +++ b/tensorflow/compiler/xla/service/executable.h @@ -61,16 +61,7 @@ class Executable { // If the hlo_execution_profile is provided as non-nullptr, profiling will be // enabled. // - // Returns the device memory region that a successful execution would - // populate. - virtual StatusOr ExecuteOnStream( - const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice - arguments, - HloExecutionProfile* hlo_execution_profile) = 0; - - // Overload of ExecuteOnStream which returns and takes arguments as - // ShapedBuffers. Used for LocalService execution. + // Returns a shaped buffer containing the result of the computation. virtual StatusOr> ExecuteOnStream( const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments, @@ -78,21 +69,19 @@ class Executable { // Same as ExecuteOnStream(), but this call is non-blocking and returns as // soon as all of the operations are enqueued for launch on the stream. - virtual StatusOr ExecuteAsyncOnStream( + virtual StatusOr> ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice - arguments) = 0; + tensorflow::gtl::ArraySlice arguments) = 0; // Same as ExecuteOnStream(), but runs this executable on multiple // streams. arguments[i] contains the arguments to the execution on // run_options[i]->stream() and the returned value is at index i of the // returned vector. - virtual StatusOr> - ExecuteOnStreams( + virtual StatusOr>> ExecuteOnStreams( tensorflow::gtl::ArraySlice run_options, tensorflow::gtl::ArraySlice< - tensorflow::gtl::ArraySlice> + tensorflow::gtl::ArraySlice> arguments); // Populates `hlo_execution_profile` from `executor`. This is implicit in any @@ -107,13 +96,10 @@ class Executable { // Convenience wrapper for calling Executable::ExecuteOnStream. Sets up a // timer for the execution, sets up HLO profiling if enabled, and fills in the - // given ExecutionProfile if non-null. The ExecuteOnStream overloads have - // different argument types and return types, so this method is templated on - // argument type and return type of the execute function. - template - StatusOr ExecuteOnStreamWrapper( + // given ExecutionProfile if non-null. + StatusOr> ExecuteOnStreamWrapper( const ServiceExecutableRunOptions* run_options, ExecutionProfile* profile, - const ArgT& arguments); + tensorflow::gtl::ArraySlice arguments); // Returns the ExecutionProfile from executing on the device. This includes // the number of cycles taken for the computation or the compilation time. @@ -197,66 +183,6 @@ class Executable { std::unique_ptr hlo_profile_index_map_; }; -template -StatusOr Executable::ExecuteOnStreamWrapper( - const ServiceExecutableRunOptions* run_options, ExecutionProfile* profile, - const ArgT& arguments) { - perftools::gputools::Stream* stream = run_options->stream(); - std::unique_ptr timer; - if (profile != nullptr) { - timer.reset(new perftools::gputools::Timer(stream->parent())); - stream->InitTimer(timer.get()).ThenStartTimer(timer.get()); - } - - VLOG(1) << "enqueueing executable on stream..."; - // If the profiling flag isn't enabled, we pass nullptr as the profile to - // indicate profiling is not requested. - std::unique_ptr profile_ptr = - module_config().debug_options().xla_hlo_profile() && - hlo_profiling_enabled() - ? MakeUnique(&hlo_profile_printer(), - &hlo_profile_index_map()) - : nullptr; - - auto return_value = - ExecuteOnStream(run_options, arguments, profile_ptr.get()); - - if (profile != nullptr) { - VLOG(1) << "enqueueing 'stop timer' and blocking host until done..."; - stream->ThenStopTimer(timer.get()).BlockHostUntilDone(); - VLOG(1) << "done with block-host-until-done"; - - // Merge in run-time profile information from execution_profile. - profile->MergeFrom(execution_profile()); - - // Overall execution time (in nanoseconds) from the executor timer. - profile->set_compute_and_transfer_time_ns(timer->Nanoseconds()); - - // TODO(b/28123297): On GPU we end up including transfer time in - // the compute time this way. Instead, we should get the correct - // value by measuring it. Setting the field here at least lets - // benchmarks provide *some* value for GPU computations. - // - // TODO(b/28447609): The value in compute_and_transfer_time_ns is actually - // the compute time without the transfer time, so this way we get the - // correct compute time. We should instead have the correct value for - // compute_and_transfer_time and set compute_time to the compute time. - if (profile->compute_time_ns() == 0) { - profile->set_compute_time_ns(profile->compute_and_transfer_time_ns()); - } - } - - if (profile_ptr != nullptr) { - XLA_LOG_LINES( - tensorflow::INFO, - profile_ptr->ToString(stream->parent()->GetDeviceDescription())); - hlo_graph_dumper::MaybeDumpHloModule(module(), "Service::Execute", - profile_ptr.get()); - } - - return return_value; -} - } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_EXECUTABLE_H_ diff --git a/tensorflow/compiler/xla/service/execution_tracker.cc b/tensorflow/compiler/xla/service/execution_tracker.cc index c225e62e3e11d2d01251b0f92272b0949eff8af1..2f0b9ed2bd98fbea4e67c0a30d5aa41ff6a06979 100644 --- a/tensorflow/compiler/xla/service/execution_tracker.cc +++ b/tensorflow/compiler/xla/service/execution_tracker.cc @@ -39,9 +39,7 @@ AsyncExecution::AsyncExecution(Backend* backend, tensorflow::Status AsyncExecution::BlockUntilDone() const { for (auto& stream : streams_) { - if (!stream->BlockHostUntilDone()) { - return InternalError("failed to block until done"); - } + TF_RETURN_IF_ERROR(stream->BlockHostUntilDone()); } return tensorflow::Status::OK(); } diff --git a/tensorflow/compiler/xla/service/flatten_call_graph.cc b/tensorflow/compiler/xla/service/flatten_call_graph.cc index dfba22a6c4c5cf071c2cd8621643b8da6587ee3b..2b6caa149439a86d6d047605099bc3ff7b295a8e 100644 --- a/tensorflow/compiler/xla/service/flatten_call_graph.cc +++ b/tensorflow/compiler/xla/service/flatten_call_graph.cc @@ -26,7 +26,10 @@ namespace xla { namespace { -// Helper to replace the called computation at a while- or call-instruction. +// Helper to replace the called computation at a while-, call-, or +// conditional-instruction. This function replaces exactly one instance of +// 'computation' with 'new_computation' even if 'instruction' calls +// 'computation' more than once. void ReplaceCalledComputation(HloInstruction* instruction, HloComputation* computation, HloComputation* new_computation) { @@ -45,6 +48,15 @@ void ReplaceCalledComputation(HloInstruction* instruction, instruction->set_to_apply(new_computation); break; } + case HloOpcode::kConditional: { + if (computation == instruction->true_computation()) { + instruction->set_true_computation(new_computation); + } else { + CHECK_EQ(computation, instruction->false_computation()); + instruction->set_false_computation(new_computation); + } + break; + } default: LOG(FATAL) << "unexpected opcode: " << HloOpcodeString(instruction->opcode()); diff --git a/tensorflow/compiler/xla/service/flatten_call_graph_test.cc b/tensorflow/compiler/xla/service/flatten_call_graph_test.cc index a68e90b7d009890012f94baa790d911871c9c960..d3854b40de3572a60df1ad99d8a4589f59ad7194 100644 --- a/tensorflow/compiler/xla/service/flatten_call_graph_test.cc +++ b/tensorflow/compiler/xla/service/flatten_call_graph_test.cc @@ -223,5 +223,35 @@ TEST_F(FlattenCallGraphTest, FlattenCalls) { EXPECT_EQ(1, b_node.caller_callsites().size()); } +TEST_F(FlattenCallGraphTest, FlattenCallsInConditional) { + auto module = CreateNewModule(); + HloComputation* sub_computation = + module->AddEmbeddedComputation(MakeScalarComputation()); + + // Create entry computation, which is a conditional that has the same + // computation in the true and false branch. + HloComputation::Builder builder(TestName()); + auto pred = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(true))); + auto constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(56.0f))); + auto constant2 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(12.0f))); + builder.AddInstruction(HloInstruction::CreateConditional( + kScalarShape, pred, constant1, sub_computation, constant2, + sub_computation)); + module->AddEntryComputation(builder.Build()); + EXPECT_EQ(2, module->computation_count()); + + TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get())); + EXPECT_TRUE(result); + std::unique_ptr call_graph = CallGraph::Build(module.get()); + // The true and false computations must now be different. + EXPECT_EQ(3, module->computation_count()); + + const CallGraphNode& sub_node = call_graph->GetNode(sub_computation); + EXPECT_EQ(1, sub_node.caller_callsites().size()); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.cc b/tensorflow/compiler/xla/service/generic_transfer_manager.cc index 74aa77b4f165be76fbc0a8aa1a4a7e90a8e9acec..78dc0ad4fcd167c93f19d0c2b18ea72d666897ef 100644 --- a/tensorflow/compiler/xla/service/generic_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/generic_transfer_manager.cc @@ -51,83 +51,7 @@ se::Platform::Id GenericTransferManager::PlatformId() const { return platform_id_; } -Status GenericTransferManager::TransferLiteralFromDevice( - se::StreamExecutor* executor, const se::DeviceMemoryBase& source, - const Shape& device_shape, const Shape& literal_shape, Literal* literal) { - VLOG(2) << "transferring literal shape from device: " - << ShapeUtil::HumanString(literal_shape) - << "; device location: " << source.opaque(); - TF_RET_CHECK(ShapeUtil::Compatible(device_shape, literal_shape)); - - // Tuples are a special case and contain one or more shapes inside of them to - // an arbitrary nesting depth. - if (device_shape.element_type() == TUPLE) { - *literal->mutable_shape() = literal_shape; - TF_ASSIGN_OR_RETURN( - std::vector element_buffers, - ShallowCopyTupleFromDevice(executor, source, device_shape)); - TF_RET_CHECK(element_buffers.size() == - ShapeUtil::TupleElementCount(device_shape)); - for (int64 i = 0; i < element_buffers.size(); ++i) { - const Shape& element_device_shape = device_shape.tuple_shapes(i); - const Shape& element_literal_shape = literal_shape.tuple_shapes(i); - Literal* element_literal = literal->add_tuple_literals(); - // Recursively call TransferFromDevice to copy over the data in the - // element array. - TF_RETURN_IF_ERROR(TransferLiteralFromDevice( - executor, element_buffers[i], /*device_shape=*/element_device_shape, - /*literal_shape=*/element_literal_shape, element_literal)); - } - return Status::OK(); - } - - *literal->mutable_shape() = device_shape; - literal->Reserve(ShapeUtil::ElementsIn(device_shape)); - TF_RETURN_IF_ERROR(TransferBufferFromDevice( - executor, source, /*size=*/ShapeUtil::ByteSizeOf(device_shape), - /*destination=*/literal->MutableInternalData())); - if (!ShapeUtil::Equal(literal_shape, device_shape)) { - *literal = std::move(*literal->Relayout(literal_shape.layout())); - } - TF_RET_CHECK(ShapeUtil::Equal(literal_shape, literal->shape())); - return Status::OK(); -} - -StatusOr> -GenericTransferManager::ShallowCopyTupleFromDevice( - se::StreamExecutor* executor, const se::DeviceMemoryBase& source, - const Shape& shape) { - TF_RET_CHECK(ShapeUtil::IsTuple(shape)); - - // For devices which use the GenericTransferManager, a tuple is stored as an - // array of pointers to buffers. Copy the contents of the tuple buffer into - // a vector of void* pointers. - std::vector element_pointers(ShapeUtil::TupleElementCount(shape), - nullptr); - int64 tuple_size = ShapeUtil::ByteSizeOf(shape, pointer_size_); - auto copy_status = executor->SynchronousMemcpyD2H(source, tuple_size, - element_pointers.data()); - if (!copy_status.ok()) { - return AddStatus( - Status(static_cast(copy_status.code()), - copy_status.error_message()), - "failed transfer of tuple buffer " + ShapeUtil::HumanString(shape)); - } - - // Create a DeviceMemoryBase from each void* pointer. - std::vector destination; - for (size_t i = 0; i < element_pointers.size(); ++i) { - if (element_pointers[i] == nullptr && - !ShapeUtil::HasZeroElements(shape.tuple_shapes(i))) { - return FailedPrecondition("tuple contains nullptr at element %lu", i); - } - destination.emplace_back(element_pointers[i], - GetByteSizeRequirement(shape.tuple_shapes(i))); - } - return std::move(destination); -} - -Status GenericTransferManager::WriteTuplePointersToDevice( +Status GenericTransferManager::WriteSingleTupleIndexTable( perftools::gputools::StreamExecutor* executor, tensorflow::gtl::ArraySlice elements, const Shape& shape, perftools::gputools::DeviceMemoryBase* region) { @@ -145,16 +69,19 @@ StatusOr> GenericTransferManager::TransferLiteralFromDevice( se::StreamExecutor* executor, const ShapedBuffer& device_buffer) { VLOG(2) << "transferring literal from device ordinal " - << executor->device_ordinal() << "; device shape: " - << ShapeUtil::HumanStringWithLayout(device_buffer.shape()) - << "; opaque: " << device_buffer.buffer(/*index=*/{}).opaque(); + << executor->device_ordinal() << "; device buffer: " << device_buffer; TF_RET_CHECK(executor->device_ordinal() == device_buffer.device_ordinal()); + // The on-host and on-device shape should always be the same for the generic + // transfer manager. + TF_RET_CHECK(ShapeUtil::Equal(device_buffer.on_device_shape(), + device_buffer.on_host_shape())); + std::unique_ptr literal = - Literal::CreateFromShape(device_buffer.shape()); + Literal::CreateFromShape(device_buffer.on_host_shape()); TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( - device_buffer.shape(), + device_buffer.on_host_shape(), [&](const Shape& subshape, const ShapeIndex& index) -> Status { if (!ShapeUtil::IsTuple(subshape)) { TF_RETURN_IF_ERROR(TransferBufferFromDevice( @@ -162,7 +89,7 @@ GenericTransferManager::TransferLiteralFromDevice( /*source=*/device_buffer.buffer(index), /*size=*/GetByteSizeRequirement(subshape), /*destination=*/ - literal->GetSubliteral(index).MutableInternalData())); + literal->untyped_data(index))); } return Status::OK(); @@ -175,33 +102,39 @@ Status GenericTransferManager::TransferLiteralToDevice( const ShapedBuffer& device_buffer) { const Shape& shape = literal.shape(); VLOG(2) << "transferring literal shape to device: " - << ShapeUtil::HumanString(shape) << "; device location: " - << device_buffer.buffer(/*index=*/{}).opaque(); + << ShapeUtil::HumanString(shape) + << "; device buffer: " << device_buffer; + + // The on-host and on-device shape should always be the same for the generic + // transfer manager. + TF_RET_CHECK(ShapeUtil::Equal(device_buffer.on_device_shape(), + device_buffer.on_host_shape())); - TF_RET_CHECK(ShapeUtil::Compatible(literal.shape(), device_buffer.shape())); + TF_RET_CHECK( + ShapeUtil::Compatible(literal.shape(), device_buffer.on_host_shape())); TF_RET_CHECK(executor->device_ordinal() == device_buffer.device_ordinal()); TF_RETURN_IF_ERROR(WriteTupleIndexTables(executor, device_buffer)); return ShapeUtil::ForEachSubshapeWithStatus( - device_buffer.shape(), + device_buffer.on_host_shape(), [&](const Shape& device_subshape, const ShapeIndex& index) -> Status { se::DeviceMemoryBase device_memory = device_buffer.buffer(index); if (ShapeUtil::IsArray(device_subshape)) { TF_RET_CHECK(GetByteSizeRequirement(device_subshape) == device_memory.size()); // Element is array-shaped: transfer array data to device buffer. - const Literal& subliteral = literal.GetSubliteral(index); + const auto subliteral = LiteralView::Create(literal, index); std::unique_ptr relayed_out_literal; const void* source; if (LayoutUtil::Equal(device_subshape.layout(), subliteral.shape().layout())) { - source = subliteral.InternalData(); + source = subliteral.untyped_data(); } else { // Relayout data before transferring. relayed_out_literal = subliteral.Relayout(device_subshape.layout(), /*shape_index=*/{}); - source = relayed_out_literal->InternalData(); + source = relayed_out_literal->untyped_data(); } return TransferBufferToDevice( executor, @@ -212,33 +145,6 @@ Status GenericTransferManager::TransferLiteralToDevice( }); } -Status GenericTransferManager::TransferLiteralToDevice( - se::StreamExecutor* executor, const Literal& literal, - se::DeviceMemoryBase* destination) { - const Shape& shape = literal.shape(); - VLOG(2) << "transferring literal shape to device: " - << ShapeUtil::HumanString(shape) - << "; device location: " << destination->opaque(); - - if (ShapeUtil::IsTuple(literal.shape())) { - std::vector tuple_elements_on_device; - for (const Literal& tuple_element : literal.tuple_literals()) { - se::DeviceMemoryBase allocation = executor->AllocateArray( - GetByteSizeRequirement(tuple_element.shape())); - TF_RETURN_IF_ERROR( - TransferLiteralToDevice(executor, tuple_element, &allocation)); - tuple_elements_on_device.push_back(allocation.opaque()); - } - return TransferBufferToDevice( - executor, tuple_elements_on_device.size() * sizeof(void*), - tuple_elements_on_device.data(), destination); - } - - return TransferBufferToDevice(executor, - /*size=*/GetByteSizeRequirement(shape), - /*source=*/literal.InternalData(), destination); -} - Status GenericTransferManager::TransferLiteralToInfeed( se::StreamExecutor* executor, const Literal& literal) { return Unimplemented("Generic transfer to Infeed"); diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.h b/tensorflow/compiler/xla/service/generic_transfer_manager.h index 50dca6aec5012f0b02cb54846b622f008600e48e..63a7c820cf4e5fbbdf870086a4fb5316ac50d10b 100644 --- a/tensorflow/compiler/xla/service/generic_transfer_manager.h +++ b/tensorflow/compiler/xla/service/generic_transfer_manager.h @@ -42,16 +42,6 @@ class GenericTransferManager : public TransferManager { perftools::gputools::Platform::Id PlatformId() const override; - Status TransferLiteralFromDevice( - perftools::gputools::StreamExecutor* executor, - const perftools::gputools::DeviceMemoryBase& source, - const Shape& device_shape, const Shape& literal_shape, - Literal* literal) override; - - Status TransferLiteralToDevice( - perftools::gputools::StreamExecutor* executor, const Literal& literal, - perftools::gputools::DeviceMemoryBase* destination) override; - StatusOr> TransferLiteralFromDevice( perftools::gputools::StreamExecutor* executor, const ShapedBuffer& device_buffer) override; @@ -62,9 +52,6 @@ class GenericTransferManager : public TransferManager { Status TransferLiteralToInfeed(perftools::gputools::StreamExecutor* executor, const Literal& literal) override; - Status TransferBufferToInfeed(perftools::gputools::StreamExecutor* executor, - int64 size, const void* source) override; - Status TransferLiteralFromOutfeed( perftools::gputools::StreamExecutor* executor, const Shape& literal_shape, Literal* literal) override; @@ -73,16 +60,13 @@ class GenericTransferManager : public TransferManager { tensorflow::gtl::ArraySlice executors) override; - StatusOr> - ShallowCopyTupleFromDevice( - perftools::gputools::StreamExecutor* executor, - const perftools::gputools::DeviceMemoryBase& source, - const Shape& shape) override; - int64 GetByteSizeRequirement(const Shape& shape) const override; protected: - Status WriteTuplePointersToDevice( + Status TransferBufferToInfeed(perftools::gputools::StreamExecutor* executor, + int64 size, const void* source) override; + + Status WriteSingleTupleIndexTable( perftools::gputools::StreamExecutor* executor, tensorflow::gtl::ArraySlice elements, diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index e57558b5788965214cadf5eab1024860f1a39ca1..d7ca0f6846834ae77569930325d3fc6b9fd5cca8 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -23,6 +23,15 @@ filegroup( load("//tensorflow:tensorflow.bzl", "tf_cc_test") +cc_library( + name = "gpu_constants", + srcs = ["gpu_constants.cc"], + hdrs = ["gpu_constants.h"], + deps = [ + "//tensorflow/compiler/xla:types", + ], +) + cc_library( name = "partition_assignment", srcs = [ @@ -123,6 +132,7 @@ cc_library( ], deps = [ ":elemental_ir_emitter", + ":gpu_constants", ":gpu_executable", ":hlo_to_ir_bindings", ":ir_emission_utils", @@ -203,6 +213,7 @@ cc_library( srcs = ["buffer_allocations.cc"], hdrs = ["buffer_allocations.h"], deps = [ + ":gpu_constants", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", @@ -219,6 +230,8 @@ cc_library( srcs = [ "convolution_thunk.cc", "copy_thunk.cc", + "cudnn_batchnorm_thunk.cc", + "fft_thunk.cc", "for_thunk.cc", "gemm_thunk.cc", "gpu_executable.cc", @@ -232,6 +245,8 @@ cc_library( hdrs = [ "convolution_thunk.h", "copy_thunk.h", + "cudnn_batchnorm_thunk.h", + "fft_thunk.h", "for_thunk.h", "gemm_thunk.h", "gpu_executable.h", @@ -246,6 +261,7 @@ cc_library( deps = [ ":buffer_allocations", ":infeed_manager", + ":ir_emission_utils", ":partition_assignment", ":stream_assignment", "//tensorflow/compiler/xla:array2d", @@ -269,6 +285,7 @@ cc_library( "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core/platform/default/build_config:cublas_plugin", "//tensorflow/core/platform/default/build_config:cudnn_plugin", + "//tensorflow/core/platform/default/build_config:cufft_plugin", "//tensorflow/core/platform/default/build_config:stream_executor_cuda", # build_cleaner: keep ], ) @@ -429,13 +446,15 @@ cc_library( deps = [ ":convolution_folding", ":fusion_merger", + ":gpu_constants", ":gpu_copy_insertion", ":gpu_executable", + ":gpu_hlo_support_checker", + ":gpu_layout_assignment", ":hlo_schedule", ":instruction_fusion", ":ir_emission_utils", ":ir_emitter", - ":layout_assignment", ":pad_insertion", ":partition_assignment", ":stream_assignment", @@ -445,10 +464,11 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service:algebraic_simplifier", - "//tensorflow/compiler/xla/service:batchnorm_rewriter", + "//tensorflow/compiler/xla/service:batchnorm_expander", "//tensorflow/compiler/xla/service:buffer_assignment", "//tensorflow/compiler/xla/service:buffer_liveness", "//tensorflow/compiler/xla/service:call_inliner", + "//tensorflow/compiler/xla/service:dot_decomposer", "//tensorflow/compiler/xla/service:executable", "//tensorflow/compiler/xla/service:flatten_call_graph", "//tensorflow/compiler/xla/service:hlo", @@ -467,11 +487,14 @@ cc_library( "//tensorflow/compiler/xla/service:transpose_folding", "//tensorflow/compiler/xla/service:tuple_simplifier", "//tensorflow/compiler/xla/service:while_loop_simplifier", + "//tensorflow/compiler/xla/service:zero_sized_hlo_elimination", + "//tensorflow/compiler/xla/service/gpu:cudnn_batchnorm_rewriter", "//tensorflow/compiler/xla/service/gpu/llvm_gpu_backend", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/core:cuda_libdevice_path", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "//tensorflow/core:regexp_internal", "//tensorflow/core:stream_executor_no_cuda", "@llvm//:core", "@llvm//:support", @@ -479,6 +502,19 @@ cc_library( alwayslink = True, # Contains compiler registration ) +cc_library( + name = "cudnn_batchnorm_rewriter", + srcs = ["cudnn_batchnorm_rewriter.cc"], + hdrs = ["cudnn_batchnorm_rewriter.h"], + deps = [ + ":ir_emission_utils", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_pass", + "@llvm//:core", + ], +) + cc_library( name = "infeed_manager", srcs = ["infeed_manager.cc"], @@ -492,9 +528,9 @@ cc_library( ) cc_library( - name = "layout_assignment", - srcs = ["layout_assignment.cc"], - hdrs = ["layout_assignment.h"], + name = "gpu_layout_assignment", + srcs = ["gpu_layout_assignment.cc"], + hdrs = ["gpu_layout_assignment.h"], deps = [ ":ir_emission_utils", "//tensorflow/compiler/xla:shape_util", @@ -508,17 +544,18 @@ cc_library( ) tf_cc_test( - name = "layout_assignment_test", - srcs = ["layout_assignment_test.cc"], + name = "gpu_layout_assignment_test", + srcs = ["gpu_layout_assignment_test.cc"], deps = [ - ":layout_assignment", + ":gpu_layout_assignment", + ":ir_emission_utils", "//tensorflow/compiler/xla:shape_layout", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:computation_layout", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", # build_cleaner: keep ], ) @@ -586,6 +623,32 @@ tf_cc_test( ], ) +cc_library( + name = "gpu_hlo_support_checker", + srcs = ["gpu_hlo_support_checker.cc"], + hdrs = ["gpu_hlo_support_checker.h"], + deps = [ + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo_pass", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "gpu_hlo_support_checker_test", + srcs = ["gpu_hlo_support_checker_test.cc"], + deps = [ + ":gpu_hlo_support_checker", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + ], +) + # ----------------------------------------------------------------------------- filegroup( diff --git a/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc b/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc index 9fdf717b5d463010e2709b6209c070f25555de72..ed78fef4113bd9f7048ca3c8c2d4e38c5ec4762a 100644 --- a/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc +++ b/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_constants.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" @@ -48,6 +49,15 @@ StatusOr> BufferAllocations::Builder::Build( // If buffer #i's address is already registered (e.g. external arguments or // result buffers), use that registered buffer. if (registered_buffers_.count(i)) { + se::DeviceMemoryBase address = FindOrDie(registered_buffers_, i); + if (reinterpret_cast(address.opaque()) % + kCudaMallocAlignBytes != + 0) { + return InternalError( + "Address of registered buffer %lld must be a multiple of %llx, but " + "was %p", + i, kCudaMallocAlignBytes, address.opaque()); + } buffer_allocations->SetBuffer(i, FindOrDie(registered_buffers_, i)); continue; } @@ -67,6 +77,14 @@ StatusOr> BufferAllocations::Builder::Build( tensorflow::strings::HumanReadableNumBytes(buffer_size).c_str(), i); } + if (reinterpret_cast(buffer_address.opaque()) % + kCudaMallocAlignBytes != + 0) { + return InternalError( + "Address returned by memory_allocator->Allocate must be a " + "multiple of %llx, but was %p", + kCudaMallocAlignBytes, buffer_address.opaque()); + } } buffer_allocations->SetBuffer(i, buffer_address); if (allocation.IsPreallocatedTempBuffer()) { diff --git a/tensorflow/compiler/xla/service/gpu/convolution_folding.cc b/tensorflow/compiler/xla/service/gpu/convolution_folding.cc index 828ae675d7ba60b4cee1c3f5312b069263d5a814..b0626ca3bc9f843e513d4727932f0e2d5fa37748 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_folding.cc +++ b/tensorflow/compiler/xla/service/gpu/convolution_folding.cc @@ -55,19 +55,7 @@ MatchBackwardFilter(HloInstruction* conv) { // v v // Convolution // conv - // | - // v - // Transpose (optional if identity transposition) CHECK_EQ(HloOpcode::kConvolution, conv->opcode()); - // If the forward convolution is followed by a transpose, we can fuse the - // transpose into the backward convolution as well. - HloInstruction* transpose = nullptr; - if (conv->user_count() == 1) { - HloInstruction* single_user = *conv->users().begin(); - if (single_user->opcode() == HloOpcode::kTranspose) { - transpose = single_user; - } - } // Step 2: match paddings and dimension numbers of the forward convolution. const ConvolutionDimensionNumbers& conv_dnums = @@ -75,6 +63,9 @@ MatchBackwardFilter(HloInstruction* conv) { auto input_batch_dim = conv_dnums.input_batch_dimension(); auto input_feature_dim = conv_dnums.input_feature_dimension(); auto input_spatial_dims = conv_dnums.input_spatial_dimensions(); + auto kernel_input_feature_dim = conv_dnums.kernel_input_feature_dimension(); + auto kernel_output_feature_dim = conv_dnums.kernel_output_feature_dimension(); + auto kernel_spatial_dims = conv_dnums.kernel_spatial_dimensions(); auto output_batch_dim = conv_dnums.output_batch_dimension(); auto output_feature_dim = conv_dnums.output_feature_dimension(); auto output_spatial_dims = conv_dnums.output_spatial_dimensions(); @@ -96,9 +87,14 @@ MatchBackwardFilter(HloInstruction* conv) { VLOG(1) << "Padding low should be non-negative."; return no_match_result; } + if (window_dim.window_reversal()) { + VLOG(1) << "Window reversal field not supported"; + return no_match_result; + } // Padding high will be checked in Step 3. } - if (transpose == nullptr && !window_util::HasWindowDilation(conv->window())) { + if (input_batch_dim == output_batch_dim && + !window_util::HasWindowDilation(conv->window())) { VLOG(1) << conv->ToString() << " is a regular forward convolution. No need " "to fold it to a backward filter convolution."; @@ -169,53 +165,32 @@ MatchBackwardFilter(HloInstruction* conv) { } } - // To make future HLO passes easier, we canonicalize the fused expression by - // adding an identity transposition if it's omitted in the pattern. - if (transpose == nullptr) { - // Create an identity transposition with the same rank as the forward - // convolution. - HloComputation* parent_computation = conv->parent(); - std::vector transpose_dimensions(ShapeUtil::Rank(conv->shape())); - std::iota(transpose_dimensions.begin(), transpose_dimensions.end(), 0); - transpose = - parent_computation->AddInstruction(HloInstruction::CreateTranspose( - conv->shape(), conv, transpose_dimensions)); - TF_CHECK_OK(conv->ReplaceAllUsesWith(transpose)); - } - // Restore the dimension numbers of the backward convolution from the forward // convolution. The two activation dimensions are reversed (batch and // feature). ConvolutionDimensionNumbers backward_conv_dnums; backward_conv_dnums.set_input_batch_dimension(input_feature_dim); backward_conv_dnums.set_input_feature_dimension(input_batch_dim); - backward_conv_dnums.set_output_batch_dimension(output_feature_dim); - backward_conv_dnums.set_output_feature_dimension(output_batch_dim); for (int i = 0; i < input_spatial_dims.size(); ++i) { backward_conv_dnums.add_input_spatial_dimensions(input_spatial_dims[i]); } - for (int i = 0; i < output_spatial_dims.size(); ++i) { - backward_conv_dnums.add_output_spatial_dimensions(output_spatial_dims[i]); + backward_conv_dnums.set_output_batch_dimension(kernel_input_feature_dim); + backward_conv_dnums.set_output_feature_dimension(kernel_output_feature_dim); + for (int i = 0; i < kernel_spatial_dims.size(); ++i) { + backward_conv_dnums.add_output_spatial_dimensions(kernel_spatial_dims[i]); } // The dimension numbering of the output of the forward convolution (before // transposition) is the same as that of the activations (according to the // semantics of kConvolution). The batch dimension of the activations should // be treated as the input feature dimension, and the feature dimension should // be treated as the output feature. - // - // The output of the forward convolution needs to be transposed to fit into - // the dimension numbering of the weight gradients. This transposition maps - // dimension i to PositionInContainer(transpose->dimensions(), i). - backward_conv_dnums.set_kernel_input_feature_dimension( - PositionInContainer(transpose->dimensions(), output_batch_dim)); - backward_conv_dnums.set_kernel_output_feature_dimension( - PositionInContainer(transpose->dimensions(), output_feature_dim)); + backward_conv_dnums.set_kernel_input_feature_dimension(output_batch_dim); + backward_conv_dnums.set_kernel_output_feature_dimension(output_feature_dim); for (int i = 0; i < output_spatial_dims.size(); ++i) { - backward_conv_dnums.add_kernel_spatial_dimensions( - PositionInContainer(transpose->dimensions(), output_spatial_dims[i])); + backward_conv_dnums.add_kernel_spatial_dimensions(output_spatial_dims[i]); } - return std::make_tuple(true, std::vector({transpose, conv}), + return std::make_tuple(true, std::vector({conv}), backward_conv_window, backward_conv_dnums); } @@ -275,6 +250,10 @@ MatchBackwardInput(HloInstruction* conv) { << " should have no window dilation."; return no_match_result; } + if (window_dim.window_reversal()) { + VLOG(1) << "Window reversal field not supported"; + return no_match_result; + } } const auto& input_spatial_dims = dnums.input_spatial_dimensions(); diff --git a/tensorflow/compiler/xla/service/gpu/convolution_folding_test.cc b/tensorflow/compiler/xla/service/gpu/convolution_folding_test.cc index 112c496e1f6bd17f89ac389ccf0256846dfa1971..34e6bdb117d47a3d7e1eb3bae5806e130e94ea79 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_folding_test.cc +++ b/tensorflow/compiler/xla/service/gpu/convolution_folding_test.cc @@ -46,18 +46,18 @@ class ConvolutionFoldingTest : public HloTestBase { // // TODO(jingyue): Add more tests on NCHW input order which TF also supports. tf_default_dnums_for_backward_filter_.set_input_batch_dimension(3); - tf_default_dnums_for_backward_filter_.set_output_batch_dimension(3); tf_default_dnums_for_backward_filter_.set_input_feature_dimension(0); - tf_default_dnums_for_backward_filter_.set_output_feature_dimension(0); tf_default_dnums_for_backward_filter_.add_input_spatial_dimensions(1); - tf_default_dnums_for_backward_filter_.add_output_spatial_dimensions(1); tf_default_dnums_for_backward_filter_.add_input_spatial_dimensions(2); - tf_default_dnums_for_backward_filter_.add_output_spatial_dimensions(2); tf_default_dnums_for_backward_filter_.set_kernel_input_feature_dimension(0); tf_default_dnums_for_backward_filter_.set_kernel_output_feature_dimension( 3); tf_default_dnums_for_backward_filter_.add_kernel_spatial_dimensions(1); tf_default_dnums_for_backward_filter_.add_kernel_spatial_dimensions(2); + tf_default_dnums_for_backward_filter_.add_output_spatial_dimensions(0); + tf_default_dnums_for_backward_filter_.add_output_spatial_dimensions(1); + tf_default_dnums_for_backward_filter_.set_output_batch_dimension(2); + tf_default_dnums_for_backward_filter_.set_output_feature_dimension(3); tf_default_dnums_for_backward_input_.set_input_batch_dimension(0); tf_default_dnums_for_backward_input_.set_output_batch_dimension(0); @@ -86,7 +86,7 @@ class ConvolutionFoldingTest : public HloTestBase { ConvolutionDimensionNumbers tf_default_dnums_for_backward_input_; }; -TEST_F(ConvolutionFoldingTest, BackwardFilterConvolveWithoutTranspose) { +TEST_F(ConvolutionFoldingTest, BackwardFilterConvolve) { HloComputation::Builder builder(TestName()); HloInstruction* activations = builder.AddInstruction(HloInstruction::CreateParameter( @@ -136,7 +136,7 @@ TEST_F(ConvolutionFoldingTest, auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - EXPECT_FALSE(FoldConvolution(module.get())); + EXPECT_TRUE(FoldConvolution(module.get())); } // Extracted from block35 training. @@ -155,13 +155,9 @@ TEST_F(ConvolutionFoldingTest, BackwardFilterConvolveWithPaddedActivations) { conv_window.mutable_dimensions(i)->set_padding_low(1); conv_window.mutable_dimensions(i)->set_padding_high(1); } - HloInstruction* convolution = - builder.AddInstruction(HloInstruction::CreateConvolve( - ShapeUtil::MakeShape(F32, {32, 3, 3, 32}), activations, gradients, - conv_window, tf_default_dnums_for_backward_filter_)); - - builder.AddInstruction(HloInstruction::CreateTranspose( - ShapeUtil::MakeShape(F32, {3, 3, 32, 32}), convolution, {1, 2, 3, 0})); + builder.AddInstruction(HloInstruction::CreateConvolve( + ShapeUtil::MakeShape(F32, {32, 3, 3, 32}), activations, gradients, + conv_window, tf_default_dnums_for_backward_filter_)); auto module = CreateNewModule(); HloComputation* entry_computation = @@ -189,13 +185,9 @@ TEST_F(ConvolutionFoldingTest, BackwardFilterConvolveWithPaddedGradients) { conv_window.mutable_dimensions(i)->set_padding_high(-1); conv_window.mutable_dimensions(i)->set_window_dilation(2); } - HloInstruction* convolution = - builder.AddInstruction(HloInstruction::CreateConvolve( - ShapeUtil::MakeShape(F32, {320, 3, 3, 192}), activations, gradients, - conv_window, tf_default_dnums_for_backward_filter_)); - - builder.AddInstruction(HloInstruction::CreateTranspose( - ShapeUtil::MakeShape(F32, {3, 3, 192, 320}), convolution, {1, 2, 3, 0})); + builder.AddInstruction(HloInstruction::CreateConvolve( + ShapeUtil::MakeShape(F32, {320, 3, 3, 192}), activations, gradients, + conv_window, tf_default_dnums_for_backward_filter_)); auto module = CreateNewModule(); HloComputation* entry_computation = @@ -222,13 +214,9 @@ TEST_F(ConvolutionFoldingTest, BackwardFilterConvolveWithUnevenPadding) { // Uneven padding: padding_low=0, padding_high=1 conv_window.mutable_dimensions(i)->set_padding_high(1); } - HloInstruction* convolution = - builder.AddInstruction(HloInstruction::CreateConvolve( - ShapeUtil::MakeShape(F32, {32, 2, 2, 32}), activations, gradients, - conv_window, tf_default_dnums_for_backward_filter_)); - - builder.AddInstruction(HloInstruction::CreateTranspose( - ShapeUtil::MakeShape(F32, {2, 2, 32, 32}), convolution, {1, 2, 3, 0})); + builder.AddInstruction(HloInstruction::CreateConvolve( + ShapeUtil::MakeShape(F32, {32, 2, 2, 32}), activations, gradients, + conv_window, tf_default_dnums_for_backward_filter_)); auto module = CreateNewModule(); HloComputation* entry_computation = diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc index 037eec8ef59e1aeccdfc43dbb5c1a852403780d1..899cc5c83b99f1bb6154f883ca17871863e1f457 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc @@ -314,7 +314,9 @@ tensorflow::Status ConvolutionThunk::ConvolveWithTune( const ConvolutionDescriptor& convolution_descriptor, const BufferAllocations& buffer_allocations, se::Stream* stream) { // TODO(b/29126320): Try cudnn v5's new auto-tuner when it's rolled out. - if (best_algorithm_.algorithm().is_default()) { + if (!best_algorithm_.has_value()) { + best_algorithm_.emplace(); + // Auto-tuning either is disabled or only happens in the first run of this // function. VLOG(2) << "Profiling for best convolution algorithm used for " @@ -363,35 +365,35 @@ tensorflow::Status ConvolutionThunk::ConvolveWithTune( } if (best_result.is_valid()) { - best_algorithm_.set_algorithm(best_result.algorithm()); + best_algorithm_->set_algorithm(best_result.algorithm()); } else { LOG(ERROR) << "No convolution algorithm works with profiling. Fall back " "to the default algorithm."; - best_algorithm_.set_algorithm(AlgorithmDesc()); + best_algorithm_->set_algorithm(AlgorithmDesc()); } if (best_result_without_scratch.is_valid()) { - best_algorithm_.set_algorithm_no_scratch( + best_algorithm_->set_algorithm_no_scratch( best_result_without_scratch.algorithm()); } else { LOG(ERROR) << "No convolution algorithm without scratch works with " "profiling. Fall back " "to the default algorithm."; - best_algorithm_.set_algorithm_no_scratch(AlgorithmDesc()); + best_algorithm_->set_algorithm_no_scratch(AlgorithmDesc()); } } { VLOG(2) << "Using convolution algorithm (" - << AlgorithmToString(best_algorithm_.algorithm()) << ", " - << AlgorithmToString(best_algorithm_.algorithm_no_scratch()) + << AlgorithmToString(best_algorithm_->algorithm()) << ", " + << AlgorithmToString(best_algorithm_->algorithm_no_scratch()) << ") for ConvolutionThunk: " << this; ConvolveScratchAllocator scratch_allocator( buffer_allocations.device_ordinal(), buffer_allocations.memory_allocator()); return Convolve(input_descriptor, input_data, filter_descriptor, filter_data, output_descriptor, output_data, - convolution_descriptor, best_algorithm_, stream, + convolution_descriptor, *best_algorithm_, stream, &scratch_allocator, nullptr); } } diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h index 5ac5db2f04b6796c6013a7f87dd40b485233baa6..46c94d0bf1e486fb91e63109efb8e4ba778c4120 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" namespace xla { @@ -87,6 +88,34 @@ class ConvolutionThunk : public Thunk { const BufferAllocations& buffer_allocations, perftools::gputools::Stream* stream) override; + // Returns true if the next run of ExecuteOnStream will do autotuning. If so, + // we want the GPU to be quiescent during autotuning, so as not to introduce + // noise in our results. + bool ShouldHaltAllActivityBeforeRunning( + perftools::gputools::Stream*) override { + return !best_algorithm_.has_value(); + } + + // Return true if scratch memory is needed to execute the thunk, that is + // either the best algorithm hasn't been chosen or the best algorithm is not + // the same as the no-scratch algorithm. This is because that the execution + // of the thunk is asynchronous, and the scratch allocator goes out of + // scope before the thunk finishes execution. Returning true tells the stream + // executor to make future thunks wait for this thunk to avoid reusing the + // deallocated scratch memory until this thunk is done with it. + bool ShouldBlockFutureThunks() { + if (!best_algorithm_.has_value()) { + return true; + } + + const perftools::gputools::dnn::AlgorithmDesc& best_alg = + best_algorithm_->algorithm(); + const perftools::gputools::dnn::AlgorithmDesc& no_scratch_best_alg = + best_algorithm_->algorithm_no_scratch(); + return (!best_alg.is_default() || !no_scratch_best_alg.is_default() || + !(best_alg == no_scratch_best_alg)); + } + private: tensorflow::Status ConvolveWithTune( const perftools::gputools::dnn::BatchDescriptor& input_descriptor, @@ -121,9 +150,10 @@ class ConvolutionThunk : public Thunk { // Fastest cuDNN convolution algorithm for this thunk learned from // auto-tuning. If auto-tuning is disabled or failed, best_algorithm_ is set - // to the default value indicating cuDNN's convolution will choose - // the best algorithm from some heuristics based on its parameters. - perftools::gputools::dnn::AlgorithmConfig best_algorithm_; + // to the default value, indicating cuDNN's convolution will choose the best + // algorithm from some heuristics based on its parameters. + tensorflow::gtl::optional + best_algorithm_; const ConvolutionKind convolution_kind_; diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.cc new file mode 100644 index 0000000000000000000000000000000000000000..db6924c742e4a949a3e939b6d6659e92c2d1e312 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.cc @@ -0,0 +1,219 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" + +namespace xla { +namespace gpu { +namespace { + +class Visitor : public DfsHloVisitorWithDefault { + public: + explicit Visitor(HloComputation* computation) : computation_(computation) {} + + static bool Run(HloComputation* computation) { + Visitor visitor(computation); + TF_CHECK_OK(computation->Accept(&visitor)); + return visitor.changed_; + } + + Status DefaultAction(HloInstruction* /*hlo_instruction*/) override { + return Status::OK(); + } + + Status HandleBatchNormInference(HloInstruction* batch_norm) override; + Status HandleBatchNormTraining(HloInstruction* batch_norm) override; + Status HandleBatchNormGrad(HloInstruction* batch_norm) override; + + private: + bool changed_ = false; + HloComputation* computation_; +}; + +// cudnn defines CUDNN_BN_MIN_EPSILON = 1e-5 as the minimum acceptable epsilon +// for calls to its batchnorm ops. +bool EpsilonInRange(HloInstruction* batch_norm) { + return batch_norm->epsilon() >= 1e-5; +} + +Status Visitor::HandleBatchNormInference(HloInstruction* batch_norm) { + if (batch_norm->operand(0)->shape().element_type() != F32) { + VLOG(1) << "Not rewriting op with non-F32 element type: " + << batch_norm->ToString(); + return Status::OK(); + } + + // cudnn errors out on zero-sized inputs. + if (ShapeUtil::ElementsIn(batch_norm->operand(0)->shape()) == 0) { + return Status::OK(); + } + + if (!EpsilonInRange(batch_norm)) { + return Status::OK(); + } + + HloInstruction* epsilon = computation_->AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(batch_norm->epsilon()))); + HloInstruction* feature_index = + computation_->AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR0(batch_norm->feature_index()))); + + std::vector operands(batch_norm->operands().begin(), + batch_norm->operands().end()); + operands.push_back(epsilon); + operands.push_back(feature_index); + + std::unique_ptr libcall = HloInstruction::CreateCustomCall( + batch_norm->shape(), operands, kCudnnBatchNormForwardInferenceCallTarget); + TF_RETURN_IF_ERROR( + computation_->ReplaceWithNewInstruction(batch_norm, std::move(libcall))); + changed_ = true; + return Status::OK(); +} + +Status Visitor::HandleBatchNormTraining(HloInstruction* batch_norm) { + if (batch_norm->operand(0)->shape().element_type() != F32) { + VLOG(1) << "Not rewriting op with non-F32 element type: " + << batch_norm->ToString(); + return Status::OK(); + } + + // cudnn errors out on zero-sized inputs. + if (ShapeUtil::ElementsIn(batch_norm->operand(0)->shape()) == 0) { + return Status::OK(); + } + + if (!EpsilonInRange(batch_norm)) { + return Status::OK(); + } + + HloInstruction* epsilon = computation_->AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(batch_norm->epsilon()))); + HloInstruction* feature_index = + computation_->AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR0(batch_norm->feature_index()))); + + std::vector operands(batch_norm->operands().begin(), + batch_norm->operands().end()); + operands.push_back(epsilon); + operands.push_back(feature_index); + + HloInstruction* libcall = + computation_->AddInstruction(HloInstruction::CreateCustomCall( + batch_norm->shape(), operands, + kCudnnBatchNormForwardTrainingCallTarget)); + + // The cudnn libcall returns a tuple + // {output, mean, rsqrt(variance + epsilon)}, + // but the batchnorm HLO returns {output, mean, variance}. Fix it up. + HloInstruction* inverse_stddev = + computation_->AddInstruction(HloInstruction::CreateGetTupleElement( + libcall->shape().tuple_shapes(2), libcall, 2)); + HloInstruction* variance_plus_epsilon = + computation_->AddInstruction(HloInstruction::CreateBinary( + inverse_stddev->shape(), HloOpcode::kPower, inverse_stddev, + computation_->AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(-2))))); + HloInstruction* variance = + computation_->AddInstruction(HloInstruction::CreateBinary( + variance_plus_epsilon->shape(), HloOpcode::kSubtract, + variance_plus_epsilon, epsilon)); + + // Repackage the results. + std::unique_ptr new_tuple = HloInstruction::CreateTuple({ + computation_->AddInstruction(HloInstruction::CreateGetTupleElement( + libcall->shape().tuple_shapes(0), libcall, 0)), + computation_->AddInstruction(HloInstruction::CreateGetTupleElement( + libcall->shape().tuple_shapes(1), libcall, 1)), + variance, + }); + + TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction( + batch_norm, std::move(new_tuple))); + changed_ = true; + return Status::OK(); +} + +Status Visitor::HandleBatchNormGrad(HloInstruction* batch_norm) { + if (batch_norm->operand(0)->shape().element_type() != F32) { + VLOG(1) << "Not rewriting op with non-F32 element type: " + << batch_norm->ToString(); + return Status::OK(); + } + + // cudnn errors out on zero-sized inputs. + if (ShapeUtil::ElementsIn(batch_norm->operand(0)->shape()) == 0) { + return Status::OK(); + } + + if (!EpsilonInRange(batch_norm)) { + return Status::OK(); + } + + HloInstruction* epsilon = computation_->AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(batch_norm->epsilon()))); + HloInstruction* feature_index = + computation_->AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR0(batch_norm->feature_index()))); + + // The cudnn libcall expects its input to be rsqrt(variance + epsilon), but + // the batchnorm HLO takes plain variance as input. Fix it up. + HloInstruction* var_plus_epsilon = + computation_->AddInstruction(HloInstruction::CreateBinary( + batch_norm->operand(3)->shape(), HloOpcode::kAdd, + batch_norm->mutable_operand(3), epsilon)); + HloInstruction* inverse_stddev = + computation_->AddInstruction(HloInstruction::CreateBinary( + var_plus_epsilon->shape(), HloOpcode::kPower, var_plus_epsilon, + computation_->AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(-.5))))); + + std::vector operands(batch_norm->operands().begin(), + batch_norm->operands().end()); + operands[3] = inverse_stddev; + operands.push_back(epsilon); + operands.push_back(feature_index); + + std::unique_ptr libcall = HloInstruction::CreateCustomCall( + batch_norm->shape(), operands, kCudnnBatchNormBackwardCallTarget); + + TF_RETURN_IF_ERROR( + computation_->ReplaceWithNewInstruction(batch_norm, std::move(libcall))); + changed_ = true; + return Status::OK(); +} + +} // anonymous namespace + +StatusOr CudnnBatchNormRewriter::Run(HloModule* module) { + VLOG(2) << "CudnnBatchNormRewriter::Run(), before:"; + XLA_VLOG_LINES(2, module->ToString()); + + bool changed = false; + for (auto* comp : module->MakeNonfusionComputations()) { + if (Visitor::Run(comp)) { + changed = true; + } + } + + VLOG(2) << "CudnnBatchNormRewriter::Run(), after:"; + XLA_VLOG_LINES(2, module->ToString()); + return changed; +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h new file mode 100644 index 0000000000000000000000000000000000000000..e09cde9abf85454c7a020566cd8c2671ae12ffc3 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h @@ -0,0 +1,66 @@ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_BATCHNORM_REWRITER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_BATCHNORM_REWRITER_H_ + +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { +namespace gpu { + +// Rewrites BatchNorm HLOs into calls into cudnn where possible. +// +// A call into cudnn for performing a batchnorm op is represented as a +// CustomCall HLO with custom_call_target equal to one of +// +// - kCudnnBatchNormForwardInferenceCallTarget +// - kCudnnBatchNormForwardTrainingCallTarget, or +// - kCudnnBatchNormBackwardCallTarget. +// +// A CustomCall created by this pass has the same operands corresponding +// batchnorm HLO, except the epsilon() and feature_index() properties of the +// batchnorm HLO are converted into proper operands, added to the end of the +// CustomCall's operands list. +// +// The inputs/outputs of the cudnn calls for BatchNormTraining and BatchNormGrad +// do not correspond exactly to the HLOs. In particular, the training cudnn +// call returns 1/sqrt(variance + epsilon), while the HLO returns plain +// variance. Similarly, the grad cudnn call expects 1/sqrt(variance + epsilon) +// as input, whereas the HLO expects plain variance. +// +// This pass adds HLOs in front of / behind the CustomCalls to fix up the +// inputs/outputs as appropriate, and we rely on the AlgebraicSimplifier to +// remove these where possible. +// +// Currently batchnorm ops over F32s are converted into cudnn calls, so long as +// epsilon is not too small. This pass leaves other batchnorm ops unmodified. +// +// The GPU backend does not implement a lowering for the batchnorm HLOs -- it +// expects them to be lowered to cudnn calls via this pass or to HLO soup via +// BatchNormRewriter. +class CudnnBatchNormRewriter : public HloPassInterface { + public: + tensorflow::StringPiece name() const override { + return "cudnn_batchnorm_rewriter"; + } + StatusOr Run(HloModule* module) override; +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_BATCHNORM_REWRITER_H_ diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc new file mode 100644 index 0000000000000000000000000000000000000000..58d9c8caff31e878487fbef01afce566e6187fd9 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc @@ -0,0 +1,285 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h" + +#include + +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace xla { +namespace gpu { + +namespace se = ::perftools::gputools; +namespace dnn = se::dnn; + +static std::pair +MakeDescriptors(const Shape& shape, int64 feature_index) { + std::vector logical_to_physical = + LayoutUtil::MakeLogicalToPhysical(shape.layout()); + + auto physical_dim_size = [&](int64 physical_dim) { + return shape.dimensions(LayoutUtil::Major(shape.layout(), physical_dim)); + }; + + // Batchnorm only cares about the location of the depth (aka "feature") dim. + // The other dims are all treated the same. Thus we can use the kBatchDepthYX + // cudnn layout for any XLA shape+layout, even XLA shapes that don't have + // exactly 4 dimensions: We put everything that comes before the feature dim + // into "batch", and everything that comes after the feature dim into "Y". + int64 batch_size = 1; + int64 y_size = 1; + int64 physical_dim; + for (physical_dim = 0; physical_dim != logical_to_physical[feature_index]; + ++physical_dim) { + CHECK_LT(physical_dim, shape.dimensions_size()); + batch_size *= physical_dim_size(physical_dim); + } + ++physical_dim; // Skip the feature dimension. + for (; physical_dim < shape.dimensions_size(); ++physical_dim) { + y_size *= physical_dim_size(physical_dim); + } + + dnn::BatchDescriptor input_desc; + input_desc.set_layout(dnn::DataLayout::kBatchDepthYX) + .set_count(batch_size) + .set_feature_map_count(shape.dimensions(feature_index)) + .set_height(y_size) + .set_width(1); + + dnn::BatchDescriptor scale_offset_desc; + scale_offset_desc.set_layout(dnn::DataLayout::kBatchDepthYX) + .set_feature_map_count(input_desc.feature_map_count()) + .set_height(1) + .set_width(1) + .set_count(1); + + return std::make_pair(input_desc, scale_offset_desc); +} + +CudnnBatchNormForwardInferenceThunk::CudnnBatchNormForwardInferenceThunk( + const BufferAllocation::Slice& operand, + const BufferAllocation::Slice& scale, const BufferAllocation::Slice& offset, + const BufferAllocation::Slice& mean, + const BufferAllocation::Slice& variance, float epsilon, int64 feature_index, + const BufferAllocation::Slice& output, const HloInstruction* hlo) + : Thunk(Thunk::Kind::kCudnnBatchNormForwardInference, hlo), + operand_(operand), + scale_(scale), + offset_(offset), + mean_(mean), + variance_(variance), + epsilon_(epsilon), + feature_index_(feature_index), + output_(output) { + CHECK_EQ(hlo->opcode(), HloOpcode::kCustomCall); + CHECK_EQ(hlo->custom_call_target(), + kCudnnBatchNormForwardInferenceCallTarget); + CHECK( + LayoutUtil::LayoutsInShapesEqual(hlo->shape(), hlo->operand(0)->shape())); + CHECK_EQ(hlo->shape().element_type(), F32) << "Not yet implemented"; +} + +Status CudnnBatchNormForwardInferenceThunk::ExecuteOnStream( + const BufferAllocations& buffer_allocations, se::Stream* stream) { + dnn::BatchDescriptor operand_desc; + dnn::BatchDescriptor scale_offset_desc; + std::tie(operand_desc, scale_offset_desc) = + MakeDescriptors(hlo_instruction()->shape(), feature_index_); + + se::DeviceMemory output(buffer_allocations.GetDeviceAddress(output_)); + stream->ThenBatchNormalizationForward( + se::DeviceMemory(buffer_allocations.GetDeviceAddress(operand_)), + se::DeviceMemory(buffer_allocations.GetDeviceAddress(scale_)), + se::DeviceMemory(buffer_allocations.GetDeviceAddress(offset_)), + se::DeviceMemory(buffer_allocations.GetDeviceAddress(mean_)), + se::DeviceMemory(buffer_allocations.GetDeviceAddress(variance_)), + operand_desc, // + scale_offset_desc, // + epsilon_, // + &output, // + /*batch_mean=*/nullptr, // + /*batch_var=*/nullptr, // + /*saved_mean=*/nullptr, // + /*saved_inv_var=*/nullptr, // + /*is_training=*/false, // + /*var_to_inv_var=*/nullptr, // + /*inv_var_to_var=*/nullptr); + if (!stream->ok()) { + return InternalError("BatchNormalizationForward call failed."); + } + return Status::OK(); +} + +CudnnBatchNormForwardTrainingThunk::CudnnBatchNormForwardTrainingThunk( + const BufferAllocation::Slice& operand, + const BufferAllocation::Slice& scale, const BufferAllocation::Slice& offset, + float epsilon, int64 feature_index, + const BufferAllocation::Slice& output_data, + const BufferAllocation::Slice& output_mean, + const BufferAllocation::Slice& output_inv_stddev, + const BufferAllocation::Slice& output_tuple, const HloInstruction* hlo) + : Thunk(Thunk::Kind::kCudnnBatchNormForwardTraining, hlo), + operand_(operand), + scale_(scale), + offset_(offset), + epsilon_(epsilon), + feature_index_(feature_index), + output_data_(output_data), + output_mean_(output_mean), + output_inv_stddev_(output_inv_stddev), + output_tuple_(output_tuple) { + CHECK_EQ(hlo->opcode(), HloOpcode::kCustomCall); + CHECK_EQ(hlo->custom_call_target(), kCudnnBatchNormForwardTrainingCallTarget); + CHECK_EQ(hlo->shape().tuple_shapes_size(), 3); + CHECK(LayoutUtil::LayoutsInShapesEqual(hlo->shape().tuple_shapes(0), + hlo->operand(0)->shape())); + for (const auto& tuple_shape : hlo->shape().tuple_shapes()) { + CHECK_EQ(tuple_shape.element_type(), F32) << "Not yet implemented"; + } +} + +Status CudnnBatchNormForwardTrainingThunk::ExecuteOnStream( + const BufferAllocations& buffer_allocations, se::Stream* stream) { + dnn::BatchDescriptor operand_desc; + dnn::BatchDescriptor scale_offset_desc; + // The BatchNormTraining HLO outputs a tuple of three elements: output data, + // batch mean, and batch variance. We want to make our descriptors based on + // the shape of the output data. + std::tie(operand_desc, scale_offset_desc) = MakeDescriptors( + hlo_instruction()->shape().tuple_shapes(0), feature_index_); + + se::DeviceMemory output_data( + buffer_allocations.GetDeviceAddress(output_data_)); + se::DeviceMemory output_mean( + buffer_allocations.GetDeviceAddress(output_mean_)); + se::DeviceMemory output_inv_stddev( + buffer_allocations.GetDeviceAddress(output_inv_stddev_)); + + se::DeviceMemory null_device_ptr(nullptr); + stream->ThenBatchNormalizationForward( + se::DeviceMemory(buffer_allocations.GetDeviceAddress(operand_)), + se::DeviceMemory(buffer_allocations.GetDeviceAddress(scale_)), + se::DeviceMemory(buffer_allocations.GetDeviceAddress(offset_)), + /*estimated_mean=*/null_device_ptr, + /*estimated_variance=*/null_device_ptr, + operand_desc, // + scale_offset_desc, // + epsilon_, // + &output_data, // + /*batch_mean=*/&null_device_ptr, // + /*batch_var=*/&null_device_ptr, // + /*saved_mean=*/&output_mean, // + /*saved_inv_var=*/&output_inv_stddev, // + /*is_training=*/true, // + /*var_to_inv_var=*/nullptr, // + /*inv_var_to_var=*/nullptr); + + // Write the tuple. + void* ptrs[] = {output_data.opaque(), output_mean.opaque(), + output_inv_stddev.opaque()}; + se::DeviceMemory tuple_addr( + buffer_allocations.GetDeviceAddress(output_tuple_)); + stream->ThenMemcpyH2D(ptrs, &tuple_addr); + + if (!stream->ok()) { + return InternalError("BatchNormalizationTraining call failed."); + } + return Status::OK(); +} + +CudnnBatchNormBackwardThunk::CudnnBatchNormBackwardThunk( + const BufferAllocation::Slice& operand, + const BufferAllocation::Slice& scale, const BufferAllocation::Slice& mean, + const BufferAllocation::Slice& inv_stddev, + const BufferAllocation::Slice& grad_output, float epsilon, + int64 feature_index, const BufferAllocation::Slice& output_grad_data, + const BufferAllocation::Slice& output_grad_scale, + const BufferAllocation::Slice& output_grad_offset, + const BufferAllocation::Slice& output_tuple, const HloInstruction* hlo) + : Thunk(Thunk::Kind::kCudnnBatchNormBackward, hlo), + operand_(operand), + scale_(scale), + mean_(mean), + inv_stddev_(inv_stddev), + grad_output_(grad_output), + epsilon_(epsilon), + feature_index_(feature_index), + output_grad_data_(output_grad_data), + output_grad_scale_(output_grad_scale), + output_grad_offset_(output_grad_offset), + output_tuple_(output_tuple) { + CHECK_EQ(hlo->opcode(), HloOpcode::kCustomCall); + CHECK_EQ(hlo->custom_call_target(), kCudnnBatchNormBackwardCallTarget); + CHECK_EQ(hlo->shape().tuple_shapes_size(), 3); + CHECK(LayoutUtil::LayoutsInShapesEqual(hlo->shape().tuple_shapes(0), + hlo->operand(0)->shape())); + CHECK(LayoutUtil::LayoutsInShapesEqual(hlo->shape().tuple_shapes(0), + hlo->operand(4)->shape())); + for (const auto& tuple_shape : hlo->shape().tuple_shapes()) { + CHECK_EQ(tuple_shape.element_type(), F32) << "Not yet implemented"; + } +} + +Status CudnnBatchNormBackwardThunk::ExecuteOnStream( + const BufferAllocations& buffer_allocations, se::Stream* stream) { + dnn::BatchDescriptor operand_desc; + dnn::BatchDescriptor scale_offset_desc; + + // This call outputs a tuple of three elements: grad data, grad offset, and + // grad scale. We want to make our descriptors based on the shape of the grad + // data. + std::tie(operand_desc, scale_offset_desc) = MakeDescriptors( + hlo_instruction()->shape().tuple_shapes(0), feature_index_); + + se::DeviceMemory output_grad_data( + buffer_allocations.GetDeviceAddress(output_grad_data_)); + se::DeviceMemory output_grad_scale( + buffer_allocations.GetDeviceAddress(output_grad_scale_)); + se::DeviceMemory output_grad_offset( + buffer_allocations.GetDeviceAddress(output_grad_offset_)); + + stream->ThenBatchNormalizationBackward( + se::DeviceMemory( + buffer_allocations.GetDeviceAddress(grad_output_)), + se::DeviceMemory(buffer_allocations.GetDeviceAddress(operand_)), + se::DeviceMemory(buffer_allocations.GetDeviceAddress(scale_)), + se::DeviceMemory(buffer_allocations.GetDeviceAddress(mean_)), + se::DeviceMemory(buffer_allocations.GetDeviceAddress(inv_stddev_)), + operand_desc, scale_offset_desc, epsilon_, &output_grad_data, + &output_grad_scale, &output_grad_offset); + + // Write the output tuple. + void* ptrs[] = {output_grad_data.opaque(), output_grad_scale.opaque(), + output_grad_offset.opaque()}; + se::DeviceMemory tuple_addr( + buffer_allocations.GetDeviceAddress(output_tuple_)); + stream->ThenMemcpyH2D(ptrs, &tuple_addr); + + if (!stream->ok()) { + return InternalError("BatchNormalizationBackward call failed."); + } + return Status::OK(); +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h new file mode 100644 index 0000000000000000000000000000000000000000..c5fbb6d8a3912d380172d496d8d35e80dc9f5c71 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h @@ -0,0 +1,145 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_BATCHNORM_THUNK_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_BATCHNORM_THUNK_H_ + +#include "tensorflow/compiler/xla/service/buffer_assignment.h" +#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" +#include "tensorflow/compiler/xla/service/gpu/thunk.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/core/status.h" + +namespace xla { +namespace gpu { + +// This file contains thunks which call into cudnn to run the various flavors of +// batch normalization: BatchNormInference, BatchNormTraining, and +// BatchNormGrad, known to cudnn as BatchNormForwardInference, +// BatchNormForwardTraining, and BatchNormBackward. +// +// As an alternative to using these thunks, XLA can decompose batchnorm HLOs +// into smaller components using the BatchNormRewriter pass. This can result in +// faster code because those individual components can fuse into their +// inputs/outputs, but it may also be slower if cudnn's batchnorm implementation +// outperforms the code XLA generates for these components. +// +// Currently these thunks require that their inputs are F32s. +// +// Note that these thunks do not take full advantage of the cudnn batchnorm +// functions. For example, cudnn lets you bias and/or scale the input/output, +// but these thunks don't currently support that. + +class CudnnBatchNormForwardInferenceThunk : public Thunk { + public: + CudnnBatchNormForwardInferenceThunk(const BufferAllocation::Slice& operand, + const BufferAllocation::Slice& scale, + const BufferAllocation::Slice& offset, + const BufferAllocation::Slice& mean, + const BufferAllocation::Slice& variance, + float epsilon, int64 feature_index, + const BufferAllocation::Slice& output, + const HloInstruction* hlo); + + CudnnBatchNormForwardInferenceThunk( + const CudnnBatchNormForwardInferenceThunk&) = delete; + CudnnBatchNormForwardInferenceThunk& operator=( + const CudnnBatchNormForwardInferenceThunk&) = delete; + + Status ExecuteOnStream(const BufferAllocations& buffer_allocations, + perftools::gputools::Stream* stream) override; + + private: + BufferAllocation::Slice operand_; + BufferAllocation::Slice scale_; + BufferAllocation::Slice offset_; + BufferAllocation::Slice mean_; + BufferAllocation::Slice variance_; + float epsilon_; + int64 feature_index_; + BufferAllocation::Slice output_; +}; + +class CudnnBatchNormForwardTrainingThunk : public Thunk { + public: + CudnnBatchNormForwardTrainingThunk( + const BufferAllocation::Slice& operand, + const BufferAllocation::Slice& scale, + const BufferAllocation::Slice& offset, float epsilon, int64 feature_index, + const BufferAllocation::Slice& output_data, + const BufferAllocation::Slice& output_mean, + const BufferAllocation::Slice& output_inv_stddev, + const BufferAllocation::Slice& output_tuple, const HloInstruction* hlo); + + CudnnBatchNormForwardTrainingThunk( + const CudnnBatchNormForwardTrainingThunk&) = delete; + CudnnBatchNormForwardTrainingThunk& operator=( + const CudnnBatchNormForwardTrainingThunk&) = delete; + + Status ExecuteOnStream(const BufferAllocations& buffer_allocations, + perftools::gputools::Stream* stream) override; + + private: + BufferAllocation::Slice operand_; + BufferAllocation::Slice scale_; + BufferAllocation::Slice offset_; + float epsilon_; + int64 feature_index_; + BufferAllocation::Slice output_data_; + BufferAllocation::Slice output_mean_; + BufferAllocation::Slice output_inv_stddev_; + BufferAllocation::Slice output_tuple_; +}; + +class CudnnBatchNormBackwardThunk : public Thunk { + public: + CudnnBatchNormBackwardThunk(const BufferAllocation::Slice& operand, + const BufferAllocation::Slice& scale, + const BufferAllocation::Slice& mean, + const BufferAllocation::Slice& inv_stddev, + const BufferAllocation::Slice& grad_output, + float epsilon, int64 feature_index, + const BufferAllocation::Slice& output_grad_data, + const BufferAllocation::Slice& output_grad_scale, + const BufferAllocation::Slice& output_grad_offset, + const BufferAllocation::Slice& output_tuple, + const HloInstruction* hlo); + + CudnnBatchNormBackwardThunk(const CudnnBatchNormBackwardThunk&) = delete; + CudnnBatchNormBackwardThunk& operator=(const CudnnBatchNormBackwardThunk&) = + delete; + + Status ExecuteOnStream(const BufferAllocations& buffer_allocations, + perftools::gputools::Stream* stream) override; + + private: + BufferAllocation::Slice operand_; + BufferAllocation::Slice scale_; + BufferAllocation::Slice mean_; + BufferAllocation::Slice inv_stddev_; + BufferAllocation::Slice grad_output_; + float epsilon_; + int64 feature_index_; + BufferAllocation::Slice output_grad_data_; + BufferAllocation::Slice output_grad_scale_; + BufferAllocation::Slice output_grad_offset_; + BufferAllocation::Slice output_tuple_; +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_BATCHNORM_THUNK_H_ diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc index 6bf00cfb8a53723ae9608093480bf2eed10144dd..4b511cb4bb94addfae53d6b2e6d6f86d5b9afd84 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc @@ -135,10 +135,6 @@ StatusOr GpuElementalIrEmitter::EmitFloatBinaryOp( PrimitiveType rhs_input_type = op->operand(1)->shape().element_type(); PrimitiveType output_type = op->shape().element_type(); switch (op->opcode()) { - case HloOpcode::kAtan2: - return EmitLibdeviceMathCall("__nv_atan2", {lhs_value, rhs_value}, - {lhs_input_type, rhs_input_type}, - output_type); case HloOpcode::kRemainder: { return EmitLibdeviceMathCall("__nv_fmod", {lhs_value, rhs_value}, {lhs_input_type, rhs_input_type}, @@ -199,29 +195,50 @@ StatusOr GpuElementalIrEmitter::EmitErfcInv( return EmitLibdeviceMathCall("__nv_erfcinv", {value}, {prim_type}, prim_type); } +StatusOr GpuElementalIrEmitter::EmitLog( + PrimitiveType prim_type, llvm::Value* value) const { + return EmitLibdeviceMathCall("__nv_log", {value}, {prim_type}, prim_type); +} + +StatusOr GpuElementalIrEmitter::EmitSin( + PrimitiveType prim_type, llvm::Value* value) const { + return EmitLibdeviceMathCall("__nv_sin", {value}, {prim_type}, prim_type); +} + +StatusOr GpuElementalIrEmitter::EmitCos( + PrimitiveType prim_type, llvm::Value* value) const { + return EmitLibdeviceMathCall("__nv_cos", {value}, {prim_type}, prim_type); +} + +StatusOr GpuElementalIrEmitter::EmitExp( + PrimitiveType prim_type, llvm::Value* value) const { + return EmitLibdeviceMathCall("__nv_exp", {value}, {prim_type}, prim_type); +} + +StatusOr GpuElementalIrEmitter::EmitPow(PrimitiveType prim_type, + llvm::Value* lhs, + llvm::Value* rhs) const { + return EmitLibdeviceMathCall("__nv_pow", {lhs, rhs}, {prim_type, prim_type}, + prim_type); +} + +StatusOr GpuElementalIrEmitter::EmitAtan2( + PrimitiveType prim_type, llvm::Value* lhs, llvm::Value* rhs) const { + return EmitLibdeviceMathCall("__nv_atan2", {lhs, rhs}, {prim_type, prim_type}, + prim_type); +} + StatusOr GpuElementalIrEmitter::EmitFloatUnaryOp( const HloInstruction* op, llvm::Value* operand_value) const { PrimitiveType input_type = op->operand(0)->shape().element_type(); PrimitiveType output_type = op->shape().element_type(); switch (op->opcode()) { - case HloOpcode::kExp: - return EmitLibdeviceMathCall("__nv_exp", {operand_value}, {input_type}, - output_type); case HloOpcode::kFloor: return EmitLibdeviceMathCall("__nv_floor", {operand_value}, {input_type}, output_type); case HloOpcode::kCeil: return EmitLibdeviceMathCall("__nv_ceil", {operand_value}, {input_type}, output_type); - case HloOpcode::kLog: - return EmitLibdeviceMathCall("__nv_log", {operand_value}, {input_type}, - output_type); - case HloOpcode::kCos: - return EmitLibdeviceMathCall("__nv_cos", {operand_value}, {input_type}, - output_type); - case HloOpcode::kSin: - return EmitLibdeviceMathCall("__nv_sin", {operand_value}, {input_type}, - output_type); case HloOpcode::kTanh: return EmitLibdeviceMathCall("__nv_tanh", {operand_value}, {input_type}, output_type); @@ -230,224 +247,6 @@ StatusOr GpuElementalIrEmitter::EmitFloatUnaryOp( } } -StatusOr GpuElementalIrEmitter::EmitComplexBinaryOp( - const HloInstruction* op, llvm::Value* lhs_value, - llvm::Value* rhs_value) const { - PrimitiveType input_type = op->operand(0)->shape().element_type(); - TF_RET_CHECK(primitive_util::IsComplexType(input_type)); - PrimitiveType component_type = - primitive_util::ComplexComponentType(input_type); - switch (op->opcode()) { - case HloOpcode::kPower: { - // (a+bi)^(c+di) = - // (a*a+b*b)^(0.5c) * exp(-d*atan2(b,a)) * (cos(q) + i*sin(q)), - // where q = c*atan2(b,a)+0.5d*ln(a*a+b*b) - auto a = EmitExtractReal(lhs_value); - auto b = EmitExtractImag(lhs_value); - auto c = EmitExtractReal(rhs_value); - auto d = EmitExtractImag(rhs_value); - auto aa_p_bb = ir_builder_->CreateFAdd(ir_builder_->CreateFMul(a, a), - ir_builder_->CreateFMul(b, b)); - auto one_half = llvm::ConstantFP::get(a->getType(), 0.5); - auto half_c = ir_builder_->CreateFMul(one_half, c); - - TF_ASSIGN_OR_RETURN( - auto aa_p_bb_to_half_c, - EmitLibdeviceMathCall("__nv_pow", {aa_p_bb, half_c}, - {component_type, component_type}, - component_type)); - auto neg_d = ir_builder_->CreateFNeg(d); - TF_ASSIGN_OR_RETURN( - auto arg_lhs, EmitLibdeviceMathCall("__nv_atan2", {b, a}, - {component_type, component_type}, - component_type)); - auto neg_d_arg_lhs = ir_builder_->CreateFMul(neg_d, arg_lhs); - TF_ASSIGN_OR_RETURN( - auto e_to_neg_d_arg_lhs, - EmitLibdeviceMathCall("__nv_exp", {neg_d_arg_lhs}, {component_type}, - component_type)); - auto coeff = - ir_builder_->CreateFMul(aa_p_bb_to_half_c, e_to_neg_d_arg_lhs); - TF_ASSIGN_OR_RETURN( - auto ln_aa_p_bb, - EmitLibdeviceMathCall("__nv_log", {aa_p_bb}, {component_type}, - component_type)); - auto half_d = ir_builder_->CreateFMul(one_half, d); - auto q = - ir_builder_->CreateFAdd(ir_builder_->CreateFMul(c, arg_lhs), - ir_builder_->CreateFMul(half_d, ln_aa_p_bb)); - TF_ASSIGN_OR_RETURN( - auto cos_q, EmitLibdeviceMathCall("__nv_cos", {q}, {component_type}, - component_type)); - TF_ASSIGN_OR_RETURN( - auto sin_q, EmitLibdeviceMathCall("__nv_sin", {q}, {component_type}, - component_type)); - return EmitComposeComplex(op, ir_builder_->CreateFMul(coeff, cos_q), - ir_builder_->CreateFMul(coeff, sin_q)); - } - default: - return ElementalIrEmitter::EmitComplexBinaryOp(op, lhs_value, rhs_value); - } -} - -StatusOr GpuElementalIrEmitter::EmitComplexUnaryOp( - const HloInstruction* op, llvm::Value* operand_value) const { - PrimitiveType input_type = op->operand(0)->shape().element_type(); - PrimitiveType component_type = - primitive_util::IsComplexType(input_type) - ? primitive_util::ComplexComponentType(input_type) - : input_type; - - switch (op->opcode()) { - case HloOpcode::kLog: { - // log(a+bi) = .5*log(a^2+b^2) + i*atan2(b, a) - auto a = EmitExtractReal(operand_value); - auto b = EmitExtractImag(operand_value); - llvm::Type* llvm_ty = a->getType(); - auto sum_sq = ir_builder_->CreateFAdd(ir_builder_->CreateFMul(a, a), - ir_builder_->CreateFMul(b, b)); - TF_ASSIGN_OR_RETURN( - auto log_sum_sq, - EmitLibdeviceMathCall("__nv_log", {sum_sq}, {component_type}, - component_type)); - TF_ASSIGN_OR_RETURN( - auto angle, EmitLibdeviceMathCall("__nv_atan2", {b, a}, - {component_type, component_type}, - component_type)); - auto one_half = llvm::ConstantFP::get(llvm_ty, 0.5); - return EmitComposeComplex( - op, ir_builder_->CreateFMul(one_half, log_sum_sq), angle); - } - case HloOpcode::kExp: { - // e^(a+bi) = e^a*(cos(b)+sin(b)i) - auto b = EmitExtractImag(operand_value); - TF_ASSIGN_OR_RETURN( - auto exp_a, - EmitLibdeviceMathCall("__nv_exp", {EmitExtractReal(operand_value)}, - {component_type}, component_type)); - TF_ASSIGN_OR_RETURN( - auto cos_b, EmitLibdeviceMathCall("__nv_cos", {b}, {component_type}, - component_type)); - TF_ASSIGN_OR_RETURN( - auto sin_b, EmitLibdeviceMathCall("__nv_sin", {b}, {component_type}, - component_type)); - return EmitComposeComplex(op, ir_builder_->CreateFMul(exp_a, cos_b), - ir_builder_->CreateFMul(exp_a, sin_b)); - } - case HloOpcode::kCos: { - // cos(a+bi) = .5(cos(a)*(e^-b+e^b) + i*sin(a)*(e^-b-e^b)) - auto a = EmitExtractReal(operand_value); - auto llvm_ty = a->getType(); - TF_ASSIGN_OR_RETURN( - auto exp_b, - EmitLibdeviceMathCall("__nv_exp", {EmitExtractImag(operand_value)}, - {component_type}, component_type)); - TF_ASSIGN_OR_RETURN( - auto cos_a, EmitLibdeviceMathCall("__nv_cos", {a}, {component_type}, - component_type)); - TF_ASSIGN_OR_RETURN( - auto sin_a, EmitLibdeviceMathCall("__nv_sin", {a}, {component_type}, - component_type)); - auto half_exp_b = - ir_builder_->CreateFMul(llvm::ConstantFP::get(llvm_ty, 0.5), exp_b); - auto half_exp_neg_b = - ir_builder_->CreateFDiv(llvm::ConstantFP::get(llvm_ty, 0.5), exp_b); - return EmitComposeComplex( - op, - ir_builder_->CreateFMul( - cos_a, ir_builder_->CreateFAdd(half_exp_neg_b, half_exp_b)), - ir_builder_->CreateFMul( - sin_a, ir_builder_->CreateFSub(half_exp_neg_b, half_exp_b))); - } - - case HloOpcode::kSin: { - // sin(a+bi) = 0.5(sin(a)*(e^b+e^-b) + i*cos(a)*(e^b-e^-b) - auto a = EmitExtractReal(operand_value); - auto llvm_ty = a->getType(); - TF_ASSIGN_OR_RETURN( - auto exp_b, - EmitLibdeviceMathCall("__nv_exp", {EmitExtractImag(operand_value)}, - {component_type}, component_type)); - TF_ASSIGN_OR_RETURN( - auto cos_a, EmitLibdeviceMathCall("__nv_cos", {a}, {component_type}, - component_type)); - TF_ASSIGN_OR_RETURN( - auto sin_a, EmitLibdeviceMathCall("__nv_sin", {a}, {component_type}, - component_type)); - auto half_exp_b = - ir_builder_->CreateFMul(llvm::ConstantFP::get(llvm_ty, 0.5), exp_b); - auto half_exp_neg_b = - ir_builder_->CreateFDiv(llvm::ConstantFP::get(llvm_ty, 0.5), exp_b); - return EmitComposeComplex( - op, - ir_builder_->CreateFMul( - sin_a, ir_builder_->CreateFAdd(half_exp_b, half_exp_neg_b)), - ir_builder_->CreateFMul( - cos_a, ir_builder_->CreateFSub(half_exp_b, half_exp_neg_b))); - } - case HloOpcode::kTanh: { - /* - tanh=(exp(x)-exp(-x)) / (exp(x)+exp(-x)) - e^(a+bi) = e^a*(cos(b)+sin(b)i) - so tanh=(((cos(b)+sin(b)i)e^a - (cos(-b)+sin(-b)i)e^-a)) / - (((cos(b)+sin(b)i)e^a + (cos(-b)+sin(-b)i)e^-a)) - cos(b)=cos(-b), sin(-b)=-sin(b) - so tanh=(((cos(b)+sin(b)i)e^a - (cos(b)-sin(b)i)e^-a)) / - (((cos(b)+sin(b)i)e^a + (cos(b)-sin(b)i)e^-a)) - =(cos(b)e^a+i*sin(b)e^a + cos(b)(-e^-a)+i*sin(b)e^-a) / - (cos(b)e^a+i*sin(b)e^a + cos(b)e^-a+i*sin(b)(-e^-a)) - =(cos(b)(e^a-e^-a) + i*sin(b)(e^a+e^-a)) / - (cos(b)(e^a+e^-a) + i*sin(b)(e^a-e^-a)) - This is a complex division, so we can multiply by denom_conj/denom_conj - =(cos(b)(e^a-e^-a) + i*sin(b)(e^a+e^-a)) * - (cos(b)(e^a+e^-a) - i*sin(b)(e^a-e^-a)) / - ((cos(b)(e^a+e^-a))^2 + (sin(b)(e^a-e^-a))^2) - =(cos(b)^2(e^(2a)-e^(-2a)) + sin(b)^2(e^(2a)-e^(-2a)) + - i*(cos(b)sin(b)(e^a+e^-a)^2 - cos(b)sin(b)(e^a-e^-a)^2)) / - ((cos(b)(e^a+e^-a))^2 + (sin(b)(e^a-e^-a))^2) - */ - auto a = EmitExtractReal(operand_value); - auto b = EmitExtractImag(operand_value); - TF_ASSIGN_OR_RETURN( - auto exp_a, EmitLibdeviceMathCall("__nv_exp", {a}, {component_type}, - component_type)); - TF_ASSIGN_OR_RETURN( - auto cos_b, EmitLibdeviceMathCall("__nv_cos", {b}, {component_type}, - component_type)); - TF_ASSIGN_OR_RETURN( - auto sin_b, EmitLibdeviceMathCall("__nv_sin", {b}, {component_type}, - component_type)); - auto exp_neg_a = ir_builder_->CreateFDiv( - llvm::ConstantFP::get(exp_a->getType(), 1), exp_a); - auto exp_2a_minus_exp_neg_2a = ir_builder_->CreateFSub( - ir_builder_->CreateFMul(exp_a, exp_a), - ir_builder_->CreateFMul(exp_neg_a, exp_neg_a)); - auto cos_b_sq = ir_builder_->CreateFMul(cos_b, cos_b); - auto sin_b_sq = ir_builder_->CreateFMul(sin_b, sin_b); - auto real_num = ir_builder_->CreateFAdd( - ir_builder_->CreateFMul(cos_b_sq, exp_2a_minus_exp_neg_2a), - ir_builder_->CreateFMul(sin_b_sq, exp_2a_minus_exp_neg_2a)); - auto cos_b_sin_b = ir_builder_->CreateFMul(cos_b, sin_b); - auto exp_a_plus_exp_neg_a = ir_builder_->CreateFAdd(exp_a, exp_neg_a); - auto exp_a_plus_exp_neg_a_sq = - ir_builder_->CreateFMul(exp_a_plus_exp_neg_a, exp_a_plus_exp_neg_a); - auto exp_a_minus_exp_neg_a = ir_builder_->CreateFSub(exp_a, exp_neg_a); - auto exp_a_minus_exp_neg_a_sq = - ir_builder_->CreateFMul(exp_a_minus_exp_neg_a, exp_a_minus_exp_neg_a); - auto imag_num = ir_builder_->CreateFMul( - cos_b_sin_b, ir_builder_->CreateFSub(exp_a_plus_exp_neg_a_sq, - exp_a_minus_exp_neg_a_sq)); - auto denom = ir_builder_->CreateFAdd( - ir_builder_->CreateFMul(cos_b_sq, exp_a_plus_exp_neg_a_sq), - ir_builder_->CreateFMul(sin_b_sq, exp_a_minus_exp_neg_a_sq)); - return EmitComposeComplex(op, ir_builder_->CreateFDiv(real_num, denom), - ir_builder_->CreateFDiv(imag_num, denom)); - } - default: - return ElementalIrEmitter::EmitComplexUnaryOp(op, operand_value); - } -} - llvm::Value* GpuElementalIrEmitter::EmitDeviceFunctionCall( const string& callee_name, tensorflow::gtl::ArraySlice operands, diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h index 6a537d015209bc507af36b13eeb5d69ce58d8fea..77d4569b1e8e398005e8f517ff086a77aedd382d 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h @@ -54,20 +54,31 @@ class GpuElementalIrEmitter : public ElementalIrEmitter { StatusOr EmitFloatUnaryOp( const HloInstruction* op, llvm::Value* operand_value) const override; - StatusOr EmitComplexUnaryOp( - const HloInstruction* op, llvm::Value* operand_value) const override; - StatusOr EmitFloatBinaryOp( const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) const override; - StatusOr EmitComplexBinaryOp( - const HloInstruction* op, llvm::Value* lhs_value, - llvm::Value* rhs_value) const override; - StatusOr EmitErfcInv(PrimitiveType prim_type, llvm::Value* value) const override; + StatusOr EmitLog(PrimitiveType prim_type, + llvm::Value* value) const override; + + StatusOr EmitSin(PrimitiveType prim_type, + llvm::Value* value) const override; + + StatusOr EmitCos(PrimitiveType prim_type, + llvm::Value* value) const override; + + StatusOr EmitExp(PrimitiveType prim_type, + llvm::Value* value) const override; + + StatusOr EmitPow(PrimitiveType prim_type, llvm::Value* lhs, + llvm::Value* rhs) const override; + + StatusOr EmitAtan2(PrimitiveType prim_type, llvm::Value* lhs, + llvm::Value* rhs) const override; + llvm::Value* EmitThreadId() const override; private: diff --git a/tensorflow/compiler/xla/service/gpu/fft_thunk.cc b/tensorflow/compiler/xla/service/gpu/fft_thunk.cc new file mode 100644 index 0000000000000000000000000000000000000000..66931bdc8b1030b2b2e7731ce6327c1e908d4ee6 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/fft_thunk.cc @@ -0,0 +1,234 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/fft_thunk.h" + +#include + +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace se = ::perftools::gputools; + +namespace xla { +namespace gpu { + +FftScratchAllocator::FftScratchAllocator( + int device_ordinal, DeviceMemoryAllocator* memory_allocator) + : device_ordinal_(device_ordinal), memory_allocator_(memory_allocator) {} + +FftScratchAllocator::~FftScratchAllocator() { + for (auto& allocated_buffer : allocated_buffers_) { + if (!memory_allocator_->Deallocate(device_ordinal_, &allocated_buffer) + .ok()) { + // The program can still continue with failed deallocation. + LOG(ERROR) << "Failed to deallocate the allocated buffer: " + << allocated_buffer.opaque(); + } + } +} + +int64 FftScratchAllocator::GetMemoryLimitInBytes(se::Stream* stream) { + constexpr int64 kFftScratchSize = 1LL << 32; // 4GB by default. + return kFftScratchSize; +} + +se::port::StatusOr> FftScratchAllocator::AllocateBytes( + se::Stream* stream, int64 byte_size) { + CHECK_GE(byte_size, 0) << "byte_size must be positive."; + if (byte_size > GetMemoryLimitInBytes(stream)) { + return se::port::Status( + se::port::error::RESOURCE_EXHAUSTED, + tensorflow::strings::Printf( + "Allocating %lld bytes exceeds the memory limit of %lld bytes.", + byte_size, GetMemoryLimitInBytes(stream))); + } + + auto status_or_memory = + memory_allocator_->Allocate(device_ordinal_, byte_size, + /*retry_on_failure=*/false); + if (!status_or_memory.ok()) { + return tensorflow::errors::ResourceExhausted( + "Failed to allocate %lld bytes on device %d.", byte_size, + device_ordinal_); + } + se::DeviceMemoryBase allocated_buffer = status_or_memory.ValueOrDie(); + allocated_buffers_.push_back(allocated_buffer); + total_allocated_bytes_ += byte_size; + return se::DeviceMemory(allocated_buffer); +} + +namespace { + +se::fft::Type FftTypeToSeType(FftType type) { + switch (type) { + case FftType::FFT: + return se::fft::Type::kC2CForward; + case FftType::IFFT: + return se::fft::Type::kC2CInverse; + case FftType::IRFFT: + return se::fft::Type::kC2R; + case FftType::RFFT: + return se::fft::Type::kR2C; + default: + LOG(FATAL) << "unsupported fft type"; + } +} + +string FftTypeToString(se::fft::Type type) { + switch (type) { + case se::fft::Type::kC2CForward: + return "FFT"; + case se::fft::Type::kC2CInverse: + return "IFFT"; + case se::fft::Type::kC2R: + return "IRFFT"; + case se::fft::Type::kR2C: + return "RFFT"; + default: + LOG(FATAL) << "unknown fft type"; + } +} + +} // namespace + +FftThunk::FftThunk(FftType fft_type, + tensorflow::gtl::ArraySlice fft_length, + const BufferAllocation::Slice& input_buffer, + const BufferAllocation::Slice& output_buffer, + const Shape& input_shape, const Shape& output_shape, + const HloInstruction* hlo) + : Thunk(Kind::kFft, hlo), + fft_type_(FftTypeToSeType(fft_type)), + fft_length_(fft_length.begin(), fft_length.end()), + scale_factor_(1.0f), + input_buffer_(input_buffer), + output_buffer_(output_buffer), + input_shape_(input_shape), + output_shape_(output_shape) {} + +tensorflow::Status FftThunk::ExecuteOnStream( + const BufferAllocations& buffer_allocations, se::Stream* stream) { + VLOG(3) << "FFT type: " << FftTypeToString(fft_type_); + VLOG(3) << "Input shape: " << ShapeUtil::HumanStringWithLayout(input_shape_); + VLOG(3) << "Output shape: " + << ShapeUtil::HumanStringWithLayout(output_shape_); + + FftScratchAllocator scratch_allocator(buffer_allocations.device_ordinal(), + buffer_allocations.memory_allocator()); + + if (fft_plan_ == nullptr) { + const int64 fft_rank = fft_length_.size(); + CHECK_LE(fft_rank, 3); + int batch_size = 1; + for (int i = 0; i < input_shape_.dimensions_size() - fft_rank; ++i) { + batch_size *= input_shape_.dimensions(i); + } + uint64 fft_length[3]; + uint64 input_embed[3]; + const uint64 input_stride = 1; + uint64 input_distance = 1; + uint64 output_embed[3]; + const uint64 output_stride = 1; + uint64 output_distance = 1; + + for (int i = 0; i < fft_rank; ++i) { + auto dim_offset = input_shape_.dimensions_size() - fft_rank + i; + fft_length[i] = static_cast(fft_length_[i]); + input_embed[i] = input_shape_.dimensions(dim_offset); + input_distance *= input_shape_.dimensions(dim_offset); + output_embed[i] = output_shape_.dimensions(dim_offset); + output_distance *= output_shape_.dimensions(dim_offset); + } + + constexpr bool kInPlaceFft = false; + fft_plan_ = + stream->parent()->AsFft()->CreateBatchedPlanWithScratchAllocator( + stream, fft_rank, fft_length, input_embed, input_stride, + input_distance, output_embed, output_stride, output_distance, + fft_type_, kInPlaceFft, batch_size, &scratch_allocator); + scale_factor_ = 1.0f / output_distance; + } else { + stream->parent()->AsFft()->UpdatePlanWithScratchAllocator( + stream, fft_plan_.get(), &scratch_allocator); + } + + bool launch_ok; + switch (fft_type_) { + case se::fft::Type::kC2CForward: { + se::DeviceMemory input_data( + buffer_allocations.GetDeviceAddress(input_buffer_)); + se::DeviceMemory output_data( + buffer_allocations.GetDeviceAddress(output_buffer_)); + launch_ok = + stream->ThenFft(fft_plan_.get(), input_data, &output_data).ok(); + break; + } + case se::fft::Type::kC2CInverse: { + se::DeviceMemory input_data( + buffer_allocations.GetDeviceAddress(input_buffer_)); + se::DeviceMemory output_data( + buffer_allocations.GetDeviceAddress(output_buffer_)); + launch_ok = + stream->ThenFft(fft_plan_.get(), input_data, &output_data).ok(); + if (launch_ok) { + launch_ok = + stream + ->ThenBlasScal(ShapeUtil::ElementsIn(output_shape_), + complex64(scale_factor_), &output_data, 1) + .ok(); + } + break; + } + case se::fft::Type::kR2C: { + se::DeviceMemory input_data( + buffer_allocations.GetDeviceAddress(input_buffer_)); + se::DeviceMemory output_data( + buffer_allocations.GetDeviceAddress(output_buffer_)); + launch_ok = + stream->ThenFft(fft_plan_.get(), input_data, &output_data).ok(); + break; + } + case se::fft::Type::kC2R: { + se::DeviceMemory input_data( + buffer_allocations.GetDeviceAddress(input_buffer_)); + se::DeviceMemory output_data( + buffer_allocations.GetDeviceAddress(output_buffer_)); + launch_ok = + stream->ThenFft(fft_plan_.get(), input_data, &output_data).ok(); + if (launch_ok) { + launch_ok = stream + ->ThenBlasScal(ShapeUtil::ElementsIn(output_shape_), + scale_factor_, &output_data, 1) + .ok(); + } + break; + } + default: + LOG(FATAL) << "unsupported fft type"; + } + if (launch_ok) { + return tensorflow::Status::OK(); + } + return InternalError("Unable to launch fft for thunk %p with type %s", this, + FftTypeToString(fft_type_).c_str()); +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/fft_thunk.h b/tensorflow/compiler/xla/service/gpu/fft_thunk.h new file mode 100644 index 0000000000000000000000000000000000000000..52fb8c376d7acea0f15aaa865c23fa2382717338 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/fft_thunk.h @@ -0,0 +1,98 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FFT_THUNK_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FFT_THUNK_H_ + +#include "tensorflow/compiler/xla/service/buffer_assignment.h" +#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" +#include "tensorflow/compiler/xla/service/gpu/thunk.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/optional.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace xla { +namespace gpu { + +// A one-time scratch allocator for FFT. The scratch buffers allocated are +// released on destruction. +// +// Not thread-safe in that AllocateBytes, destructor are not locked. +class FftScratchAllocator : public perftools::gputools::ScratchAllocator { + public: + FftScratchAllocator(int device_ordinal, + DeviceMemoryAllocator* memory_allocator); + + ~FftScratchAllocator() override; + + int64 GetMemoryLimitInBytes(perftools::gputools::Stream* stream) override; + + int64 TotalAllocatedBytes() { return total_allocated_bytes_; } + + perftools::gputools::port::StatusOr> + AllocateBytes(perftools::gputools::Stream* stream, int64 byte_size) override; + + private: + const int device_ordinal_; + DeviceMemoryAllocator* memory_allocator_; + std::vector allocated_buffers_; + int64 total_allocated_bytes_ = 0; +}; + +// This class stores everything that StreamExecutor needs to launch an FFT. +// It is generated by IrEmitter. +// +// This is thread-compatible. +class FftThunk : public Thunk { + public: + // Constructs a thunk for launching an FFT on a stream. + // Semantics of null hlo_instruction argument are as in Thunk. + FftThunk(FftType fft_type, tensorflow::gtl::ArraySlice fft_length, + const BufferAllocation::Slice& input_buffer, + const BufferAllocation::Slice& output_buffer, + const Shape& input_shape, const Shape& output_shape, + const HloInstruction* hlo); + + FftThunk(const FftThunk&) = delete; // Cannot share fft_plan_ + FftThunk& operator=(const FftThunk&) = delete; // Cannot share fft_plan_ + + // Does the FFT for the thunk on "stream". + tensorflow::Status ExecuteOnStream( + const BufferAllocations& buffer_allocations, + perftools::gputools::Stream* stream) override; + + private: + const perftools::gputools::fft::Type fft_type_; + const std::vector fft_length_; + + float scale_factor_; + + std::unique_ptr fft_plan_; + + const BufferAllocation::Slice input_buffer_; + const BufferAllocation::Slice output_buffer_; + + const Shape input_shape_; + const Shape output_shape_; +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FFT_THUNK_H_ diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc index e784046450ed1cca088770c65c786e80adda869f..8e3aebbc12b5e6d746700956b9743bc94db50167 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc @@ -264,9 +264,9 @@ tensorflow::Status GemmThunk::ExecuteOnStream( auto make_descriptor = [this](se::DeviceMemoryBase data, const Shape& shape, bool transpose) -> MatrixDescriptor { - bool is_row_major = shape.layout().minor_to_major(0) != 0; - bool layout_mismatch = shape.layout().minor_to_major(0) != - output_shape_.layout().minor_to_major(0); + bool is_row_major = LayoutUtil::Minor(shape.layout(), 0) != 0; + bool layout_mismatch = LayoutUtil::Minor(shape.layout(), 0) != + LayoutUtil::Minor(output_shape_.layout(), 0); return MatrixDescriptor(data, transpose ^ layout_mismatch, shape.dimensions(is_row_major), shape.dimensions(!is_row_major)); @@ -320,7 +320,7 @@ tensorflow::Status GemmThunk::ExecuteOnStream( }; bool launch_ok; - if (output_shape_.layout().minor_to_major(0) == 0) { + if (LayoutUtil::Minor(output_shape_.layout(), 0) == 0) { launch_ok = launch( lhs_descriptor, rhs_descriptor, MatrixDescriptor(output_data, false, output_num_rows, output_num_cols), diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.h b/tensorflow/compiler/xla/service/gpu/gemm_thunk.h index 983cb872924f22be0dfad8aa9ad86f233b909c46..8c6a1f51a8a09ef78950dfe7e89994a3fe247f49 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.h @@ -52,6 +52,15 @@ class GemmThunk : public Thunk { const BufferAllocations& buffer_allocations, perftools::gputools::Stream* stream) override; + // Returns true if we'll perform autotuning if run on the given stream. If + // so, we want the GPU to be quiescent during autotuning, so as not to + // introduce noise in our results. + bool ShouldHaltAllActivityBeforeRunning( + perftools::gputools::Stream* stream) override { + return autotune_results_.count( + stream->parent()->GetDeviceDescription().name()) != 0; + } + private: const BufferAllocation::Slice lhs_buffer_; const BufferAllocation::Slice rhs_buffer_; diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index fcd73fd37a2d9ae3c24b56970e3e992da5944682..89acac2c3ff77a93b6cf3b871a130dcd7edecf30 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -18,30 +18,36 @@ limitations under the License. #include #include #include +#include // NOLINT(build/c++11): only using std::call_once, not mutex. #include #include "llvm/IR/DiagnosticInfo.h" #include "llvm/IR/DiagnosticPrinter.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" +#include "llvm/IR/Verifier.h" #include "tensorflow/compiler/xla/protobuf_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/algebraic_simplifier.h" -#include "tensorflow/compiler/xla/service/batchnorm_rewriter.h" +#include "tensorflow/compiler/xla/service/batchnorm_expander.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/buffer_liveness.h" #include "tensorflow/compiler/xla/service/call_inliner.h" +#include "tensorflow/compiler/xla/service/dot_decomposer.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" #include "tensorflow/compiler/xla/service/gpu/convolution_folding.h" +#include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h" #include "tensorflow/compiler/xla/service/gpu/fusion_merger.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_constants.h" #include "tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h" #include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h" #include "tensorflow/compiler/xla/service/gpu/hlo_schedule.h" #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/gpu/ir_emitter.h" #include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h" -#include "tensorflow/compiler/xla/service/gpu/layout_assignment.h" #include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h" #include "tensorflow/compiler/xla/service/gpu/pad_insertion.h" #include "tensorflow/compiler/xla/service/gpu/partition_assignment.h" @@ -64,6 +70,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/transpose_folding.h" #include "tensorflow/compiler/xla/service/tuple_simplifier.h" #include "tensorflow/compiler/xla/service/while_loop_simplifier.h" +#include "tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" @@ -74,9 +81,11 @@ limitations under the License. #include "tensorflow/core/platform/cuda_libdevice_path.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/regexp.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/subprocess.h" #include "tensorflow/core/platform/tracing.h" +#include "tensorflow/stream_executor/cuda/cuda_diagnostics.h" namespace se = ::perftools::gputools; @@ -90,14 +99,6 @@ namespace gpu { namespace { using tensorflow::port::Tracing; -using tensorflow::strings::StrCat; - -// Any address of a variable residing in global memory or returned by one of the -// memory allocation routines from the driver or runtime API is always aligned -// to at least 256 bytes. -// -// http://docs.nvidia.com/cuda/cuda-c-programming-guide/#device-memory-accesses -constexpr int64 kMemoryAlignment = 256; // Returns the directory containing nvvm libdevice files. config_cuda_data_dir // should be equal to config().debug_options().xla_gpu_cuda_data_dir() of the @@ -125,31 +126,39 @@ string GetLibdeviceDir(const string& config_cuda_data_dir) { } // Runs optimization passes on the given HLO module. -tensorflow::Status OptimizeHloModule( - HloModule* hlo_module, - const HloCostAnalysis::ShapeSizeFunction& shape_size_function) { +tensorflow::Status OptimizeHloModule(HloModule* hlo_module) { { HloPassPipeline pipeline("optimization"); - pipeline.AddInvariantChecker(shape_size_function); + pipeline.AddInvariantChecker(); + pipeline.AddPass(); ReducePrecisionInsertion::AddPasses( &pipeline, hlo_module->config().debug_options(), ReducePrecisionInsertion::PassTiming::BEFORE_OPTIMIZATION); // TODO(b/64094172): make Call work on GPU instead of inlining. pipeline.AddPass(); - + pipeline.AddPass(); { auto& pass = pipeline.AddPass>("simplification"); - pass.AddInvariantChecker(shape_size_function); + pass.AddInvariantChecker(); - // TODO(b/62764704): Do not rewrite on GPU, use cuDNN's BatchNorm APIs - // instead. - pass.AddPass( + // If cudnn batchnorms are enabled, rewrite batchnorm HLOs to cudnn calls + // where possible. Not every batchnorm op can be implemented as a call to + // cudnn, so decompose any remaining batchnorm ops into a soup of HLOs. + if (hlo_module->config().debug_options().xla_gpu_use_cudnn_batchnorm()) { + pass.AddPass(); + } + pass.AddPass( /*rewrite_training_op=*/true, /*rewrite_inference_op=*/true, /*rewrite_grad_op=*/true, /*use_fusion=*/false); + + // BatchNormExpander can create zero-sized ops, so zero-sized HLO + // elimination has to come after that pass. + pipeline.AddPass(); + pass.AddPass( /*is_layout_sensitive=*/false, [](const Shape&, const Shape&) { return false; }); @@ -173,14 +182,14 @@ tensorflow::Status OptimizeHloModule( } { HloPassFix fusion("fusion"); - fusion.AddInvariantChecker(shape_size_function); + fusion.AddInvariantChecker(); fusion.AddPass(/*may_duplicate=*/false); fusion.AddPass(/*may_duplicate=*/true); fusion.AddPass(); TF_RETURN_IF_ERROR(fusion.Run(hlo_module).status()); HloPassPipeline reduce_pipeline("reduce-precision"); - reduce_pipeline.AddInvariantChecker(shape_size_function); + reduce_pipeline.AddInvariantChecker(); ReducePrecisionInsertion::AddPasses( &reduce_pipeline, hlo_module->config().debug_options(), ReducePrecisionInsertion::PassTiming::AFTER_FUSION); @@ -198,16 +207,14 @@ tensorflow::Status OptimizeHloModule( // Modifies the given HLO module so that it will be accepted by IrEmitter. // Unlike optimization passes, the passes are necessary for correctness. -tensorflow::Status PrepareHloModuleForIrEmitting( - HloModule* hlo_module, - const HloCostAnalysis::ShapeSizeFunction& shape_size_function) { +tensorflow::Status PrepareHloModuleForIrEmitting(HloModule* hlo_module) { // In some cases, we have to place the result of an instruction in a temporary // buffer. For instance, the buffer that holds an external parameter is // assumed immutable at this point, and should not be reused for output // (b/27180329). Therefore, in that case, we set the output to be a copy of // the parameter. HloPassPipeline pipeline("GPU-ir-emit-prepare"); - pipeline.AddInvariantChecker(shape_size_function); + pipeline.AddInvariantChecker(); pipeline.AddPass(); pipeline.AddPass( hlo_module->mutable_entry_computation_layout()); @@ -229,6 +236,93 @@ tensorflow::Status PrepareHloModuleForIrEmitting( return pipeline.Run(hlo_module).status(); } +// Prints a warning if the ptxas at ptxas_path has known bugs. +// +// Only prints a warning the first time it's called for a particular value of +// ptxas_path. +void WarnIfBadPtxasVersion(const string& ptxas_path) { + static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED); + static std::unordered_set* seen_ptxas_paths GUARDED_BY(mu) = + new std::unordered_set(); + + tensorflow::mutex_lock lock(mu); + if (!seen_ptxas_paths->insert(ptxas_path).second) { + // Already checked this ptx binary, nothing to do. + return; + } + + tensorflow::SubProcess ptxas; + ptxas.SetProgram(ptxas_path, {ptxas_path, "--version"}); + ptxas.SetChannelAction(tensorflow::CHAN_STDOUT, tensorflow::ACTION_PIPE); + if (!ptxas.Start()) { + LOG(WARNING) << "Couldn't invoke " << ptxas_path << " --version"; + return; + } + + string out; + int exit_code = ptxas.Communicate(/*stdin_input=*/nullptr, &out, + /*stderr_output=*/nullptr); + if (exit_code != 0) { + LOG(WARNING) << "Running " << ptxas_path << " --version returned " + << exit_code; + return; + } + + int64 vmaj, vmin, vdot; + string vmaj_str, vmin_str, vdot_str; + if (!RE2::PartialMatch(out, R"(\bV(\d+)\.(\d+)\.(\d+)\b)", &vmaj_str, + &vmin_str, &vdot_str) || + !tensorflow::strings::safe_strto64(vmaj_str, &vmaj) || + !tensorflow::strings::safe_strto64(vmin_str, &vmin) || + !tensorflow::strings::safe_strto64(vdot_str, &vdot)) { + LOG(WARNING) << "Couldn't parse ptxas version in output of " << ptxas_path + << " --version:\n" + << out; + return; + } + + // ptxas 9.0 before 9.0.276 miscompiles some address calculations with large + // offsets (e.g. "load ptr + large_constant"), b/70245379. + if (vmaj == 9 && vmin == 0 && vdot < 276) { + LOG(WARNING) << "*** WARNING *** You are using ptxas " << vmaj << "." + << vmin << "." << vdot + << ", which is in range [9.0.0, 9.0.276). These versions are " + "known to miscompile XLA code, leading to incorrect " + "results or invalid-address errors."; + } +} + +// Prints a warning if the ptx->sass JIT in the driver has known bugs. +// +// Using such a driver only a problem if we fail to use ptxas to compile our ptx +// and have to use the driver instead, so you should only call this function if +// we're going to use the driver JIT. +// +// Only prints a warning the first time it's called. +void WarnIfBadDriverJITVersion() { + static std::once_flag run_once; + std::call_once(run_once, [] { + auto version_or_status = se::cuda::Diagnostician::FindKernelDriverVersion(); + if (!version_or_status.ok()) { + LOG(WARNING) << "Couldn't read CUDA driver version."; + return; + } + se::cuda::DriverVersion version = version_or_status.ValueOrDie(); + + // The driver JIT in 384 before 384.108 miscompiles some address + // calculations with large offsets (e.g. "load ptr + large_constant"), + // b/70245379. + if (std::get<0>(version) == 384 && std::get<1>(version) < 108) { + LOG(WARNING) + << "*** WARNING *** Invoking the PTX->SASS JIT from driver version " + << se::cuda::DriverVersionToString(version) + << ", which is in range [384.0.0, 384.108.0). These versions are " + "known to miscompile XLA code, leading to incorrect results or " + "invalid-address errors."; + } + }); +} + // Compiles the given PTX string using ptxas and returns the resulting machine // code (i.e. a cubin) as a byte array. StatusOr> CompilePtx(const string& ptx, int cc_major, @@ -240,6 +334,8 @@ StatusOr> CompilePtx(const string& ptx, int cc_major, auto env = tensorflow::Env::Default(); TF_RETURN_IF_ERROR(env->FileExists(ptxas_path)); + WarnIfBadPtxasVersion(ptxas_path); + // Write ptx into a temporary file. string ptx_path; if (!env->LocalTempFilename(&ptx_path)) { @@ -263,8 +359,9 @@ StatusOr> CompilePtx(const string& ptx, int cc_major, tensorflow::Env::Default()->DeleteFile(cubin_path).IgnoreError(); }); tensorflow::SubProcess ptxas_info_dumper; - std::vector ptxas_args = {ptxas_path, ptx_path, "-o", cubin_path, - StrCat("-arch=sm_", cc_major, cc_minor)}; + std::vector ptxas_args = { + ptxas_path, ptx_path, "-o", cubin_path, + tensorflow::strings::StrCat("-arch=sm_", cc_major, cc_minor)}; if (VLOG_IS_ON(2)) { ptxas_args.push_back("-v"); } @@ -294,14 +391,15 @@ StatusOr> CompilePtx(const string& ptx, int cc_major, } // namespace GpuCompiler::GpuCompiler() - : pointer_size_(llvm::DataLayout(kDataLayout).getPointerSize()) {} + : pointer_size_(llvm::DataLayout(kDataLayout) + .getPointerSize(0 /* default address space */)) {} StatusOr> GpuCompiler::RunHloPasses( std::unique_ptr module, se::StreamExecutor* /*stream_exec*/) { XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunHloPasses"); Tracing::TraceMe annotation("HLO Transforms", module->name(), /*is_expensive=*/true); - TF_RETURN_IF_ERROR(OptimizeHloModule(module.get(), ShapeSizeBytesFunction())); + TF_RETURN_IF_ERROR(OptimizeHloModule(module.get())); return std::move(module); } @@ -311,8 +409,7 @@ StatusOr> GpuCompiler::RunBackend( TF_RET_CHECK(stream_exec != nullptr); - TF_RETURN_IF_ERROR( - PrepareHloModuleForIrEmitting(module.get(), ShapeSizeBytesFunction())); + TF_RETURN_IF_ERROR(PrepareHloModuleForIrEmitting(module.get())); llvm::LLVMContext llvm_context; std::string buffer; @@ -343,8 +440,9 @@ StatusOr> GpuCompiler::RunBackend( TF_ASSIGN_OR_RETURN( std::unique_ptr buffer_assignment, BufferAssigner::Run(module.get(), hlo_schedule->ConsumeHloOrdering(), - BufferSizeBytesFunction(), [](LogicalBuffer::Color) { - return kMemoryAlignment; + BufferSizeBytesFunction(), + /*color_alignment=*/[](LogicalBuffer::Color) { + return kCudaMallocAlignBytes; })); // BufferAssignment::ToString() includes a header, so no need for us to // print one ourselves. @@ -393,6 +491,20 @@ StatusOr> GpuCompiler::RunBackend( /*optimized=*/false)); } + { + XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunBackend - Running LLVM verifier"); + + std::string err; + llvm::raw_string_ostream err_stream(err); + + // verifyModule() returns true if the module is broken. + TF_RET_CHECK(!llvm::verifyModule(llvm_module, &err_stream)) + << "Invalid LLVM IR before optimizations:\n" + << err_stream.str() + << "\nThis probably indicates a bug in the HLO -> LLVM IR lowering. " + "Rerun with --xla_dump_ir_to to get the IR. "; + } + string libdevice_dir; { tensorflow::mutex_lock lock(mutex_); @@ -443,7 +555,7 @@ StatusOr> GpuCompiler::RunBackend( // Write PTX to IR dump directory, if IR dumping was requested. if (!ir_dump_directory.empty()) { const string ptx_outfile = tensorflow::io::JoinPath( - ir_dump_directory, StrCat(module->name(), ".ptx")); + ir_dump_directory, tensorflow::strings::StrCat(module->name(), ".ptx")); auto status = [&] { auto* env = tensorflow::Env::Default(); TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(ir_dump_directory)); @@ -470,6 +582,7 @@ StatusOr> GpuCompiler::RunBackend( if (module->config().hlo_profiling_enabled()) { HloCostAnalysis cost_analysis(ShapeSizeBytesFunction()); + TF_RETURN_IF_ERROR(module->entry_computation()->Accept(&cost_analysis)); profile_index_map = MakeUnique(*module); profile_printer = CreateHloProfilePrinter(*profile_index_map, cost_analysis); @@ -541,6 +654,10 @@ std::vector GpuCompiler::CompilePtxOrGetCachedResult(const string& ptx, "GPU driver compile the ptx. " << maybe_cubin.status(); } + + // We're going to use the driver to JIT our PTX->SASS, so warn if + // the JIT in the driver has known bugs. + WarnIfBadDriverJITVersion(); } } cache_value->compilation_done = true; diff --git a/tensorflow/compiler/xla/service/gpu/gpu_constants.cc b/tensorflow/compiler/xla/service/gpu/gpu_constants.cc new file mode 100644 index 0000000000000000000000000000000000000000..aa360c7f73de2f0f9cf59c22b552b8e60ddb3a87 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/gpu_constants.cc @@ -0,0 +1,25 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/gpu_constants.h" + +namespace xla { +namespace gpu { + +// http://docs.nvidia.com/cuda/cuda-c-programming-guide/#device-memory-accesses +const int64 kCudaMallocAlignBytes = 256; + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/gpu_constants.h b/tensorflow/compiler/xla/service/gpu/gpu_constants.h new file mode 100644 index 0000000000000000000000000000000000000000..572c85628278752f924b90dbb7134c5fc8fb9740 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/gpu_constants.h @@ -0,0 +1,31 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_CONSTANTS_H_ +#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_CONSTANTS_H_ + +#include "tensorflow/compiler/xla/types.h" + +namespace xla { +namespace gpu { + +// Minimum alignment of cudaMalloc. We require that buffers created by our +// DeviceMemoryAllocator, and all input/output buffers, have this alignment. +extern const int64 kCudaMallocAlignBytes; + +} // namespace gpu +} // namespace xla + +#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_CONSTANTS_H_ diff --git a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc index 33d739b79d3664fec3586bbc924b7fa2e10d3256..e67087d822e2f3367c48b08be66f5f60791be638 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc @@ -55,20 +55,33 @@ StatusOr GpuCopyInsertion::Run(HloModule* module) { // in IR. for (HloInstruction* hlo : module->entry_computation()->MakeInstructionPostOrder()) { - if (ImplementedAsLibraryCall(*hlo)) { + // Inserts a copy of hlo->operand(n) if it's a constant. + auto copy_operand_if_constant = [&](int64 n) -> Status { + HloInstruction* operand = hlo->mutable_operand(n); + TF_RET_CHECK(ShapeUtil::IsArray(operand->shape())); + const auto& values = dataflow->GetValueSet(operand).values(); + if (std::any_of(values.begin(), values.end(), [](const HloValue* value) { + return value->defining_instruction()->opcode() == + HloOpcode::kConstant; + })) { + TF_ASSIGN_OR_RETURN(HloInstruction * copy, FindOrInsertCopy(operand)); + TF_RETURN_IF_ERROR(hlo->ReplaceOperandWith(n, copy)); + changed = true; + } + return Status::OK(); + }; + + if (IsCustomCallToDnnBatchNorm(*hlo)) { + // The epsilon and feature_index operands to a CUDNN batchnorm op don't + // need to be materialized in memory -- in fact, they must be constants. + // These are the last two operands of all three batchnorm ops. + for (int64 i = 0; i < hlo->operand_count() - 2; ++i) { + TF_RETURN_IF_ERROR(copy_operand_if_constant(i)); + } + } else if (ImplementedAsLibraryCall(*hlo)) { + // For all other library calls, materialize all the operands into memory. for (int64 i = 0; i < hlo->operand_count(); ++i) { - HloInstruction* operand = hlo->mutable_operand(i); - TF_RET_CHECK(ShapeUtil::IsArray(operand->shape())); - const auto& values = dataflow->GetValueSet(operand).values(); - if (std::any_of(values.begin(), values.end(), - [](const HloValue* value) { - return value->defining_instruction()->opcode() == - HloOpcode::kConstant; - })) { - TF_ASSIGN_OR_RETURN(HloInstruction * copy, FindOrInsertCopy(operand)); - TF_RETURN_IF_ERROR(hlo->ReplaceOperandWith(i, copy)); - changed = true; - } + TF_RETURN_IF_ERROR(copy_operand_if_constant(i)); } } } diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc index 0fd85e4fb057f144df93d53485570d67c66af0d4..51d164cdf427f9513bc340e090832a9b064b999c 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc @@ -66,10 +66,12 @@ class HloExecutionProfiler { // If profiling is enabled, sets the total cycle count on the profile from the // execution timer. - ~HloExecutionProfiler() { + void FinishExecution() { + CHECK(!finished_execution_) << "Call FinishExecution only once!"; + finished_execution_ = true; if (do_profile_) { stream_->ThenStopTimer(execution_timer_.get()); - stream_->BlockHostUntilDone(); + stream_->BlockHostUntilDone().IgnoreError(); profile_->set_total_cycles_executed( *computation_, execution_timer_->Nanoseconds() * clock_rate_ghz_); } @@ -87,7 +89,7 @@ class HloExecutionProfiler { void FinishOperation(const HloInstruction* hlo_instruction) { if (do_profile_) { stream_->ThenStopTimer(per_op_timer_.get()); - stream_->BlockHostUntilDone(); + stream_->BlockHostUntilDone().IgnoreError(); profile_->SetCyclesTakenBy( hlo_instruction, per_op_timer_->Nanoseconds() * clock_rate_ghz_); } @@ -101,6 +103,7 @@ class HloExecutionProfiler { const HloComputation* computation_; std::unique_ptr execution_timer_; std::unique_ptr per_op_timer_; + bool finished_execution_ = false; }; } // namespace @@ -143,9 +146,12 @@ Status GpuExecutable::ExecuteThunks( if (do_profile) { LOG(WARNING) << "PROFILING: profiling is enabled"; } + HloExecutionProfiler profiler(do_profile, hlo_execution_profile, main_stream, hlo_module_->entry_computation()); + uint64 start_micros = tensorflow::Env::Default()->NowMicros(); + // Stream 0 indicates `main_stream` and substreams start from stream 1. std::vector::SmartPtr> sub_streams; while (sub_streams.size() + 1 < thunk_schedule_->StreamCount()) { @@ -155,6 +161,9 @@ Status GpuExecutable::ExecuteThunks( run_options->BorrowStream(main_stream->parent()->device_ordinal())); } + // The next event enqueued on stream N must not run until the thunk at + // last_blocking_thunk_for_stream[N] completes. + std::map last_blocking_thunk_for_stream; std::map> thunk_to_finish_event; for (Thunk* thunk : thunk_schedule_->TotalOrder()) { TF_RETURN_IF_ERROR(thunk->Initialize(*this)); @@ -167,15 +176,41 @@ Status GpuExecutable::ExecuteThunks( stream->ThenWaitFor(FindOrDie(thunk_to_finish_event, dependency).get()); } + if (last_blocking_thunk_for_stream.count(stream_no)) { + stream->ThenWaitFor(FindOrDie(thunk_to_finish_event, + last_blocking_thunk_for_stream[stream_no]) + .get()); + last_blocking_thunk_for_stream.erase(stream_no); + } + + // If this thunk requests it, wait for all currently-executing thunks to + // finish. This is useful e.g. if the thunk is about to perform autotuning. + if (thunk->ShouldHaltAllActivityBeforeRunning(stream)) { + TF_RETURN_IF_ERROR(main_stream->BlockHostUntilDone()); + last_blocking_thunk_for_stream.clear(); + } + profiler.StartOperation(); VLOG(2) << "Executing the thunk for " - << thunk->hlo_instruction()->ToString(); + << thunk->hlo_instruction()->ToString() << " on stream " + << stream_no; TF_RETURN_IF_ERROR(thunk->ExecuteOnStream(buffer_allocations, stream)); - if (thunk_schedule_->Depended(thunk)) { + if (thunk_schedule_->Depended(thunk) || thunk->ShouldBlockFutureThunks()) { auto finish_event = MakeUnique(main_stream->parent()); finish_event->Init(); stream->ThenRecordEvent(finish_event.get()); thunk_to_finish_event[thunk] = std::move(finish_event); + + if (thunk->ShouldBlockFutureThunks()) { + // Set last_blocking_thunk_for_stream on all streams other than this one + // so that all other streams will wait for this thunk to complete before + // executing any events that occur later in the total order. + for (int32 i = 0; i < sub_streams.size() + 1; ++i) { + if (i != stream_no) { + last_blocking_thunk_for_stream[i] = thunk; + } + } + } } profiler.FinishOperation(thunk->hlo_instruction()); } @@ -184,90 +219,32 @@ Status GpuExecutable::ExecuteThunks( // Make sure kernels are completed before deallocating temporary buffers. // TODO(b/30100571): we could potentially postpone deallocating the temp // buffers until a different computation is executed. - if (block_host_until_done && !main_stream->BlockHostUntilDone()) { - return InternalError("Failed to complete all kernels launched on stream %p", - main_stream); + if (block_host_until_done) { + Status block_status = main_stream->BlockHostUntilDone(); + if (!block_status.ok()) { + return InternalError( + "Failed to complete all kernels launched on stream %p: %s", + main_stream, block_status.error_message().c_str()); + } } - return Status::OK(); -} + profiler.FinishExecution(); + uint64 end_micros = tensorflow::Env::Default()->NowMicros(); -StatusOr GpuExecutable::ExecuteOnStream( - const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments, - HloExecutionProfile* hlo_execution_profile) { - se::Stream* stream = run_options->stream(); - DeviceMemoryAllocator* memory_allocator = run_options->allocator(); + { + tensorflow::mutex_lock lock(mutex_); + const double nanoseconds = (end_micros - start_micros) * 1000.0; + execution_profile_.set_compute_time_ns(std::max(nanoseconds, 1.0)); - BufferAllocations::Builder buffer_allocations_builder; - for (BufferAllocation::Index i = 0; i < assignment_->Allocations().size(); - ++i) { - const BufferAllocation& allocation = assignment_->GetAllocation(i); - if (allocation.is_entry_computation_parameter()) { - buffer_allocations_builder.RegisterBuffer( - i, arguments[allocation.parameter_number()]); + // If hlo profiling was disabled then the cycle count is left empty. + if (do_profile) { + execution_profile_.set_compute_cycle_count( + hlo_execution_profile->total_cycles_executed( + *module().entry_computation())); } } - se::StreamExecutor* executor = stream->parent(); - TF_ASSIGN_OR_RETURN( - auto buffer_allocations, - buffer_allocations_builder.Build(*assignment_, executor->device_ordinal(), - memory_allocator)); - bool block_host_until_done = - !memory_allocator->AllowsAsynchronousDeallocation(); - TF_RETURN_IF_ERROR(ExecuteThunks(run_options, *buffer_allocations, - block_host_until_done, - hlo_execution_profile)); - - HloInstruction* root = hlo_module_->entry_computation()->root_instruction(); - TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice output_slice, - assignment_->GetUniqueTopLevelOutputSlice()); - se::DeviceMemoryBase output_buffer_address = - buffer_allocations->GetDeviceAddress(output_slice.index()); - - if (ShapeUtil::IsTuple(root->shape())) { - std::set referred_by_output; - if (GetRootPointsToSet().IsAmbiguous()) { - // The points-to set of the root is ambiguous so we need to examine the - // result data to determine which buffers are contained in the result. - TF_ASSIGN_OR_RETURN( - TransferManager * transfer_manager, - TransferManager::GetForPlatform(executor->platform())); - TF_ASSIGN_OR_RETURN(referred_by_output, - transfer_manager->GatherBufferPointersFromTuple( - executor, output_buffer_address, root->shape())); - } else { - // The points-to set of the root is unambiguous so it's known statically - // which buffers are in the result. Gather these buffers using the root's - // points-to set. - TF_RETURN_IF_ERROR(GetRootPointsToSet().ForEachElementWithStatus( - [&referred_by_output, &buffer_allocations, this]( - const ShapeIndex& /*index*/, - const PointsToSet::BufferList& buffers) { - // The points to set is unambiguous so the set should be a - // singleton. That is, we know exactly which instruction produced - // the array at this element. - CHECK_EQ(1, buffers.size()); - HloInstruction* hlo = buffers[0]->instruction(); - TF_ASSIGN_OR_RETURN( - const BufferAllocation::Slice slice, - this->assignment_->GetUniqueSlice(hlo, buffers[0]->index())); - CHECK(!slice.allocation()->is_entry_computation_parameter()); - referred_by_output.insert( - buffer_allocations->GetDeviceAddress(slice.index())); - return Status::OK(); - })); - } - TF_RETURN_IF_ERROR( - buffer_allocations->TearDown(referred_by_output, *assignment_)); - } else { - // If the computation result is not a tuple, we can delete all temporary - // buffers that are not the output. - TF_RETURN_IF_ERROR( - buffer_allocations->TearDown({output_buffer_address}, *assignment_)); - } - return output_buffer_address; + return Status::OK(); } StatusOr> GpuExecutable::ExecuteOnStream( @@ -287,7 +264,7 @@ StatusOr> GpuExecutable::ExecuteOnStream( if (allocation.is_entry_computation_parameter()) { auto param_no = allocation.parameter_number(); buffer_allocations_builder.RegisterBuffer( - i, arguments[param_no]->buffer(/*index=*/{})); + i, arguments[param_no]->root_buffer()); } } se::StreamExecutor* executor = run_options->stream()->parent(); @@ -305,50 +282,46 @@ StatusOr> GpuExecutable::ExecuteOnStream( HloInstruction* root = hlo_module_->entry_computation()->root_instruction(); auto device_ordinal = executor->device_ordinal(); auto shaped_buffer = MakeUnique( - root->shape(), executor->platform(), device_ordinal); + root->shape(), root->shape(), executor->platform(), device_ordinal); // Copy DeviceMemoryBase values which contain the array(s) of the result into // the respective location in ShapedBuffer. std::set buffers_in_result; - TF_RETURN_IF_ERROR( - shaped_buffer->mutable_shape_index_to_buffer_entry() - ->ForEachMutableElementWithStatus( - [&buffer_allocations, &buffers_in_result, &shaped_buffer, this]( - const ShapeIndex& index, size_t* buffer_entry) { - const auto& sources = this->GetRootPointsToSet().element(index); - // The points-to set is unambiguous so the set should be a - // singleton. That is, we know exactly which instruction - // produced the array at this element. - CHECK_EQ(1, sources.size()); - auto src_hlo = sources[0]->instruction(); - - VLOG(4) << "Looking at: " << sources[0]; - - // The source instruction should have a non-parameter buffer - // assigned. - TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice, - this->assignment_->GetUniqueSlice( - src_hlo, sources[0]->index())); - CHECK(!slice.allocation()->is_entry_computation_parameter()); - - perftools::gputools::DeviceMemoryBase src_base = - buffer_allocations->GetDeviceAddress(slice.index()); - CHECK(!src_base.is_null() || src_base.size() == 0); - shaped_buffer->mutable_buffers()->push_back(src_base); - *buffer_entry = shaped_buffer->mutable_buffers()->size() - 1; - - buffers_in_result.insert(src_base); - return Status::OK(); - })); + TF_RETURN_IF_ERROR(shaped_buffer->buffers().ForEachMutableElementWithStatus( + [&buffer_allocations, &buffers_in_result, &shaped_buffer, this]( + const ShapeIndex& index, se::DeviceMemoryBase* device_memory) { + const auto& sources = this->GetRootPointsToSet().element(index); + // The points-to set is unambiguous so the set should be a + // singleton. That is, we know exactly which instruction + // produced the array at this element. + CHECK_EQ(1, sources.size()); + auto src_hlo = sources[0]->instruction(); + + VLOG(4) << "Looking at: " << sources[0]; + + // The source instruction should have a non-parameter buffer + // assigned. + TF_ASSIGN_OR_RETURN( + const BufferAllocation::Slice slice, + this->assignment_->GetUniqueSlice(src_hlo, sources[0]->index())); + CHECK(!slice.allocation()->is_entry_computation_parameter()); + + perftools::gputools::DeviceMemoryBase src_base = + buffer_allocations->GetDeviceAddress(slice.index()); + CHECK(!src_base.is_null() || src_base.size() == 0); + *device_memory = src_base; + buffers_in_result.insert(src_base); + return Status::OK(); + })); TF_RETURN_IF_ERROR( buffer_allocations->TearDown(buffers_in_result, *assignment_)); return std::move(shaped_buffer); } -StatusOr GpuExecutable::ExecuteAsyncOnStream( +StatusOr> GpuExecutable::ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments) { + tensorflow::gtl::ArraySlice arguments) { // TODO(b/30671675): Implement asynchronous execution mode. return Unimplemented( "Asynchronous execution on stream is not yet supported on GPU."); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.h b/tensorflow/compiler/xla/service/gpu/gpu_executable.h index e7307e07c0b5608e31f15597d31d11c50f81c6d5..00da64dfade8ddb0694c0ee7ac158c9f2e15a508 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.h @@ -72,24 +72,16 @@ class GpuExecutable : public Executable { // empty, in which case compilation is left up to the GPU driver. const std::vector& cubin() const { return cubin_; } - // Both overloads of ExecuteOnStream will fail if the compute capability of - // the stream doesn't match the compute capability passed to this object's - // constructor. - StatusOr ExecuteOnStream( - const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice - arguments, - HloExecutionProfile* hlo_execution_profile) override; - + // ExecuteOnStream will fail if the compute capability of the stream doesn't + // match the compute capability passed to this object's constructor. StatusOr> ExecuteOnStream( const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments, HloExecutionProfile* hlo_execution_profile) override; - StatusOr ExecuteAsyncOnStream( + StatusOr> ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice - arguments) override; + tensorflow::gtl::ArraySlice arguments) override; const Status EqualOrFail(const Executable& executable) { // TODO(b/62952745) Implement equality test on GPU executable. diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.cc new file mode 100644 index 0000000000000000000000000000000000000000..4944c41f7d8dc7a78a3cd094aee4d7087c74857e --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.cc @@ -0,0 +1,48 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h" + +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace xla { + +StatusOr GpuHloSupportChecker::Run(HloModule* module) { + for (auto* computation : module->computations()) { + for (const auto& instruction : computation->instructions()) { + TF_RETURN_IF_ERROR( + ShapeUtil::ValidateShapeWithOptionalLayout(instruction->shape())); + TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( + instruction->shape(), + [&instruction](const Shape& subshape, const ShapeIndex&) { + if (LayoutUtil::IsSparseArray(subshape)) { + return xla::Unimplemented( + "GPU backend does not support HLO instruction %s with shape " + "containing a sparse layout: %s", + instruction->ToString().c_str(), + ShapeUtil::HumanStringWithLayout(instruction->shape()) + .c_str()); + } + return Status::OK(); + })); + } + } + return false; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h new file mode 100644 index 0000000000000000000000000000000000000000..d9550f81b591ead3f6e8d3de4f62896ee04d2f82 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h @@ -0,0 +1,42 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_HLO_SUPPORT_CHECKER_H_ +#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_HLO_SUPPORT_CHECKER_H_ + +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { + +// his pass should run early in the HLO pipeline and checks for HLO constructs +// which are not supported by the GPU backend and cannot be removed via HLO +// transformations (eg, sparse layouts). +class GpuHloSupportChecker : public HloPassInterface { + public: + GpuHloSupportChecker() = default; + ~GpuHloSupportChecker() override = default; + + tensorflow::StringPiece name() const override { + return "gpu_hlo_support_checker"; + } + + // Note: always returns false (no instructions are ever modified by this + // pass). + StatusOr Run(HloModule* module) override; +}; + +} // namespace xla + +#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_HLO_SUPPORT_CHECKER_H_ diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..0a4089df4c954cafcbe241189ee79a0995683513 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker_test.cc @@ -0,0 +1,72 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/core/lib/core/error_codes.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { +namespace { + +using ::testing::HasSubstr; + +class GpuHloSupportCheckerTest : public HloTestBase { + protected: + GpuHloSupportChecker& checker() { return checker_; } + + private: + GpuHloSupportChecker checker_; +}; + +TEST_F(GpuHloSupportCheckerTest, Add) { + HloComputation::Builder builder(TestName()); + const Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "param0")); + HloInstruction* param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape, "param1")); + builder.AddInstruction(HloInstruction::CreateBinary( + scalar_shape, HloOpcode::kAdd, param0, param1)); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + + TF_ASSERT_OK(checker().Run(module.get()).status()); +} + +TEST_F(GpuHloSupportCheckerTest, SparseUnimplemented) { + HloComputation::Builder builder(TestName()); + const Shape sparse_shape = ShapeUtil::MakeShapeWithSparseLayout(F32, {10}, 2); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, sparse_shape, "param0")); + HloInstruction* param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, sparse_shape, "param1")); + builder.AddInstruction(HloInstruction::CreateBinary( + sparse_shape, HloOpcode::kAdd, param0, param1)); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + + Status status = checker().Run(module.get()).status(); + ASSERT_EQ(status.code(), tensorflow::error::UNIMPLEMENTED); + EXPECT_THAT(status.error_message(), + HasSubstr("GPU backend does not support")); + EXPECT_THAT(status.error_message(), + HasSubstr(ShapeUtil::HumanStringWithLayout(sparse_shape))); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/layout_assignment.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc similarity index 57% rename from tensorflow/compiler/xla/service/gpu/layout_assignment.cc rename to tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc index d475c4171b56ceedf5fdbda8b4d6221af844261c..58915f1f62f0c0f320443058a798333c498ffe47 100644 --- a/tensorflow/compiler/xla/service/gpu/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/gpu/layout_assignment.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h" #include @@ -149,5 +149,106 @@ Status GpuLayoutAssignment::AddBackendConstraints( return Status::OK(); } +bool GpuLayoutAssignment::CustomCallRequiresMajorFirstLayout( + const HloInstruction* instruction) { + // Inputs to cudnn batchnorm custom calls don't need the major-first layout + // (i.e. {n, n-1, ...0}) -- we can handle any layout. + return !IsCustomCallToDnnBatchNorm(*instruction); +} + +Status GpuLayoutAssignment::PropagateOperandConstraint( + const OperandLayoutConstraint& layout_constraint, + LayoutConstraints* constraints) { + const HloInstruction* instruction = layout_constraint.instruction(); + + // cudnn batchnorm forward inference's result must have the same layout as its + // operand 0. + if (instruction->opcode() == HloOpcode::kCustomCall && + instruction->custom_call_target() == + kCudnnBatchNormForwardInferenceCallTarget && + layout_constraint.operand_no() == 0) { + TF_RETURN_IF_ERROR(constraints->SetInstructionLayout( + layout_constraint.shape_layout().shape(), instruction)); + } + + // cudnn batchnorm forward training returns a tuple {output, mean, + // inverse-stddev}. mean and inverse-stddev are rank 1 and so have only one + // possible layout, but output is not (necessarily) rank 1, and, like in + // batchnorm forward inference, must have the same layout as operand 0. + if (instruction->opcode() == HloOpcode::kCustomCall && + instruction->custom_call_target() == + kCudnnBatchNormForwardTrainingCallTarget && + layout_constraint.operand_no() == 0) { + TF_ASSIGN_OR_RETURN(const LogicalBuffer* out_buf, + constraints->points_to_analysis().GetBufferDefinedAt( + instruction, /*index=*/{0})); + TF_RETURN_IF_ERROR(constraints->SetBufferLayout( + layout_constraint.shape_layout().layout(), *out_buf)); + } + + // Like forward training, cudnn batchnorm backward returns a tuple {output, + // mean, inverse-stddev}, and its operand 0 and 'output' must have the same + // layout. In addition, its operand 0 and operand 4 -- the 'operand' and + // 'grad_output' parameters -- must have the same layout. + if (instruction->opcode() == HloOpcode::kCustomCall && + instruction->custom_call_target() == kCudnnBatchNormBackwardCallTarget && + (layout_constraint.operand_no() == 0 || + layout_constraint.operand_no() == 4)) { + TF_ASSIGN_OR_RETURN(const LogicalBuffer* out_buf, + constraints->points_to_analysis().GetBufferDefinedAt( + instruction, /*index=*/{0})); + TF_RETURN_IF_ERROR(constraints->SetBufferLayout( + layout_constraint.shape_layout().layout(), *out_buf)); + + int64 operand_to_set = layout_constraint.operand_no() == 0 ? 4 : 0; + TF_RETURN_IF_ERROR(constraints->SetOperandLayout( + layout_constraint.shape_layout().shape(), instruction, operand_to_set)); + } + + return LayoutAssignment::PropagateOperandConstraint(layout_constraint, + constraints); +} + +Status GpuLayoutAssignment::PropagateBufferConstraint( + const BufferLayoutConstraint& buffer_constraint, + LayoutConstraints* constraints) { + const LogicalBuffer& buf = buffer_constraint.buffer(); + const HloInstruction* instruction = buf.instruction(); + + Shape shape_with_layout = buf.shape(); + *shape_with_layout.mutable_layout() = buffer_constraint.layout(); + + // Propagate output constraints to the operands of cudnn batchnorm ops. This + // is the same as PropagateOperandConstraint, just in the other direction. We + // need to both to fulfill our contract to LayoutAssignment. + if (instruction->opcode() == HloOpcode::kCustomCall && + instruction->custom_call_target() == + kCudnnBatchNormForwardInferenceCallTarget) { + TF_RETURN_IF_ERROR(constraints->SetOperandLayout( + shape_with_layout, instruction, /*operand_no=*/0)); + } + + if (instruction->opcode() == HloOpcode::kCustomCall && + instruction->custom_call_target() == + kCudnnBatchNormForwardTrainingCallTarget && + buf.index() == ShapeIndex({0})) { + TF_RETURN_IF_ERROR(constraints->SetOperandLayout( + shape_with_layout, instruction, /*operand_no=*/0)); + } + if (instruction->opcode() == HloOpcode::kCustomCall && + instruction->custom_call_target() == kCudnnBatchNormBackwardCallTarget && + buf.index() == ShapeIndex({0})) { + // batchnorm backward has two operands, "operand" and "grad_output" whose + // layouts must both match that of the result at tuple-index 0. + TF_RETURN_IF_ERROR(constraints->SetOperandLayout( + shape_with_layout, instruction, /*operand_no=*/0)); + TF_RETURN_IF_ERROR(constraints->SetOperandLayout( + shape_with_layout, instruction, /*operand_no=*/4)); + } + + return LayoutAssignment::PropagateBufferConstraint(buffer_constraint, + constraints); +} + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/layout_assignment.h b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h similarity index 70% rename from tensorflow/compiler/xla/service/gpu/layout_assignment.h rename to tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h index 169041eb85c633cb4f1f679bcea127714828308f..86a3a7111fd79494e469beecf3234f6cec9adb9c 100644 --- a/tensorflow/compiler/xla/service/gpu/layout_assignment.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_LAYOUT_ASSIGNMENT_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_LAYOUT_ASSIGNMENT_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_LAYOUT_ASSIGNMENT_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_LAYOUT_ASSIGNMENT_H_ #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/layout_assignment.h" @@ -33,9 +33,17 @@ class GpuLayoutAssignment : public LayoutAssignment { protected: Status AddBackendConstraints(LayoutConstraints* constraints) override; + Status PropagateOperandConstraint( + const OperandLayoutConstraint& layout_constraint, + LayoutConstraints* constraints) override; + Status PropagateBufferConstraint( + const BufferLayoutConstraint& buffer_constraint, + LayoutConstraints* constraints) override; + bool CustomCallRequiresMajorFirstLayout( + const HloInstruction* instruction) override; }; } // namespace gpu } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_LAYOUT_ASSIGNMENT_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_LAYOUT_ASSIGNMENT_H_ diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..4c45d2e94aebce5496da94841f6a1ae9015615c1 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc @@ -0,0 +1,328 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h" + +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/service/computation_layout.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_layout.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { +namespace gpu { +namespace { + +using LayoutAssignmentTest = HloTestBase; + +TEST_F(LayoutAssignmentTest, Elementwise) { + Shape ashape = ShapeUtil::MakeShape(F32, {42, 12}); + Shape ashape_in_row_major(ashape); + Shape ashape_in_col_major(ashape); + *ashape_in_row_major.mutable_layout() = LayoutUtil::MakeLayout({1, 0}); + *ashape_in_col_major.mutable_layout() = LayoutUtil::MakeLayout({0, 1}); + + // Enumerate all possible combinations of layouts. + for (const Shape& lhs_shape_with_layout : + {ashape_in_row_major, ashape_in_col_major}) { + for (const Shape& rhs_shape_with_layout : + {ashape_in_row_major, ashape_in_col_major}) { + for (const Shape& result_shape_with_layout : + {ashape_in_row_major, ashape_in_col_major}) { + // GpuLayoutAssignment should assign the same layout to "add" and its + // two operands. + auto builder = HloComputation::Builder(TestName()); + auto x = builder.AddInstruction( + HloInstruction::CreateParameter(0, ashape, "x")); + auto y = builder.AddInstruction( + HloInstruction::CreateParameter(1, ashape, "y")); + auto add = builder.AddInstruction( + HloInstruction::CreateBinary(ashape, HloOpcode::kAdd, x, y)); + auto module = CreateNewModule(); + HloComputation* computation = + module->AddEntryComputation(builder.Build(add)); + + ComputationLayout computation_layout( + computation->ComputeProgramShape()); + *computation_layout.mutable_parameter_layout(0) = + ShapeLayout(lhs_shape_with_layout); + *computation_layout.mutable_parameter_layout(1) = + ShapeLayout(rhs_shape_with_layout); + *computation_layout.mutable_result_layout() = + ShapeLayout(result_shape_with_layout); + + GpuLayoutAssignment layout_assignment(&computation_layout); + EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie()); + + for (const HloInstruction* operand : add->operands()) { + EXPECT_TRUE(LayoutUtil::Equal(add->shape().layout(), + operand->shape().layout())); + } + } + } + } +} + +// Returns a list shapes with all the possible layouts of this shape, including +// a shape with no layout. +std::vector AllLayoutsOf(const Shape& s) { + std::vector layout_vec(s.dimensions_size()); + std::iota(layout_vec.begin(), layout_vec.end(), 0); + + std::vector shapes; + shapes.push_back(s); + shapes.back().clear_layout(); + + do { + shapes.push_back(s); + *shapes.back().mutable_layout() = LayoutUtil::MakeLayout(layout_vec); + } while (std::next_permutation(layout_vec.begin(), layout_vec.end())); + + return shapes; +} + +TEST_F(LayoutAssignmentTest, BatchNormInference) { + const int64 kFeatureIndex = 1; + + // The shape of the data operand to BatchNormInference and of the output of + // the BatchNormInference call. + Shape shape = ShapeUtil::MakeShape(F32, {42, 12, 1, 100}); + + // The shape of the scale, offset, mean, and variance inputs to + // BatchNormTraining. These are rank 1, with as many elements are in the + // kFeatureIndex dim of shape. + Shape aux_shape = + ShapeUtil::MakeShape(F32, {shape.dimensions(kFeatureIndex)}); + + for (const Shape& input_shape : AllLayoutsOf(shape)) { + for (const Shape& result_shape : AllLayoutsOf(shape)) { + SCOPED_TRACE(tensorflow::strings::StrCat( + "input_shape=", ShapeUtil::HumanStringWithLayout(input_shape), + ", result_shape=", ShapeUtil::HumanStringWithLayout(result_shape))); + + auto builder = HloComputation::Builder(TestName()); + auto* operand = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "operand")); + auto* scale = builder.AddInstruction( + HloInstruction::CreateParameter(1, aux_shape, "scale")); + auto* offset = builder.AddInstruction( + HloInstruction::CreateParameter(2, aux_shape, "offset")); + auto* mean = builder.AddInstruction( + HloInstruction::CreateParameter(3, aux_shape, "mean")); + auto* variance = builder.AddInstruction( + HloInstruction::CreateParameter(4, aux_shape, "variance")); + + auto* epsilon = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1))); + auto* feature_index = + builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR0(kFeatureIndex))); + + auto* batchnorm = builder.AddInstruction(HloInstruction::CreateCustomCall( + shape, + {operand, scale, offset, mean, variance, epsilon, feature_index}, + kCudnnBatchNormForwardInferenceCallTarget)); + + auto module = CreateNewModule(); + HloComputation* computation = + module->AddEntryComputation(builder.Build(batchnorm)); + + ComputationLayout computation_layout(computation->ComputeProgramShape()); + + if (input_shape.has_layout()) { + *computation_layout.mutable_parameter_layout(0) = + ShapeLayout(input_shape); + } + + if (result_shape.has_layout()) { + *computation_layout.mutable_result_layout() = ShapeLayout(result_shape); + } + + GpuLayoutAssignment layout_assignment(&computation_layout); + EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie()); + + // The first operand to batchnorm should have the same layout as the + // result. + EXPECT_TRUE(LayoutUtil::Equal(batchnorm->operand(0)->shape().layout(), + batchnorm->shape().layout())) + << batchnorm->ToString(); + } + } +} + +TEST_F(LayoutAssignmentTest, BatchNormTraining) { + const int64 kFeatureIndex = 1; + + // The shape of the data operand to BatchNormTraining. + Shape shape = ShapeUtil::MakeShape(F32, {42, 12, 1, 100}); + + // The shape of the offset and scale inputs to BatchNormTraining. These are + // rank 1, with as many elements are in the kFeatureIndex dim of shape. + Shape offset_scale_shape = + ShapeUtil::MakeShape(F32, {shape.dimensions(kFeatureIndex)}); + + // Shape of the output of our BatchNormTraining op. + Shape batchnorm_shape = ShapeUtil::MakeTupleShape( + {shape, offset_scale_shape, offset_scale_shape}); + + // Enumerate all combinations of shapes. + for (const Shape& input_shape : AllLayoutsOf(shape)) { + for (const Shape& result_shape : AllLayoutsOf(shape)) { + SCOPED_TRACE(tensorflow::strings::StrCat( + "input_shape=", ShapeUtil::HumanStringWithLayout(input_shape), + ", result_shape=", ShapeUtil::HumanStringWithLayout(result_shape))); + + auto builder = HloComputation::Builder(TestName()); + auto* operand = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "operand")); + auto* scale = builder.AddInstruction( + HloInstruction::CreateParameter(1, offset_scale_shape, "scale")); + auto* offset = builder.AddInstruction( + HloInstruction::CreateParameter(2, offset_scale_shape, "offset")); + + auto* epsilon = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1))); + auto* feature_index = + builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR0(kFeatureIndex))); + + auto* batchnorm = builder.AddInstruction(HloInstruction::CreateCustomCall( + batchnorm_shape, {operand, scale, offset, epsilon, feature_index}, + kCudnnBatchNormForwardTrainingCallTarget)); + + auto module = CreateNewModule(); + HloComputation* computation = + module->AddEntryComputation(builder.Build(batchnorm)); + + ComputationLayout computation_layout(computation->ComputeProgramShape()); + + if (input_shape.has_layout()) { + *computation_layout.mutable_parameter_layout(0) = + ShapeLayout(input_shape); + } + + if (result_shape.has_layout()) { + *computation_layout.mutable_result_layout() = + ShapeLayout(ShapeUtil::MakeTupleShape( + {result_shape, offset_scale_shape, offset_scale_shape})); + } + + GpuLayoutAssignment layout_assignment(&computation_layout); + EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie()); + + // The first operand to batchnorm should have the same layout as the + // first element of the result tuple. + EXPECT_TRUE( + LayoutUtil::Equal(batchnorm->operand(0)->shape().layout(), + batchnorm->shape().tuple_shapes(0).layout())) + << batchnorm->ToString(); + } + } +} + +TEST_F(LayoutAssignmentTest, BatchNormGrad) { + const int64 kFeatureIndex = 1; + + // The shape of the data operand to BatchNormTraining. + Shape shape = ShapeUtil::MakeShape(F32, {42, 12, 1, 100}); + + // The shape of the scale, mean, and variance inputs to BatchNormGrad. These + // are rank 1, with as many elements are in the kFeatureIndex dim of shape. + Shape scale_shape = + ShapeUtil::MakeShape(F32, {shape.dimensions(kFeatureIndex)}); + + // Shape of the output of our BatchNormGrad op. + Shape batchnorm_shape = + ShapeUtil::MakeTupleShape({shape, scale_shape, scale_shape}); + + // Enumerate all combinations of shapes plus whether we're constraining param + // 0 or param 4. + for (const Shape& input_shape : AllLayoutsOf(shape)) { + for (const Shape& result_shape : AllLayoutsOf(shape)) { + for (int constrained_param_no : {0, 4}) { + SCOPED_TRACE(tensorflow::strings::StrCat( + "input_shape=", ShapeUtil::HumanStringWithLayout(input_shape), + ", result_shape=", ShapeUtil::HumanStringWithLayout(result_shape))); + + auto builder = HloComputation::Builder(TestName()); + auto* operand = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "operand")); + auto* scale = builder.AddInstruction( + HloInstruction::CreateParameter(1, scale_shape, "scale")); + auto* mean = builder.AddInstruction( + HloInstruction::CreateParameter(2, scale_shape, "mean")); + auto* var = builder.AddInstruction( + HloInstruction::CreateParameter(3, scale_shape, "var")); + auto* grad_offset = builder.AddInstruction( + HloInstruction::CreateParameter(4, shape, "var")); + + auto* epsilon = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1))); + auto* feature_index = + builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR0(kFeatureIndex))); + + auto* batchnorm = + builder.AddInstruction(HloInstruction::CreateCustomCall( + batchnorm_shape, + {operand, scale, mean, var, grad_offset, epsilon, + feature_index}, + kCudnnBatchNormBackwardCallTarget)); + + auto module = CreateNewModule(); + HloComputation* computation = + module->AddEntryComputation(builder.Build(batchnorm)); + + ComputationLayout computation_layout( + computation->ComputeProgramShape()); + + if (input_shape.has_layout()) { + *computation_layout.mutable_parameter_layout(constrained_param_no) = + ShapeLayout(input_shape); + } + + if (result_shape.has_layout()) { + *computation_layout.mutable_result_layout() = + ShapeLayout(ShapeUtil::MakeTupleShape( + {result_shape, scale_shape, scale_shape})); + } + + GpuLayoutAssignment layout_assignment(&computation_layout); + EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie()); + + // The first and fourth operands to the batchnorm call should have the + // same layout as the first element of the result tuple. + EXPECT_TRUE( + LayoutUtil::Equal(batchnorm->operand(0)->shape().layout(), + batchnorm->shape().tuple_shapes(0).layout())) + << batchnorm->ToString(); + EXPECT_TRUE( + LayoutUtil::Equal(batchnorm->operand(4)->shape().layout(), + batchnorm->shape().tuple_shapes(0).layout())) + << batchnorm->ToString(); + } + } + } +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc index f0f036f7f381db15b84db85d3efeec5d8141884e..af9897769fda371e47af06c19abce9a06015e094 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc @@ -44,7 +44,7 @@ GpuTransferManager::GpuTransferManager() : GenericTransferManager( se::cuda::kCudaPlatformId, /*pointer_size=*/llvm::DataLayout(gpu::GpuCompiler::kDataLayout) - .getPointerSize()) {} + .getPointerSize(0 /* default address space */)) {} Status GpuTransferManager::TransferLiteralToInfeed(se::StreamExecutor* executor, const Literal& literal) { @@ -54,7 +54,7 @@ Status GpuTransferManager::TransferLiteralToInfeed(se::StreamExecutor* executor, if (!ShapeUtil::IsTuple(shape)) { int64 size = GetByteSizeRequirement(shape); - return TransferBufferToInfeed(executor, size, literal.InternalData()); + return TransferBufferToInfeed(executor, size, literal.untyped_data()); } if (ShapeUtil::IsNestedTuple(shape)) { @@ -67,20 +67,21 @@ Status GpuTransferManager::TransferLiteralToInfeed(se::StreamExecutor* executor, // enqueue the resulting destination device addresses with the // infeed manager. std::vector buffers; - buffers.reserve(literal.tuple_literals_size()); + buffers.reserve(ShapeUtil::TupleElementCount(shape)); auto cleanup = tensorflow::gtl::MakeCleanup([buffers]() { for (gpu::InfeedBuffer* b : buffers) { b->Done(); } }); - for (const auto& tuple_element : literal.tuple_literals()) { - const Shape& tuple_element_shape = tuple_element.shape(); + for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { + const Shape& tuple_element_shape = + ShapeUtil::GetTupleElementShape(shape, i); int64 tuple_element_size = GetByteSizeRequirement(tuple_element_shape); TF_ASSIGN_OR_RETURN( gpu::InfeedBuffer * buffer, TransferBufferToInfeedInternal(executor, tuple_element_size, - tuple_element.InternalData())); + literal.untyped_data({i}))); buffers.push_back(buffer); } @@ -105,12 +106,13 @@ Status GpuTransferManager::EnqueueBuffersToInfeed( // infeed requests, blocking on the stream might be // heavy-handed. Figure out if finer-grained acknowledgement is // possible. - if (!stream->BlockHostUntilDone()) { + Status block_status = stream->BlockHostUntilDone(); + if (!block_status.ok()) { for (gpu::InfeedBuffer* b : buffers) { b->Done(); } - return InternalError("Failed to complete data transfer on stream %p", - stream); + return InternalError("Failed to complete data transfer on stream %p: %s", + stream, block_status.error_message().c_str()); } infeed_manager->EnqueueBuffers(buffers); diff --git a/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc b/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc index e33e904692ca5ad41e17d2e165dbb40b6bd4aa33..2ac95ceb692447c7ac6dbbcd8b9a38876f7a77b6 100644 --- a/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc @@ -30,9 +30,8 @@ InfeedThunk::InfeedThunk( tuple_element_buffers.end()), destination_buffer_(destination_buffer) {} -tensorflow::Status InfeedThunk::ExecuteOnStream( - const BufferAllocations& buffer_allocations, - perftools::gputools::Stream* stream) { +Status InfeedThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, + perftools::gputools::Stream* stream) { VLOG(2) << "Infeeding to GPU "; perftools::gputools::DeviceMemoryBase destination_address = @@ -66,15 +65,16 @@ tensorflow::Status InfeedThunk::ExecuteOnStream( buffer->length()); } - if (!stream->BlockHostUntilDone()) { - return InternalError("Failed to complete data transfer on stream %p", - stream); + Status block_status = stream->BlockHostUntilDone(); + if (!block_status.ok()) { + return InternalError("Failed to complete data transfer on stream %p: %s", + stream, block_status.error_message().c_str()); } infeed_manager->ReleaseBuffers(infeed_buffers); VLOG(2) << "Infeeding to GPU complete"; - return tensorflow::Status::OK(); + return Status::OK(); } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/infeed_thunk.h b/tensorflow/compiler/xla/service/gpu/infeed_thunk.h index 371d71f9dbdd21cb5f36cc3108c8f398a4a91c29..86918705fa0305217f11753e383200c7bd71474b 100644 --- a/tensorflow/compiler/xla/service/gpu/infeed_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/infeed_thunk.h @@ -43,9 +43,8 @@ class InfeedThunk : public Thunk { InfeedThunk(const InfeedThunk&) = delete; InfeedThunk& operator=(const InfeedThunk&) = delete; - tensorflow::Status ExecuteOnStream( - const BufferAllocations& buffer_allocations, - perftools::gputools::Stream* stream) override; + Status ExecuteOnStream(const BufferAllocations& buffer_allocations, + perftools::gputools::Stream* stream) override; private: const std::vector tuple_element_buffers_; diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc index 658fd05cd4b63c923d21b4a1de16468c0aeec65d..76566a9e3dbbc936ff90fe3f440ede14bf4e5233 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc @@ -110,6 +110,10 @@ bool ImplementedAsDnnConvolution(const HloInstruction& hlo) { return false; } + if (window_util::HasWindowReversal(hlo.window())) { + return false; + } + return true; } @@ -123,8 +127,26 @@ bool ImplementedAsDnnConvolution(const HloInstruction& hlo) { return false; } +const char* const kCudnnBatchNormForwardInferenceCallTarget = + "__cudnn$batchNormalizationForwardInference"; +const char* const kCudnnBatchNormForwardTrainingCallTarget = + "__cudnn$batchNormalizationForwardTraining"; +const char* const kCudnnBatchNormBackwardCallTarget = + "__cudnn$batchNormalizationBackward"; + +bool IsCustomCallToDnnBatchNorm(const HloInstruction& hlo) { + if (hlo.opcode() != HloOpcode::kCustomCall) { + return false; + } + const auto& target = hlo.custom_call_target(); + return target == kCudnnBatchNormForwardInferenceCallTarget || + target == kCudnnBatchNormForwardTrainingCallTarget || + target == kCudnnBatchNormBackwardCallTarget; +} + bool ImplementedAsLibraryCall(const HloInstruction& hlo) { - return ImplementedAsGemm(hlo) || ImplementedAsDnnConvolution(hlo); + return ImplementedAsGemm(hlo) || ImplementedAsDnnConvolution(hlo) || + IsCustomCallToDnnBatchNorm(hlo); } bool IsReductionToVector(const HloInstruction& reduce) { diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h index 06c3205296e4546e39525ec093cc17e2fc375d0d..d24ed9879d084e96862885efaae2f79a256cd71d 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h @@ -33,6 +33,31 @@ bool ImplementedAsGemm(const HloInstruction& hlo); // Returns true if `hlo` will be implemented as a call to cuDNN convolution. bool ImplementedAsDnnConvolution(const HloInstruction& hlo); +// A call to cuDNN for batch normalization is represented as CustomCall HLO with +// a call target equal to one of these strings. +// +// The operands to and outputs of these calls are the same as those of the +// corresponding HLOs, except: +// +// - epsilon and feature_index are proper operands, at the end of the operands +// list. They must be HLO constants. +// - The cuDNN forward training call returns inv_stddev = +// 1/sqrt(variance + epsilon) in place of plain variance. +// - Similarly, BatchNormGrad accepts inv_stddev in place of the variance +// operand. +extern const char* const kCudnnBatchNormForwardInferenceCallTarget; +extern const char* const kCudnnBatchNormForwardTrainingCallTarget; +extern const char* const kCudnnBatchNormBackwardCallTarget; + +// Returns true if `hlo` will be implemented as a call to a cuDNN batch +// normalization routine. +// +// This returns true if `hlo` is a CustomCall HLO with a call target equal to +// one of the kCudnnBatchNormFoo constants above, but returns *false* for HLOs +// with one of the kBatchNorm opcodes, because these are lowered either to a +// sequence of generic HLOs or to a cuDNN CustomCall. +bool IsCustomCallToDnnBatchNorm(const HloInstruction& hlo); + // Returns true if `hlo` will be implemented as a library call, e.g. cuBLAS gemm // or cuDNN convolution. bool ImplementedAsLibraryCall(const HloInstruction& hlo); diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc index 6e2bd4e11d3c4ff576edb0df3b724abebfc0e424..095c3df3bfc75cae999edc7fdd800f6e399546dd 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc @@ -173,7 +173,7 @@ Status IrEmitter::EmitCallToNestedComputation( return Status::OK(); } -bool IrEmitter::MaybeEmitSpecialAtomicOperation( +bool IrEmitter::MaybeEmitDirectAtomicOperation( const HloComputation& computation, llvm::Value* output_address, llvm::Value* source_address) { CHECK_EQ(2, computation.num_parameters()); @@ -233,102 +233,189 @@ bool IrEmitter::MaybeEmitSpecialAtomicOperation( return false; } -Status IrEmitter::EmitAtomicOperationForNestedComputation( - const HloComputation& computation, llvm::Value* output_address, - llvm::Value* source_address) { - if (computation.num_parameters() != 2) { - // TODO(b/30258929): We only accept binary computations so far. - return Unimplemented( - "We only support atomic functions with exactly two parameters, but " - "computation %s has %lld.", - computation.name().c_str(), computation.num_parameters()); - } +// Implements atomic binary operations using atomic compare-and-swap +// (atomicCAS) as follows: +// 1. Reads the value from the memory pointed to by output_address and +// records it as old_output. +// 2. Uses old_output as one of the source operand to perform the binary +// operation and stores the result in new_output. +// 3. Calls atomicCAS which implements compare-and-swap as an atomic +// operation. In particular, atomicCAS reads the value from the memory +// pointed to by output_address, and compares the value with old_output. If +// the two values equal, new_output is written to the same memory location +// and true is returned to indicate that the atomic operation succeeds. +// Otherwise, the new value read from the memory is returned. In this case, +// the new value is copied to old_output, and steps 2. and 3. are repeated +// until atomicCAS succeeds. +// +// On Nvidia GPUs, atomicCAS can only operate on 32 bit and 64 bit integers. If +// the element type of the binary operation is 32 bits or 64 bits, the integer +// type of the same size is used for the atomicCAS operation. On the other hand, +// if the element type is smaller than 32 bits, int32 is used for the atomicCAS +// operation. In this case, atomicCAS reads and writes 32 bit values from +// the memory, which is larger than the memory size required by the original +// atomic binary operation. We mask off the last two bits of the output_address +// and use the result as an address to read the 32 bit values from the memory. +// This can avoid out of bound memory accesses if tensor buffers are 4 byte +// aligned and have a size of 4N, an assumption that the runtime can guarantee. +// +// The pseudo code is shown below. Variables *_address are pointers to a memory +// region with a size equal to the size of the atomicCAS operation, with the +// exception that new_output_address is a pointer to a memory region with a size +// equal to the element size of the binary operation. +// +// element_size = sizeof(element_type); +// atomic_size = max(32, element_size); +// cas_new_output_address = alloca(atomic_size); +// cas_old_output_address = alloca(atomic_size); +// if (atomic_size != element_size) { +// atomic_address = output_address & ((int64)(-4)); +// new_output_address = cas_new_output_address + (output_address & 3); +// } else { +// atomic_address = output_address; +// new_output_address = cas_new_output_address; +// } +// +// *cas_old_output_address = *atomic_address; +// do { +// *cas_new_output_address = *cas_old_output_address; +// *new_output_address = operation(*new_output_address, *source_address); +// (*cas_old_output_address, success) = +// atomicCAS(atomic_address, *cas_old_output_address, +// *cas_new_output_address); +// } while (!success); +// +Status IrEmitter::EmitAtomicOperationUsingCAS(const HloComputation& computation, + llvm::Value* output_address, + llvm::Value* source_address) { + llvm::PointerType* output_address_type = + llvm::dyn_cast(output_address->getType()); + CHECK_NE(output_address_type, nullptr); + + // element_type is the data type for the binary operation. + llvm::Type* element_type = output_address_type->getPointerElementType(); + int element_size = llvm_ir::GetSizeInBits(element_type); + llvm::Type* element_address_type = element_type->getPointerTo(); + + int atomic_size = (element_size < 32) ? 32 : element_size; + llvm::Type* atomic_type = ir_builder_.getIntNTy(atomic_size); + llvm::Type* atomic_address_type = + atomic_type->getPointerTo(output_address_type->getPointerAddressSpace()); + + // cas_old_output_address and cas_new_output_address point to the scratch + // memory where we store the old and new values for the repeated atomicCAS + // operations. + llvm::Value* cas_old_output_address = ir_builder_.CreateAlloca( + atomic_type, /*ArraySize=*/nullptr, "cas_old_output_address"); + llvm::Value* cas_new_output_address = ir_builder_.CreateAlloca( + atomic_type, /*ArraySize=*/nullptr, "cas_new_output_address"); - if (MaybeEmitSpecialAtomicOperation(computation, output_address, - source_address)) { - return Status::OK(); - } - - // Other binary computations can be made atomic as following (labels are basic - // block names used in the IR emitting code later). - // - // atomic_op_loop_preheader: - // ... - // source = *source_address; - // old_output = *output_address; - // do { - // atomic_op_loop_body_entry: - // new_output = computation(old_output, source); - // (old_output, success) = - // atomicCAS(output_address, old_output, new_output); - // } while (!success); - // - // atomic_op_loop_exit: - // ... - // - // TODO(jingyue): Consider encapsulate the logic of emitting control flow to - // something similar to llvm_ir::ForLoop. - // // Emit preparation code to the preheader. llvm::BasicBlock* loop_preheader_bb = ir_builder_.GetInsertBlock(); - llvm::Type* element_ir_type = - output_address->getType()->getPointerElementType(); - // old_output = *output_address; - llvm::Value* old_output_location = ir_builder_.CreateAlloca( - element_ir_type, /*ArraySize=*/nullptr, "old_output_location"); - ir_builder_.CreateStore(ir_builder_.CreateLoad(output_address, "old_output"), - old_output_location); + + llvm::Value* atomic_memory_address; + // binop_output_address points to the scratch memory that stores the + // result of the binary operation. + llvm::Value* binop_output_address; + if (element_size < 32) { + // Assume the element size is an integer number of bytes. + CHECK_EQ((element_size % sizeof(char)), 0); + llvm::Type* address_int_type = + module_->getDataLayout().getIntPtrType(output_address_type); + atomic_memory_address = + ir_builder_.CreatePtrToInt(output_address, address_int_type); + llvm::Value* mask = llvm::ConstantInt::get(address_int_type, 3); + llvm::Value* offset = ir_builder_.CreateAnd(atomic_memory_address, mask); + mask = llvm::ConstantInt::get(address_int_type, -4); + atomic_memory_address = ir_builder_.CreateAnd(atomic_memory_address, mask); + atomic_memory_address = + ir_builder_.CreateIntToPtr(atomic_memory_address, atomic_address_type); + binop_output_address = ir_builder_.CreateAdd( + ir_builder_.CreatePtrToInt(cas_new_output_address, address_int_type), + offset); + binop_output_address = + ir_builder_.CreateIntToPtr(binop_output_address, element_address_type); + } else { + atomic_memory_address = + ir_builder_.CreateBitCast(output_address, atomic_address_type); + binop_output_address = + ir_builder_.CreateBitCast(cas_new_output_address, element_address_type); + } + + // Use the value from the memory that atomicCAS operates on to initialize + // cas_old_output. + llvm::Value* cas_old_output = + ir_builder_.CreateLoad(atomic_memory_address, "cas_old_output"); + ir_builder_.CreateStore(cas_old_output, cas_old_output_address); + llvm::BasicBlock* loop_exit_bb = loop_preheader_bb->splitBasicBlock( ir_builder_.GetInsertPoint(), "atomic_op_loop_exit"); - - // Emit the body of the loop that repeatedly invokes atomicCAS. llvm::BasicBlock* loop_body_bb = llvm::BasicBlock::Create(ir_builder_.getContext(), "atomic_op_loop_body", ir_builder_.GetInsertBlock()->getParent()); ir_builder_.SetInsertPoint(loop_body_bb); // Change preheader's successor from loop_exit_bb to loop_body_bb. loop_preheader_bb->getTerminator()->setSuccessor(0, loop_body_bb); - // new_output = computation(old_output, source); - llvm::Value* new_output_location = ir_builder_.CreateAlloca( - element_ir_type, /*ArraySize=*/nullptr, "new_output_location"); + + // Emit the body of the loop that repeatedly invokes atomicCAS. + // + // Use cas_old_output to initialize cas_new_output. + cas_old_output = + ir_builder_.CreateLoad(cas_old_output_address, "cas_old_output"); + ir_builder_.CreateStore(cas_old_output, cas_new_output_address); + // Emits code to calculate new_output = operation(old_output, source); TF_RETURN_IF_ERROR(EmitCallToNestedComputation( - computation, {old_output_location, source_address}, new_output_location)); - - // (old_output, success) = atomicCAS(output_address, old_output, new_output); - int num_bits = llvm_ir::GetSizeInBits(element_ir_type); - llvm::Type* element_int_ir_type = ir_builder_.getIntNTy(num_bits); - // cmpxchg accepts integer only, and bitcast refuses to operate on aggregate - // types, so we bitcast load and store addresses to intN* of the same bit - // width. - llvm::Value* old_output = ir_builder_.CreateLoad( - ir_builder_.CreateBitCast(old_output_location, - element_int_ir_type->getPointerTo()), - "old_output"); - llvm::Value* new_output = ir_builder_.CreateLoad( - ir_builder_.CreateBitCast(new_output_location, - element_int_ir_type->getPointerTo()), - "new_output"); + computation, {binop_output_address, source_address}, + binop_output_address)); + + llvm::Value* cas_new_output = + ir_builder_.CreateLoad(cas_new_output_address, "cas_new_output"); + + // Emit code to perform the atomicCAS operation + // (cas_old_output, success) = atomicCAS(memory_address, cas_old_output, + // cas_new_output); llvm::Value* ret_value = ir_builder_.CreateAtomicCmpXchg( - ir_builder_.CreateBitCast(output_address, - element_int_ir_type->getPointerTo()), - old_output, new_output, llvm::AtomicOrdering::SequentiallyConsistent, + atomic_memory_address, cas_old_output, cas_new_output, + llvm::AtomicOrdering::SequentiallyConsistent, llvm::AtomicOrdering::SequentiallyConsistent); - // cmpxchg returns a pair. The first element is the original value at - // output_address and the second element is whether the swap is successful. + + // Extract the memory value returned from atomicCAS and store it as + // cas_old_output. ir_builder_.CreateStore( - ir_builder_.CreateExtractValue(ret_value, 0, "old_output"), - ir_builder_.CreateBitCast(old_output_location, - element_int_ir_type->getPointerTo())); + ir_builder_.CreateExtractValue(ret_value, 0, "cas_old_output"), + cas_old_output_address); + // Extract the success bit returned from atomicCAS and generate a + // conditional branch on the success bit. ir_builder_.CreateCondBr( ir_builder_.CreateExtractValue(ret_value, 1, "success"), loop_exit_bb, loop_body_bb); - // Restore the insertion point to the exit basic block so that the caller of + // Set the insertion point to the exit basic block so that the caller of // this method can continue emitting code to the right place. SetToFirstInsertPoint(loop_exit_bb, &ir_builder_); return Status::OK(); } +Status IrEmitter::EmitAtomicOperationForNestedComputation( + const HloComputation& computation, llvm::Value* output_address, + llvm::Value* source_address) { + if (computation.num_parameters() != 2) { + // TODO(b/30258929): We only accept binary computations so far. + return Unimplemented( + "We only support atomic functions with exactly two parameters, but " + "computation %s has %lld.", + computation.name().c_str(), computation.num_parameters()); + } + + if (MaybeEmitDirectAtomicOperation(computation, output_address, + source_address)) { + return Status::OK(); + } + + return EmitAtomicOperationUsingCAS(computation, output_address, + source_address); +} + Status IrEmitter::HandleSelect(HloInstruction* select) { auto pred = select->operand(0); auto on_true = select->operand(1); @@ -518,6 +605,14 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) { "Hit a case for convolution that is not implemented on GPU."); } +Status IrEmitter::HandleFft(HloInstruction* fft) { + if (ShapeUtil::HasZeroElements(fft->shape())) { + // Emit no code for an empty output. + return Status::OK(); + } + return Unimplemented("Hit a case for fft that is not implemented on GPU."); +} + Status IrEmitter::HandleCrossReplicaSum(HloInstruction* crs) { // TODO(b/33011107): Support cross replica sum on GPU. return Unimplemented( @@ -640,6 +735,60 @@ Status IrEmitter::HandleRng(HloInstruction* random) { .EmitLoop(IrName(random)); } +Status IrEmitter::HandleBatchNormInference(HloInstruction*) { + return Unimplemented( + "The GPU backend does not implement BatchNormInference directly. It " + "should be lowered before IR emission to HLO-soup using " + "BatchNormRewriter or to a cudnn CustomCall using " + "CudnnBatchNormRewriter."); +} + +Status IrEmitter::HandleBatchNormTraining(HloInstruction*) { + return Unimplemented( + "The GPU backend does not implement BatchNormTraining directly. It " + "should be lowered before IR emission to HLO-soup using " + "BatchNormRewriter or to a cudnn CustomCall using " + "CudnnBatchNormRewriter."); +} + +Status IrEmitter::HandleBatchNormGrad(HloInstruction*) { + return Unimplemented( + "The GPU backend does not implement BatchNormGrad directly. It should " + "be lowered before IR emission to HLO-soup (using BatchNormRewriter) or " + "to a cudnn CustomCall using CudnnBatchNormRewriter."); +} + +Status IrEmitter::HandleConditional(HloInstruction* conditional) { + auto pred = conditional->operand(0); + auto true_arg = conditional->operand(1); + auto false_arg = conditional->operand(2); + + llvm::Value* conditional_result = GetBasePointer(*conditional); + + llvm::LoadInst* pred_value = ir_builder_.CreateLoad( + GetBasePointer(*pred), + llvm_ir::AsStringRef(IrName(conditional, "load_predicate_value"))); + llvm::Value* pred_cond = ir_builder_.CreateICmpNE( + pred_value, + llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0), + llvm_ir::AsStringRef(IrName(conditional, "boolean_predicate"))); + llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse( + pred_cond, IrName(conditional, "if_then_else"), &ir_builder_); + + SetToFirstInsertPoint(if_data.true_block, &ir_builder_); + TF_RETURN_IF_ERROR(EmitCallToNestedComputation( + *conditional->true_computation(), {GetBasePointer(*true_arg)}, + conditional_result)); + + SetToFirstInsertPoint(if_data.false_block, &ir_builder_); + TF_RETURN_IF_ERROR(EmitCallToNestedComputation( + *conditional->false_computation(), {GetBasePointer(*false_arg)}, + conditional_result)); + + SetToFirstInsertPoint(if_data.after_block, &ir_builder_); + return Status::OK(); +} + llvm_ir::IrArray::Index IrEmitter::EmitOperandArrayLoopNest( const llvm_ir::IrArray& operand_array, int64 reduction_dimension, tensorflow::StringPiece name_suffix, llvm_ir::ForLoopNest* loop_nest) { @@ -648,8 +797,8 @@ llvm_ir::IrArray::Index IrEmitter::EmitOperandArrayLoopNest( // reduction dimension. std::vector dimensions; const Shape& shape = operand_array.GetShape(); - for (int i = shape.layout().minor_to_major_size() - 1; i >= 0; --i) { - int64 dimension = shape.layout().minor_to_major(i); + for (int i = 0; i < LayoutUtil::MinorToMajor(shape).size(); ++i) { + int64 dimension = LayoutUtil::Major(shape.layout(), i); if (dimension != reduction_dimension) { dimensions.push_back(dimension); } diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.h b/tensorflow/compiler/xla/service/gpu/ir_emitter.h index 9c01f5b7c72f429822300af28bfd5261150d33d1..39bafaa34656a35f24444dc7f3665c1250833921 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.h @@ -79,6 +79,7 @@ class IrEmitter : public DfsHloVisitorWithDefault { Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; Status HandleDot(HloInstruction* dot) override; Status HandleConvolution(HloInstruction* convolution) override; + Status HandleFft(HloInstruction* fft) override; Status HandleCrossReplicaSum(HloInstruction* crs) override; Status HandleInfeed(HloInstruction* infeed) override; Status HandleOutfeed(HloInstruction* outfeed) override; @@ -95,6 +96,10 @@ class IrEmitter : public DfsHloVisitorWithDefault { Status HandleCall(HloInstruction* call) override; Status HandleCustomCall(HloInstruction* custom_call) override; Status HandleRng(HloInstruction* random) override; + Status HandleConditional(HloInstruction* conditional) override; + Status HandleBatchNormInference(HloInstruction* batch_norm) override; + Status HandleBatchNormTraining(HloInstruction* batch_norm) override; + Status HandleBatchNormGrad(HloInstruction* batch_norm) override; Status FinishVisit(HloInstruction* root) override { return Status::OK(); } @@ -185,9 +190,16 @@ class IrEmitter : public DfsHloVisitorWithDefault { // be simply implemented using an LLVM atomic instruction. If "computation" is // one of this kind, emits code to do that and returns true; otherwise, // returns false. - bool MaybeEmitSpecialAtomicOperation(const HloComputation& computation, - llvm::Value* output_address, - llvm::Value* source_address); + bool MaybeEmitDirectAtomicOperation(const HloComputation& computation, + llvm::Value* output_address, + llvm::Value* source_address); + + // A helper method for EmitAtomicOperationForNestedComputation. It implements + // binary atomic operations using atomicCAS with special handling to support + // small data types. + Status EmitAtomicOperationUsingCAS(const HloComputation& computation, + llvm::Value* output_address, + llvm::Value* source_address); StatusOr ComputeNestedElement( const HloComputation& computation, @@ -227,8 +239,11 @@ class IrEmitterUnnested : public IrEmitter { // IrEmitterUnnested handles the following instructions differently from // IrEmitter. Status HandleCopy(HloInstruction* copy) override; + Status HandleConditional(HloInstruction* conditional) override; Status HandleConvolution(HloInstruction* convolution) override; + Status HandleCustomCall(HloInstruction* custom_call) override; Status HandleDot(HloInstruction* dot) override; + Status HandleFft(HloInstruction* fft) override; Status HandleFusion(HloInstruction* fusion) override; Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; Status HandleReduce(HloInstruction* reduce) override; @@ -292,6 +307,12 @@ class IrEmitterUnnested : public IrEmitter { const llvm_ir::ElementGenerator& init_value_gen, HloComputation* reducer); + // Emits code that reduces a tensor of arbitrary rank to a scalar. + Status EmitReductionToScalar(HloInstruction* reduce, const Shape& input_shape, + const llvm_ir::ElementGenerator& input_gen, + const llvm_ir::ElementGenerator& init_value_gen, + HloComputation* reducer); + // Figures out whether `reduce` is a row or column reduction, and which // dimensions to reduce, and calls either `EmitRowReduction` or // `EmitColumnReduction` as appropriate. `input_shape` is the shape of the @@ -319,6 +340,9 @@ class IrEmitterUnnested : public IrEmitter { // Returns a ConvolutionThunk that calls DNN to implement `inst`. std::unique_ptr BuildConvolutionThunk(const HloInstruction* inst); + // Returns a FftThunk that calls cuFFT to implement `inst`. + std::unique_ptr BuildFftThunk(const HloInstruction* inst); + // Returns a GemmThunk that calls gemm to implement `inst`. The caller needs // to make sure `inst` outlives the lifetime of the returned Thunk object. std::unique_ptr BuildGemmThunk(const HloInstruction* inst); diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 1b863c9e3c51d6e757751154abd653cd1fdcb8a7..be35351e8727ce15998460e41f21a53ebe427c3b 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -30,8 +30,11 @@ limitations under the License. #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" #include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h" #include "tensorflow/compiler/xla/service/gpu/copy_thunk.h" +#include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h" +#include "tensorflow/compiler/xla/service/gpu/fft_thunk.h" #include "tensorflow/compiler/xla/service/gpu/for_thunk.h" #include "tensorflow/compiler/xla/service/gpu/gemm_thunk.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_constants.h" #include "tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h" #include "tensorflow/compiler/xla/service/gpu/infeed_thunk.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" @@ -123,10 +126,12 @@ void UpdateLaunchDimensions(const LaunchDimensions& launch_dims, Thunk* thunk, llvm::ConstantInt* threads_per_block_ir_value = llvm::ConstantInt::get( llvm::IntegerType::get(llvm_context, /*NumBits=*/32), launch_dims.threads_per_block()); + // Our launch bounds are exact, so we can specify them as reqntidx rather than + // maxntidx. nvvm_annotations_node->addOperand(llvm::MDNode::get( llvm_context, {llvm::ConstantAsMetadata::get(ir_kernel), - llvm::MDString::get(llvm_context, "maxntidx"), + llvm::MDString::get(llvm_context, "reqntidx"), llvm::ConstantAsMetadata::get(threads_per_block_ir_value)})); } } // namespace @@ -181,15 +186,15 @@ llvm::Function* IrEmitterUnnested::BuildKernelPrototype( string kernel_name = ir_emitter_context_->name_uniquer()->GetUniqueName( llvm_ir::SanitizeFunctionName(inst.name())); - // Create the kernel and adds it to the module. + // Create the kernel and add it to the module. llvm::Module* module = ir_emitter_context_->llvm_module(); llvm::LLVMContext& context = module->getContext(); int num_escaped_hlos = escaped_hlos.size(); llvm::FunctionType* kernel_type = llvm::FunctionType::get( - llvm::Type::getVoidTy(context), // The type of function result. + /*Result=*/llvm::Type::getVoidTy(context), std::vector(num_escaped_hlos + 1, ir_builder_.getInt8PtrTy()), - false); // Not a variadic argument function. + /*isVarArg=*/false); llvm::Function* kernel = llvm::Function::Create(kernel_type, llvm::GlobalValue::ExternalLinkage, kernel_name.c_str(), module); @@ -214,7 +219,14 @@ llvm::Function* IrEmitterUnnested::BuildKernelPrototype( kernel->addDereferenceableAttr(temp_buffer_arg_no + 1, temp_allocation_total_size); } - kernel->addAttribute(temp_buffer_arg_no + 1, llvm::Attribute::NoAlias); + kernel->addParamAttr(temp_buffer_arg_no, llvm::Attribute::NoAlias); + + // All arguments to a kernel must be aligned to kCudaMallocAlignBytes. + for (int64 i = 0; i < kernel->arg_size(); ++i) { + kernel->addParamAttr( + i, llvm::Attribute::get(context, llvm::Attribute::Alignment, + kCudaMallocAlignBytes)); + } // TODO(b/65380986): Investigate if adding fast math flags for generated // kernels makes sense. @@ -246,6 +258,11 @@ Status IrEmitterUnnested::DefaultAction(HloInstruction* hlo) { } Status IrEmitterUnnested::HandleDot(HloInstruction* dot) { + const DotDimensionNumbers& dnums = dot->dot_dimension_numbers(); + if (dnums.lhs_batch_dimensions_size() > 0 || + dnums.rhs_batch_dimensions_size() > 0) { + return Unimplemented("Dot with batch dimensions not implemented."); + } if (ImplementedAsGemm(*dot)) { thunk_sequence_->emplace_back(BuildGemmThunk(dot)); return Status::OK(); @@ -254,6 +271,11 @@ Status IrEmitterUnnested::HandleDot(HloInstruction* dot) { return IrEmitter::HandleDot(dot); } +Status IrEmitterUnnested::HandleConditional(HloInstruction* conditional) { + thunk_sequence_->push_back(BuildKernelThunk(conditional)); + return IrEmitter::HandleConditional(conditional); +} + Status IrEmitterUnnested::HandleConvolution(HloInstruction* convolution) { if (ImplementedAsDnnConvolution(*convolution)) { thunk_sequence_->emplace_back(BuildConvolutionThunk(convolution)); @@ -263,6 +285,111 @@ Status IrEmitterUnnested::HandleConvolution(HloInstruction* convolution) { return IrEmitter::HandleConvolution(convolution); } +Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) { + // A CustomCall on the GPU backend can either be a custom-call to a + // user-supplied kernel, or a call into a library like cudnn. + + // Lower custom-calls to cudnn batchnorm ops to specialized thunks. It's part + // of the contract of these cudnn batchnorm calls that the epsilon and + // feature_index operands be constants. + if (custom_call->custom_call_target() == + kCudnnBatchNormForwardInferenceCallTarget) { + const HloInstruction* epsilon = custom_call->operand(5); + CHECK(epsilon->IsConstant()); + float epsilon_value = epsilon->literal().Get({}); + + const HloInstruction* feature_index = custom_call->operand(6); + CHECK(feature_index->IsConstant()); + int64 feature_index_value = feature_index->literal().Get({}); + + thunk_sequence_->emplace_back( + MakeUnique( + /*operand=*/GetAllocationSlice(*custom_call->operand(0)), + /*scale=*/GetAllocationSlice(*custom_call->operand(1)), + /*offset=*/GetAllocationSlice(*custom_call->operand(2)), + /*mean=*/GetAllocationSlice(*custom_call->operand(3)), + /*variance=*/GetAllocationSlice(*custom_call->operand(4)), + /*epsilon=*/epsilon_value, + /*feature_index=*/feature_index_value, + /*output=*/GetAllocationSlice(*custom_call), + /*hlo=*/custom_call)); + return Status::OK(); + } + + if (custom_call->custom_call_target() == + kCudnnBatchNormForwardTrainingCallTarget) { + const HloInstruction* epsilon = custom_call->operand(3); + CHECK(epsilon->IsConstant()); + float epsilon_value = epsilon->literal().Get({}); + + const HloInstruction* feature_index = custom_call->operand(4); + CHECK(feature_index->IsConstant()); + int64 feature_index_value = feature_index->literal().Get({}); + + // BatchNormTraining returns a tuple of three elements: data, calculated + // mean, and calculated 1/sqrt(variance + epsilon). + const auto& assn = ir_emitter_context_->buffer_assignment(); + auto output_data = assn.GetUniqueSlice(custom_call, {0}).ValueOrDie(); + auto output_mean = assn.GetUniqueSlice(custom_call, {1}).ValueOrDie(); + auto output_inv_stddev = assn.GetUniqueSlice(custom_call, {2}).ValueOrDie(); + thunk_sequence_->emplace_back( + MakeUnique( + /*operand=*/GetAllocationSlice(*custom_call->operand(0)), + /*scale=*/GetAllocationSlice(*custom_call->operand(1)), + /*offset=*/GetAllocationSlice(*custom_call->operand(2)), + /*epsilon=*/epsilon_value, + /*feature_index=*/feature_index_value, + /*output_data=*/output_data, + /*output_mean=*/output_mean, + /*output_inv_stddev=*/output_inv_stddev, + /*output_tuple=*/GetAllocationSlice(*custom_call), + /*hlo=*/custom_call)); + return Status::OK(); + } + + if (custom_call->custom_call_target() == kCudnnBatchNormBackwardCallTarget) { + const HloInstruction* epsilon = custom_call->operand(5); + CHECK(epsilon->IsConstant()); + float epsilon_value = epsilon->literal().Get({}); + + const HloInstruction* feature_index = custom_call->operand(6); + CHECK(feature_index->IsConstant()); + int64 feature_index_value = feature_index->literal().Get({}); + + // BatchNormGrad returns a tuple of three elements: grad_data, grad_scale, + // grad_offset. + const auto& assn = ir_emitter_context_->buffer_assignment(); + auto output_grad_data = assn.GetUniqueSlice(custom_call, {0}).ValueOrDie(); + auto output_grad_scale = assn.GetUniqueSlice(custom_call, {1}).ValueOrDie(); + auto output_grad_offset = + assn.GetUniqueSlice(custom_call, {2}).ValueOrDie(); + thunk_sequence_->emplace_back(MakeUnique( + /*operand=*/GetAllocationSlice(*custom_call->operand(0)), + /*scale=*/GetAllocationSlice(*custom_call->operand(1)), + /*mean=*/GetAllocationSlice(*custom_call->operand(2)), + /*inv_stddev=*/GetAllocationSlice(*custom_call->operand(3)), + /*grad_output=*/GetAllocationSlice(*custom_call->operand(4)), + /*epsilon=*/epsilon_value, + /*feature_index=*/feature_index_value, + /*output_grad_data=*/output_grad_data, + /*output_grad_scale=*/output_grad_scale, + /*output_grad_offset=*/output_grad_offset, + /*output_tuple=*/GetAllocationSlice(*custom_call), + /*hlo=*/custom_call)); + return Status::OK(); + } + + return IrEmitter::HandleCustomCall(custom_call); +} + +Status IrEmitterUnnested::HandleFft(HloInstruction* fft) { + TF_RET_CHECK( + LayoutUtil::IsMonotonicWithDim0Major(fft->operand(0)->shape().layout())); + TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(fft->shape().layout())); + thunk_sequence_->emplace_back(BuildFftThunk(fft)); + return Status::OK(); +} + Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { HloInstruction* root = fusion->fused_expression_root(); // HandleFusion specializes reduction from a multi-dimensional array to a 1D @@ -407,8 +534,8 @@ Shape MergeDimensions(tensorflow::gtl::ArraySlice segs, (segs.size() == i ? shape.dimensions().size() : segs[i]), 1, std::multiplies())); } - return ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout(shape.element_type(), - dimensions); + return ShapeUtil::MakeShapeWithDescendingLayout(shape.element_type(), + dimensions); } // Returns whether the given shapes and permutation are a 0-2-1 transpose, and @@ -421,20 +548,22 @@ std::tuple IsTranspose021(const Shape& a, const Shape& b) { CHECK(ShapeUtil::Compatible(a, b)); std::vector perm(a.dimensions().size()); { - std::vector layout_a(a.layout().minor_to_major().rbegin(), - a.layout().minor_to_major().rend()); - std::vector layout_b(b.layout().minor_to_major().rbegin(), - b.layout().minor_to_major().rend()); + auto layout_a_orig = LayoutUtil::MinorToMajor(a); + std::vector layout_a(layout_a_orig.rbegin(), layout_a_orig.rend()); + auto layout_b_orig = LayoutUtil::MinorToMajor(b); + std::vector layout_b(layout_b_orig.rbegin(), layout_b_orig.rend()); for (size_t i = 0; i < perm.size(); ++i) { perm[i] = PositionInContainer(layout_b, layout_a[i]); } } auto segs = ConsecutiveSegments(perm); - Shape norm_a = ShapeUtil::NormalizeShapeToMonotonicDim0MajorLayout(a); - Shape norm_b = ShapeUtil::NormalizeShapeToMonotonicDim0MajorLayout(b); + Shape norm_a = + ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(a); + Shape norm_b = + ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(b); if (3 == segs.size() && 0 == perm[0]) { Shape reduced_a = MergeDimensions(segs, norm_a); - Shape reduced_b = ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout( + Shape reduced_b = ShapeUtil::MakeShapeWithDescendingLayout( b.element_type(), Permute({0, 2, 1}, AsInt64Slice(reduced_a.dimensions()))); return std::make_tuple(true, reduced_a, reduced_b); @@ -448,10 +577,11 @@ std::tuple IsTranspose021(const Shape& a, const Shape& b) { bool AreShapesForTranspose021(const Shape& a, const Shape& b) { return 3 == b.dimensions().size() && ShapeUtil::Compatible( - ShapeUtil::NormalizeShapeToMonotonicDim0MajorLayout(a), + ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(a), ShapeUtil::PermuteDimensions( {0, 2, 1}, - ShapeUtil::NormalizeShapeToMonotonicDim0MajorLayout(b))); + ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( + b))); } // Emits a tiled 0-2-1 transpose, assuming both input and output lain out from @@ -483,9 +613,11 @@ int64 EmitTranspose021Tiled(llvm_ir::IrArray input, llvm_ir::IrArray output, CHECK(AreShapesForTranspose021(input.GetShape(), output.GetShape())); Shape input_shape = - ShapeUtil::NormalizeShapeToMonotonicDim0MajorLayout(input.GetShape()); + ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( + input.GetShape()); Shape output_shape = - ShapeUtil::NormalizeShapeToMonotonicDim0MajorLayout(output.GetShape()); + ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( + output.GetShape()); input = input.CastToShape(input_shape, builder); output = output.CastToShape(output_shape, builder); @@ -603,7 +735,7 @@ int64 EmitTranspose021Tiled(llvm_ir::IrArray input, llvm_ir::IrArray output, llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x, {}, {}, builder))), builder->getInt64Ty(), /*isSigned=*/true, "block.id.x"), - ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout( + ShapeUtil::MakeShapeWithDescendingLayout( PRED /*arbitrary*/, AsInt64Slice(input_dims_in_tiles)), builder); const llvm_ir::IrArray::Index input_tile_origin = ({ @@ -706,6 +838,194 @@ Status IrEmitterUnnested::HandleCopy(HloInstruction* copy) { return IrEmitter::HandleCopy(copy); } +Status IrEmitterUnnested::EmitReductionToScalar( + HloInstruction* reduce, const Shape& input_shape, + const llvm_ir::ElementGenerator& input_gen, + const llvm_ir::ElementGenerator& init_value_gen, HloComputation* reducer) { + // Number of elements processed by a single thread. + constexpr int64 kTileSize = 16; + int64 num_elems = ShapeUtil::ElementsIn(input_shape); + + // Round up the number of tiles to a multiple of the warp size. This is + // necessary for correctness. We launch one thread per tile, and if the + // number of threads isn't a multiple of the number of the warp size, our + // shuffles will read from inactive threads, producing undefined values. + int64 num_tiles = + RoundUpToNearest(CeilOfRatio(num_elems, kTileSize), kWarpSize); + + // Check whether every thread will process a full tile's worth of elements + // without reading outside the bounds of the input. If this is true, we can + // skip some bounds checks in the final algorithm. + bool all_threads_in_bounds = num_tiles * kTileSize == num_elems; + + // __global__ void full_reduce_kernel() { + // x_in_tiles = threadIdx.x + blockIdx.x * blockDim.x; + // x = x_in_tiles * kTileSize; + // + // partial_result = init_value; + // if (all_threads_in_bounds || x + kTileSize <= num_elems) { + // for (i = 0; i < kTileSize; ++i) { + // partial_result = Reducer(partial_result, input[x + i]); + // } + // } else { + // for (i = 0; i < kTileSize; ++i) { + // if (x + i < num_elems) { + // partial_result = Reducer(partial_result, input[x + i]); + // } + // } + // } + // for (i = warpSize / 2; i > 0; i /= 2) { + // partial_result = Reducer(partial_result, + // __shfl_down(partial_result, i)); + // } + // if (lane_id == 0) { + // AtomicReducer(&output[y], partial_result); + // } + // } + // + // // Choose num_blocks and threads_per_block such that: + // // + // // num_blocks * threads_per_block = + // // RoundUpToNextMultipleOf(Ceil(num_elems / kTileSize), warpSize), + // // + // // and threads_per_block is a multiple of warpSize. + // reduce_kernel<<>>(); + // + auto loop_body_emitter = + [=](const llvm_ir::IrArray::Index& tile_index) -> Status { + llvm::Type* element_ir_type = + llvm_ir::PrimitiveTypeToIrType(input_shape.element_type(), module_); + llvm::Value* partial_reduction_result_address = ir_builder_.CreateAlloca( + element_ir_type, /*ArraySize=*/nullptr, "partial_reduction_result"); + { + TF_ASSIGN_OR_RETURN(llvm::Value * init_ir_value, + init_value_gen(llvm_ir::IrArray::Index({}))); + ir_builder_.CreateStore(init_ir_value, partial_reduction_result_address); + } + + llvm::Value* x_in_tiles = tile_index[0]; + + // Emit an inner for-loop that reduces the elements in the tile. + auto emit_tile_element_loop = [=](bool tile_in_bounds) -> Status { + std::unique_ptr tile_element_loop = + llvm_ir::ForLoop::EmitForLoop("element_id_in_tile", + ir_builder_.getInt64(0), + ir_builder_.getInt64(kTileSize), + ir_builder_.getInt64(1), &ir_builder_); + + // Emit the body of the partial reduction loop. + llvm_ir::SetToFirstInsertPoint(tile_element_loop->GetBodyBasicBlock(), + &ir_builder_); + llvm::Value* x = ir_builder_.CreateNSWAdd( + ir_builder_.CreateNSWMul(x_in_tiles, ir_builder_.getInt64(kTileSize)), + tile_element_loop->GetIndVarValue()); + // Unless we know the tile is entirely in bounds, we have to emit a + // x-in-bounds check before reading from the input. + if (!tile_in_bounds) { + llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse( + ir_builder_.CreateICmpULT(x, ir_builder_.getInt64(num_elems)), + "x_in_bounds", &ir_builder_); + + // Emit code that reads the input element and accumulates it to + // the partial reduction result. + llvm_ir::SetToFirstInsertPoint(if_data.true_block, &ir_builder_); + } + llvm_ir::IrArray::Index input_index( + /*linear=*/x, input_shape, &ir_builder_); + llvm::Value* input_address = ir_builder_.CreateAlloca(element_ir_type); + TF_ASSIGN_OR_RETURN(llvm::Value * input_ir_value, input_gen(input_index)); + ir_builder_.CreateStore(input_ir_value, input_address); + return (EmitCallToNestedComputation( + *reducer, {partial_reduction_result_address, input_address}, + partial_reduction_result_address)); + }; + + // x_end = kTileSize + x_in_tiles * kTileSize, i.e., the location that's + // immediately beyond the tile. + llvm::Value* x_end = ir_builder_.CreateNSWAdd( + ir_builder_.getInt64(kTileSize), + ir_builder_.CreateNSWMul(x_in_tiles, ir_builder_.getInt64(kTileSize))); + // The tile is entirely in bound if all_threads_in_bounds or + // x_end <= num_elems. + llvm::Value* tile_in_bounds = ir_builder_.CreateOr( + ir_builder_.CreateICmpULE(x_end, ir_builder_.getInt64(num_elems)), + ir_builder_.getInt1(all_threads_in_bounds)); + llvm_ir::LlvmIfData if_tile_in_bounds_data = + llvm_ir::EmitIfThenElse(tile_in_bounds, "tile_in_bounds", &ir_builder_); + llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.true_block, + &ir_builder_); + TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_bounds=*/true)); + llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.false_block, + &ir_builder_); + TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_bounds=*/false)); + + // After the if-then-else statement on tile_in_bounds, emit calls to + // shfl_down that accumulate the partial reduction results of all threads + // from the warp. + llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.after_block, + &ir_builder_); + int bit_width = llvm_ir::GetSizeInBits(element_ir_type); + // bitcast cannot be applied to aggregate types (even packed ones), so we + // instead bitcast addresses of load/store to intN* of the same bit-width. + llvm::Type* shuffle_ir_type = element_ir_type->isStructTy() + ? ir_builder_.getIntNTy(bit_width) + : element_ir_type; + for (int shuffle_distance = kWarpSize / 2; shuffle_distance >= 1; + shuffle_distance /= 2) { + llvm::Value* partial_reduction_result = ir_builder_.CreateLoad( + ir_builder_.CreateBitCast(partial_reduction_result_address, + shuffle_ir_type->getPointerTo()), + "partial_reduction_result"); + llvm::Value* result_from_other_lane = ir_builder_.CreateAlloca( + element_ir_type, nullptr, "result_from_other_lane"); + ir_builder_.CreateStore( + EmitShuffleDown(partial_reduction_result, + ir_builder_.getInt32(shuffle_distance), &ir_builder_), + ir_builder_.CreateBitCast(result_from_other_lane, + shuffle_ir_type->getPointerTo())); + TF_RETURN_IF_ERROR(EmitCallToNestedComputation( + *reducer, {partial_reduction_result_address, result_from_other_lane}, + partial_reduction_result_address)); + } + + const HloInstruction* output = + reduce->IsFused() ? reduce->parent()->FusionInstruction() : reduce; + + // Emit an atomic operation that accumulates the partial reduction result of + // lane 0 (which holds the partially accumulated result for its warp) to the + // output element. + llvm::Value* lane_id = ir_builder_.CreateURem( + x_in_tiles, ir_builder_.getInt64(kWarpSize), "lane_id"); + llvm_ir::LlvmIfData if_lane_id_is_zero_data = llvm_ir::EmitIfThenElse( + ir_builder_.CreateICmpEQ(lane_id, ir_builder_.getInt64(0)), + "lane_id_is_zero", &ir_builder_); + llvm_ir::SetToFirstInsertPoint(if_lane_id_is_zero_data.true_block, + &ir_builder_); + llvm::Value* output_address = + GetIrArray(*output, *output) + .EmitArrayElementAddress( + llvm_ir::IrArray::Index(/*linear=*/ir_builder_.getInt64(0), + output->shape(), &ir_builder_), + &ir_builder_, "output_element_address"); + return EmitAtomicOperationForNestedComputation( + *reducer, output_address, partial_reduction_result_address); + }; + + // Emit a parallel loop that iterates through all input tiles, one per thread. + Shape tiled_input_shape = ShapeUtil::MakeShapeWithLayout( + reduce->shape().element_type(), {num_tiles}, {0}); + LaunchDimensions launch_dimensions = CalculateLaunchDimensions( + tiled_input_shape, ir_emitter_context_->device_description()); + CHECK(LastThunk()->kind() == Thunk::Kind::kSequential); + UpdateLaunchDimensions( + launch_dimensions, + static_cast(LastThunk())->thunks().back().get(), + ir_emitter_context_->llvm_module()); + return ParallelLoopEmitter(loop_body_emitter, tiled_input_shape, + launch_dimensions, &ir_builder_) + .EmitLoop(IrName(reduce)); +} + Status IrEmitterUnnested::EmitColumnReduction( int64 height, int64 width, HloInstruction* reduce, const Shape& input_shape, const llvm_ir::ElementGenerator& input_gen, @@ -799,14 +1119,15 @@ Status IrEmitterUnnested::EmitColumnReduction( // input_shape to normalized_input_shape and a reshape from // normalized_input_shape to input_matrix_shape. const Shape normalized_input_shape = - ShapeUtil::NormalizeShapeToMonotonicDim0MajorLayout(input_shape); + ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( + input_shape); + auto input_shape_min2maj = LayoutUtil::MinorToMajor(input_shape); const std::vector transpose_dimension_mapping( - input_shape.layout().minor_to_major().rbegin(), - input_shape.layout().minor_to_major().rend()); + input_shape_min2maj.rbegin(), input_shape_min2maj.rend()); const Shape input_matrix_shape = - ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout( - input_shape.element_type(), {height, width}); + ShapeUtil::MakeShapeWithDescendingLayout(input_shape.element_type(), + {height, width}); const llvm_ir::IrArray::Index input_matrix_index( {y, x}, input_matrix_shape, &ir_builder_); const llvm_ir::IrArray::Index input_index = @@ -901,7 +1222,7 @@ Status IrEmitterUnnested::EmitRowReduction( // // Three optimizations are performed. // - // 1. To coalesc global memory accesses, dilate the tile with a factor of 32 + // 1. To coalesce global memory accesses, dilate the tile with a factor of 32 // (i.e. the warp size). For example, suppose the width is 8x32=256. Instead // of making each tile consecutive, we let make tile 0 column // [0,32,64,...,224], tile 1 column [1,33,65,...,225], and so on. This ensures @@ -1042,13 +1363,14 @@ Status IrEmitterUnnested::EmitRowReduction( // from input_shape to normalized_input_shape and a reshape from // normalized_input_shape to input_3d_tensor_shape. const Shape normalized_input_shape = - ShapeUtil::NormalizeShapeToMonotonicDim0MajorLayout(input_shape); + ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( + input_shape); + auto input_shape_min2maj = LayoutUtil::MinorToMajor(input_shape); const std::vector transpose_dimension_mapping( - input_shape.layout().minor_to_major().rbegin(), - input_shape.layout().minor_to_major().rend()); + input_shape_min2maj.rbegin(), input_shape_min2maj.rend()); const Shape input_3d_tensor_shape = - ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout( - input_shape.element_type(), {depth, height, width}); + ShapeUtil::MakeShapeWithDescendingLayout(input_shape.element_type(), + {depth, height, width}); const llvm_ir::IrArray::Index input_3d_tensor_index( {z, y, x}, input_3d_tensor_shape, &ir_builder_); const llvm_ir::IrArray::Index input_index = @@ -1177,9 +1499,9 @@ Status IrEmitterUnnested::EmitReductionToVector( // whether another dimension is major or minor of them. std::sort(input_dims_to_keep.begin(), input_dims_to_keep.end(), [&input_shape](int64 dim_a, int64 dim_b) { - return PositionInContainer(input_shape.layout().minor_to_major(), + return PositionInContainer(LayoutUtil::MinorToMajor(input_shape), dim_a) < - PositionInContainer(input_shape.layout().minor_to_major(), + PositionInContainer(LayoutUtil::MinorToMajor(input_shape), dim_b); }); // Now, if output rank is at least 1, `input_dims_to_keep.front()` is @@ -1189,14 +1511,11 @@ Status IrEmitterUnnested::EmitReductionToVector( // the dimensions to keep are contiguous, by prerequisite of // `EmitReductionToVector`, we only need to check whether the minormost // dimension of the input is to keep. - // - // If the output is scalar, we could emit either a row or a column reduction. - // Some tests have shown scalar reduction is no more efficient as row - // reduction, and is simpler to emit as column reduction, so we emit a column - // reduction in this case. - if (input_dims_to_keep.empty() || - input_dims_to_keep.front() == - LayoutUtil::Minor(input_shape.layout(), 0)) { + if (input_dims_to_keep.empty()) { + return EmitReductionToScalar(reduce, input_shape, input_gen, init_value_gen, + reducer); + } else if (input_dims_to_keep.front() == + LayoutUtil::Minor(input_shape.layout(), 0)) { // Column reduction. Treat the result of "input" as a matrix whose width // is the most minor dimension and height the product of other dimensions, // and treat "reduce" as a column reduction of the input matrix. @@ -1224,14 +1543,14 @@ Status IrEmitterUnnested::EmitReductionToVector( int64 width = 1; for (int64 input_dim = 0; input_dim < ShapeUtil::Rank(input_shape); ++input_dim) { - if (PositionInContainer(input_shape.layout().minor_to_major(), + if (PositionInContainer(LayoutUtil::MinorToMajor(input_shape), input_dim) > - PositionInContainer(input_shape.layout().minor_to_major(), + PositionInContainer(LayoutUtil::MinorToMajor(input_shape), input_dims_to_keep.back())) { depth *= input_shape.dimensions(input_dim); - } else if (PositionInContainer(input_shape.layout().minor_to_major(), + } else if (PositionInContainer(LayoutUtil::MinorToMajor(input_shape), input_dim) < - PositionInContainer(input_shape.layout().minor_to_major(), + PositionInContainer(LayoutUtil::MinorToMajor(input_shape), input_dims_to_keep.front())) { width *= input_shape.dimensions(input_dim); } @@ -1611,7 +1930,7 @@ std::unique_ptr IrEmitterUnnested::BuildHostToDeviceCopyThunk( const HloInstruction* operand = inst->operand(0); CHECK_EQ(HloOpcode::kConstant, operand->opcode()); return MakeUnique( - /*source_address=*/operand->literal().InternalData(), + /*source_address=*/operand->literal().untyped_data(), /*destination_buffer=*/GetAllocationSlice(*inst), /*mem_size=*/ llvm_ir::ByteSizeOf(operand->shape(), @@ -1738,6 +2057,16 @@ std::unique_ptr IrEmitterUnnested::BuildConvolutionThunk( } } +std::unique_ptr IrEmitterUnnested::BuildFftThunk( + const HloInstruction* inst) { + const HloInstruction* operand = inst->operand(0); + return MakeUnique(inst->fft_type(), inst->fft_length(), + /*input_buffer=*/GetAllocationSlice(*operand), + /*output_buffer=*/GetAllocationSlice(*inst), + /*input_shape=*/operand->shape(), + /*output_shape=*/inst->shape(), inst); +} + Status IrEmitterUnnested::EmitInitializer(const HloInstruction* hlo, KernelThunk* thunk) { bool fused = HloOpcode::kFusion == hlo->opcode(); diff --git a/tensorflow/compiler/xla/service/gpu/layout_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/layout_assignment_test.cc deleted file mode 100644 index ac206b89d329d7e4ac91ee51162c9694f6899d78..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/gpu/layout_assignment_test.cc +++ /dev/null @@ -1,85 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/gpu/layout_assignment.h" - -#include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/service/computation_layout.h" -#include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/service/hlo_module.h" -#include "tensorflow/compiler/xla/service/hlo_opcode.h" -#include "tensorflow/compiler/xla/shape_layout.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" - -namespace xla { -namespace gpu { -namespace { - -using LayoutAssignmentTest = HloTestBase; - -TEST_F(LayoutAssignmentTest, Elementwise) { - Shape ashape = ShapeUtil::MakeShape(F32, {42, 12}); - Shape ashape_in_row_major(ashape); - Shape ashape_in_col_major(ashape); - *ashape_in_row_major.mutable_layout() = LayoutUtil::MakeLayout({1, 0}); - *ashape_in_col_major.mutable_layout() = LayoutUtil::MakeLayout({0, 1}); - - // Enumerate all possible combinations of layouts. - for (const Shape& lhs_shape_with_layout : - {ashape_in_row_major, ashape_in_col_major}) { - for (const Shape& rhs_shape_with_layout : - {ashape_in_row_major, ashape_in_col_major}) { - for (const Shape& result_shape_with_layout : - {ashape_in_row_major, ashape_in_col_major}) { - // GpuLayoutAssignment should assign the same layout to "add" and its - // two operands. - auto builder = HloComputation::Builder(TestName()); - auto x = builder.AddInstruction( - HloInstruction::CreateParameter(0, ashape, "x")); - auto y = builder.AddInstruction( - HloInstruction::CreateParameter(1, ashape, "y")); - auto add = builder.AddInstruction( - HloInstruction::CreateBinary(ashape, HloOpcode::kAdd, x, y)); - auto module = CreateNewModule(); - HloComputation* computation = - module->AddEntryComputation(builder.Build(add)); - - ComputationLayout computation_layout( - computation->ComputeProgramShape()); - *computation_layout.mutable_parameter_layout(0) = - ShapeLayout(lhs_shape_with_layout); - *computation_layout.mutable_parameter_layout(1) = - ShapeLayout(rhs_shape_with_layout); - *computation_layout.mutable_result_layout() = - ShapeLayout(result_shape_with_layout); - - GpuLayoutAssignment layout_assignment(&computation_layout); - EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie()); - - for (const HloInstruction* operand : add->operands()) { - EXPECT_TRUE(LayoutUtil::Equal(add->shape().layout(), - operand->shape().layout())); - } - } - } - } -} - -} // namespace -} // namespace gpu -} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc index 059943d48cd34b0ac487b91c3f3079ee3f761229..cfabae791d26d0eb49826085ad7ad166a19109a1 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc @@ -440,7 +440,7 @@ StatusOr CompileModuleToPtx(llvm::Module* module, // One-time module initializer. // Must be called only once -- DO NOT CALL DIRECTLY. -void GPUBackendInit() { +void GPUBackendInit(const HloModuleConfig& hlo_module_config) { // Feed all customized flags here, so we can override them with llvm_cl_opts // without redeploy the compiler for development purpose. @@ -466,6 +466,8 @@ void GPUBackendInit() { // between those loads. FeedLLVMWithFlags({"-memdep-block-scan-limit=500"}); + llvm_ir::InitializeLLVMCommandLineOptions(hlo_module_config); + // Initialize the NVPTX target; it's the only target we link with, so call its // specific initialization functions instead of the catch-all InitializeAll*. LLVMInitializeNVPTXTarget(); @@ -485,7 +487,7 @@ StatusOr CompileToPtx(llvm::Module* module, const HloModuleConfig& hlo_module_config, const string& libdevice_dir_path) { static std::once_flag backend_init_flag; - std::call_once(backend_init_flag, GPUBackendInit); + std::call_once(backend_init_flag, GPUBackendInit, hlo_module_config); string ptx; { diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc index 11290eda4ffcd579c03acd531b493bb7b1d34ed4..c29fee0879c02021fdc23ac0e02ab398cf40f99e 100644 --- a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc +++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc @@ -202,8 +202,7 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution( // ABCD0 = Pad(ABCD, padding_high=1) // BackwardFilterConv(ABCD0, xyz, padding_low=pading_high=1) // We choose the lesser of padding_low and padding_high as the new padding. - HloInstruction* transpose = backward_conv->fused_expression_root(); - HloInstruction* forward_conv = transpose->mutable_operand(0); + HloInstruction* forward_conv = backward_conv->fused_expression_root(); HloInstruction* input = backward_conv->mutable_operand(0); Window new_forward_conv_window = forward_conv->window(); Window new_backward_conv_window = backward_conv->window(); @@ -269,19 +268,10 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution( .ConsumeValueOrDie(), padded_input, output, new_forward_conv_window, forward_conv_dnums)); - HloInstruction* new_transpose = - computation->AddInstruction(HloInstruction::CreateTranspose( - ShapeInference::InferTransposeShape(new_forward_conv->shape(), - transpose->dimensions()) - .ConsumeValueOrDie(), - new_forward_conv, transpose->dimensions())); - - // Fuse the new forward convolution and the new transpose to the new backward - // convolution. + // Fuse the new forward convolution to the new backward convolution. HloInstruction* new_backward_conv = computation->CreateFusionInstructionForBackwardConvolution( - {new_transpose, new_forward_conv}, - HloInstruction::FusionKind::kConvBackwardFilter, + {new_forward_conv}, HloInstruction::FusionKind::kConvBackwardFilter, new_backward_conv_window, backward_conv_dnums); VLOG(1) << "Canonicalizing backward filter conv"; diff --git a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc index 457e6094d90413440658452937bff2ccfe6cbe5c..388dcc008b07a76ff9ed07df04181e49a8734f51 100644 --- a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc @@ -88,6 +88,23 @@ llvm_ir::IrArray::Index ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock( /*HasNUW=*/true, /*HasNSW=*/true), thread_id, "linear_index", /*HasNUW=*/true, /*HasNSW=*/true); + // Add an @llvm.assume(linear_index < threads_per_block * num_blocks). + // + // This might seem obvious from the computation above, but LLVM does not + // currently determine the range of linear_index precisely. InstCombine uses + // known-bits, which, when applied to the task of determining a value's range, + // is imprecise for everything other than powers of 2. And + // CorrelatedValuePropagation is, as a cost-saving measure, disabled for + // conditions in the same basic block as their operands. + llvm_ir::EmitCallToIntrinsic( + llvm::Intrinsic::assume, + {ir_builder_->CreateICmpULT( + linear_index, + ir_builder_->getInt64(launch_dimensions_.threads_per_block() * + launch_dimensions_.block_count()), + "linear_index_in_range")}, + {}, ir_builder_); + auto if_in_bounds = llvm_ir::EmitIfThenElse( ir_builder_->CreateICmpULT( linear_index, ir_builder_->getInt64(ShapeUtil::ElementsIn(shape_))), diff --git a/tensorflow/compiler/xla/service/gpu/partition_assignment.cc b/tensorflow/compiler/xla/service/gpu/partition_assignment.cc index d0d2deee24848184278e3e51dcaa3bb673b5fadc..6cf280df05496716a0780d61ded92efd9982734c 100644 --- a/tensorflow/compiler/xla/service/gpu/partition_assignment.cc +++ b/tensorflow/compiler/xla/service/gpu/partition_assignment.cc @@ -44,37 +44,41 @@ std::ostream& operator<<(std::ostream& out, // Calculates the launch dimensions used to invoke `hlo`. LaunchDimensions CalculateLaunchDimensions( - const Shape& shape, const se::DeviceDescription& device_desc, - PartitionStrategy partition_strategy) { - int64 warp_size = device_desc.threads_per_warp(); - + const Shape& shape, const se::DeviceDescription& device_desc) { int64 num_elements = ShapeUtil::ElementsIn(shape); if (num_elements <= 1) { return LaunchDimensions(); } - // Calculate the number of threads per block. - // Initialize threads_per_block as the threads-per-block limit. - int64 threads_per_block = device_desc.threads_per_block_limit(); - VLOG(2) << "Initial # of threads per block = " << threads_per_block; - - if (partition_strategy == PartitionStrategy::kLatency) { - // Limit the thread count to allow maximum number of registers per thread. - // TODO(b/28560520): We don't have to assume the emitted kernel will use up - // all the registers. We could use ptxas to examine the actual number of - // register used, and set the thread count accordingly. - int64 threads_per_block_limit_due_to_registers = - device_desc.registers_per_core_limit() / - device_desc.registers_per_thread_limit(); - CHECK_NE(0, threads_per_block_limit_due_to_registers); - if (threads_per_block_limit_due_to_registers < threads_per_block) { - threads_per_block = - // Make `threads_per_block` a multiple of warp size to use GPU - // efficiently. - warp_size * - std::max(1LL, threads_per_block_limit_due_to_registers / warp_size); - VLOG(2) << "Update # of threads per block due to register pressure = " - << threads_per_block; + // Since we don't do any inter-warp communication, we're free to choose any + // block size we want, subject to hardware constraints. We choose the + // smallest block size that allows the GPU to reach full occupancy (assuming + // the kernel uses sufficiently few registers). This gives us max performance + // when the kernel uses few registers, and lets us scale down gracefully as + // the kernel uses more registers. + // + // Specifically, we choose the number of threads per block such that + // + // * = + + auto threads_per_core = device_desc.threads_per_core_limit(); + auto blocks_per_core = device_desc.blocks_per_core_limit(); + int64 threads_per_block; + if (threads_per_core != 0 && blocks_per_core != 0) { + threads_per_block = device_desc.threads_per_core_limit() / + device_desc.blocks_per_core_limit(); + } else { + static std::atomic log_count{0}; + if (log_count.fetch_add(1) < 8) { + LOG(WARNING) << "Attempting to calculate launch dimensions for GPU " + "without full information about its capabilities. " + "StreamExecutor's PopulateDeviceDescription should be " + "updated for this device."; + } + threads_per_block = device_desc.threads_per_warp(); + if (threads_per_block == 0) { + // Fall back to *something* if we can't even get num threads per warp. + threads_per_block = 32; } } @@ -84,8 +88,6 @@ LaunchDimensions CalculateLaunchDimensions( << threads_per_block << ") because the latter is smaller."; } - // Calculate the block count. We copy the strategy used by Eigen: - // eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h int64 block_count = CeilOfRatio(num_elements, threads_per_block); VLOG(2) << tensorflow::strings::Printf( "Initialized the block count to ceil(# of elements / threads per " diff --git a/tensorflow/compiler/xla/service/gpu/partition_assignment.h b/tensorflow/compiler/xla/service/gpu/partition_assignment.h index 8f7fce884acc93fd39510ad0826b819a6d9731a7..0bf463a6ef95d5a32784838c08ad239752fd1acf 100644 --- a/tensorflow/compiler/xla/service/gpu/partition_assignment.h +++ b/tensorflow/compiler/xla/service/gpu/partition_assignment.h @@ -30,14 +30,6 @@ limitations under the License. namespace xla { namespace gpu { -enum class PartitionStrategy { - // Optimized for latency by allowing maximum number of registers per thread. - kLatency, - // Optimized for throughput. This may limit registers per thread and cause - // longer latency. - kThroughput -}; - // Encapsulates the launch dimensions of a kernel, e.g., the block count and the // number of threads per block. class LaunchDimensions { @@ -66,8 +58,7 @@ std::ostream& operator<<(std::ostream& out, LaunchDimensions CalculateLaunchDimensions( const Shape& shape, - const perftools::gputools::DeviceDescription& device_desc, - PartitionStrategy partition_strategy = PartitionStrategy::kLatency); + const perftools::gputools::DeviceDescription& device_desc); } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/thunk.h b/tensorflow/compiler/xla/service/gpu/thunk.h index 0ff27888ad72f8190400c22a9086d1965448662c..625c3f8bea418b7942145a05ba42b9ea9b14543b 100644 --- a/tensorflow/compiler/xla/service/gpu/thunk.h +++ b/tensorflow/compiler/xla/service/gpu/thunk.h @@ -43,6 +43,10 @@ class Thunk { enum class Kind { kConvolution, kCopy, + kCudnnBatchNormBackward, + kCudnnBatchNormForwardInference, + kCudnnBatchNormForwardTraining, + kFft, kGemm, kInfeed, kKernel, @@ -70,6 +74,29 @@ class Thunk { return tensorflow::Status::OK(); } + // Users of Thunk should call ShouldHaltAllActivityBeforeRunning(stream) + // before calling ExecuteOnStream(stream). If it returns true, it's the + // user's responsibility to wait for all activity on the GPU to finish before + // calling ExecuteOnStream. + // + // This value is not required to be constant for a given Thunk. For example, + // a Thunk that performs autotuning may return true for its first run and + // false thereafter. + virtual bool ShouldHaltAllActivityBeforeRunning( + perftools::gputools::Stream* /*stream*/) { + return false; + } + + // Indicates whether thunks scheduled after this one should wait for this one + // to complete before running. For example, a convolution thunk creates a + // scratch allocator, then kicks off a convolution in cudnn via the stream + // executor. When the stream executor call returns, the scratch allocator goes + // out of scope, and the scratch memory is deallocated. In this case, the + // convolution thunk needs to return true so that future thunks wait for the + // convolution thunk to avoid reusing the deallocated memory until the + // convolution thunk is done with it. + virtual bool ShouldBlockFutureThunks() { return false; } + // Execute the kernel for the thunk on the given stream. This method must be // called after Initialize and can be called multiple times over Thunk's // lifetime. Stream argument must be non-null. diff --git a/tensorflow/compiler/xla/service/gpu/while_thunk.cc b/tensorflow/compiler/xla/service/gpu/while_thunk.cc index 0d2412096abf7838b7b0e7617811c789f507a4a1..c21559af6d2e5dfb5aaf62afcdcaed514e0914c9 100644 --- a/tensorflow/compiler/xla/service/gpu/while_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/while_thunk.cc @@ -34,16 +34,14 @@ WhileThunk::WhileThunk( body_thunk_sequence_( MakeUnique(std::move(*body_thunk_sequence), hlo)) {} -tensorflow::Status WhileThunk::Initialize(const GpuExecutable& executable) { +Status WhileThunk::Initialize(const GpuExecutable& executable) { TF_RETURN_IF_ERROR(condition_thunk_sequence_->Initialize(executable)); TF_RETURN_IF_ERROR(body_thunk_sequence_->Initialize(executable)); - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status WhileThunk::ExecuteOnStream( - const BufferAllocations& buffer_allocations, - perftools::gputools::Stream* stream) { - +Status WhileThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, + perftools::gputools::Stream* stream) { perftools::gputools::DeviceMemoryBase condition_result_data = buffer_allocations.GetDeviceAddress(condition_result_buffer_index_); @@ -55,9 +53,11 @@ tensorflow::Status WhileThunk::ExecuteOnStream( // Copy the result of condition computation and break the loop if 'false'. bool condition_result; stream->ThenMemcpy(&condition_result, condition_result_data, sizeof(bool)); - if (!stream->BlockHostUntilDone()) { + Status block_status = stream->BlockHostUntilDone(); + if (!block_status.ok()) { return InternalError( - "Failed to complete all kernels launched on stream %p", stream); + "Failed to complete all kernels launched on stream %p: %s", stream, + block_status.error_message().c_str()); } if (!condition_result) { @@ -68,7 +68,7 @@ tensorflow::Status WhileThunk::ExecuteOnStream( TF_RETURN_IF_ERROR( body_thunk_sequence_->ExecuteOnStream(buffer_allocations, stream)); } - return tensorflow::Status::OK(); + return Status::OK(); } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/while_thunk.h b/tensorflow/compiler/xla/service/gpu/while_thunk.h index 95ed5497cea4fa3ba5dcdc6762cbd53cec88339a..4c9f45de9e42494df58706d0a4a3eb0c4220b8b8 100644 --- a/tensorflow/compiler/xla/service/gpu/while_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/while_thunk.h @@ -45,10 +45,9 @@ class WhileThunk : public Thunk { WhileThunk(const WhileThunk&) = delete; WhileThunk& operator=(const WhileThunk&) = delete; - tensorflow::Status Initialize(const GpuExecutable& executable) override; - tensorflow::Status ExecuteOnStream( - const BufferAllocations& buffer_allocations, - perftools::gputools::Stream* stream) override; + Status Initialize(const GpuExecutable& executable) override; + Status ExecuteOnStream(const BufferAllocations& buffer_allocations, + perftools::gputools::Stream* stream) override; private: const BufferAllocation::Slice condition_result_buffer_index_; diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer.cc b/tensorflow/compiler/xla/service/gpu/while_transformer.cc index ccdd1717593e4fa7c1d1deb3f0f9ebfab1bf7209..e6caec8625f0d622dbb92bcc20802d254fe23f94 100644 --- a/tensorflow/compiler/xla/service/gpu/while_transformer.cc +++ b/tensorflow/compiler/xla/service/gpu/while_transformer.cc @@ -44,7 +44,7 @@ namespace { // // Parameter // | -// Const GetTupleElemet +// Const GetTupleElement // \ / // Add (root) // @@ -62,7 +62,7 @@ namespace { // &tagged_instructions)); // // Instructions that are "tagged" with a context-specific string will -// be returned in 'tagged_instructions' for further procesing (i.e. parsing +// be returned in 'tagged_instructions' for further processing (i.e. parsing // constants or recording the tuple_index). // class ExprTree { diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc index f16daa0b5481474e754c880ead1945297ca50168..2f290f61bd527e9827472a78256f015e066e44be 100644 --- a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc +++ b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc @@ -117,9 +117,7 @@ class WhileTransformerTest : public HloTestBase { } void RunCopyInsertionPass() { - HloVerifier verifier([](const Shape& shape) { - return ShapeUtil::ByteSizeOf(shape, /*pointer_size=*/sizeof(void*)); - }); + HloVerifier verifier; TF_ASSERT_OK(verifier.Run(module_.get()).status()); CopyInsertion copy_insertion; TF_ASSERT_OK(copy_insertion.Run(module_.get()).status()); diff --git a/tensorflow/compiler/xla/service/graphviz_example.cc b/tensorflow/compiler/xla/service/graphviz_example.cc index 049e8d80d80c835bca4a4d38592564ba82a3ecf9..05017008e2ddbe0b9e78d06275fdec5d08d94bfa 100644 --- a/tensorflow/compiler/xla/service/graphviz_example.cc +++ b/tensorflow/compiler/xla/service/graphviz_example.cc @@ -108,8 +108,11 @@ std::unique_ptr MakeBigGraph() { HloInstruction::CreateUnary(vshape, HloOpcode::kCopy, param_v0)); auto clamp = builder.AddInstruction(HloInstruction::CreateTernary( vshape, HloOpcode::kClamp, copy, param_v1, param_v2)); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); auto dot = builder.AddInstruction( - HloInstruction::CreateBinary(vshape, HloOpcode::kDot, clamp, param_v0)); + HloInstruction::CreateDot(vshape, clamp, param_v0, dot_dnums)); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({dot, param_s, clamp})); auto scalar = builder.AddInstruction( diff --git a/tensorflow/compiler/xla/service/heap_simulator.h b/tensorflow/compiler/xla/service/heap_simulator.h index a03ad2f37cf5ede35275ea019ab3d5998fb85d0a..88a8698d16132372fc8f4e87eba3b99125aab876 100644 --- a/tensorflow/compiler/xla/service/heap_simulator.h +++ b/tensorflow/compiler/xla/service/heap_simulator.h @@ -264,7 +264,7 @@ class LazyBestFitHeap : public HeapAlgorithm { enum { kLazyAllocOffset = -1 }; struct OrderChunkByIncreasingSize { - bool operator()(const Chunk& a, const Chunk& b) { + bool operator()(const Chunk& a, const Chunk& b) const { if (a.size != b.size) return a.size < b.size; return a.offset < b.offset; } diff --git a/tensorflow/compiler/xla/service/heap_simulator_test.cc b/tensorflow/compiler/xla/service/heap_simulator_test.cc index 17b926c8748e45b55f380e7595711b9e7a748f64..387b649a731ebcbfd8307807469f39f22d192b06 100644 --- a/tensorflow/compiler/xla/service/heap_simulator_test.cc +++ b/tensorflow/compiler/xla/service/heap_simulator_test.cc @@ -259,8 +259,11 @@ TEST_F(HeapSimulatorTest, MultiplyDot) { HloInstruction::CreateParameter(2, f32scalar_, "paramY")); auto mul = builder.AddInstruction(HloInstruction::CreateBinary( f32vec4_, HloOpcode::kMultiply, paramA, paramX)); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); auto dot = builder.AddInstruction( - HloInstruction::CreateBinary(f32vec4_, HloOpcode::kDot, mul, paramY)); + HloInstruction::CreateDot(f32vec4_, mul, paramY, dot_dnums)); // The buffer for dot is the output, and it cannot be shared with the buffer // for mul, since dot isn't elementwise. @@ -292,8 +295,11 @@ TEST_F(HeapSimulatorTest, MultiplyDotAdd) { HloInstruction::CreateParameter(2, f32scalar_, "paramY")); auto mul = builder.AddInstruction(HloInstruction::CreateBinary( f32vec4_, HloOpcode::kMultiply, paramA, paramX)); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); auto dot = builder.AddInstruction( - HloInstruction::CreateBinary(f32vec4_, HloOpcode::kDot, mul, paramY)); + HloInstruction::CreateDot(f32vec4_, mul, paramY, dot_dnums)); auto add = builder.AddInstruction( HloInstruction::CreateBinary(f32vec4_, HloOpcode::kAdd, dot, paramA)); @@ -327,10 +333,13 @@ TEST_F(HeapSimulatorTest, MultiplyDotDot) { HloInstruction::CreateParameter(2, f32scalar_, "paramY")); auto mul = builder.AddInstruction(HloInstruction::CreateBinary( f32vec4_, HloOpcode::kMultiply, paramA, paramX)); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); auto dot0 = builder.AddInstruction( - HloInstruction::CreateBinary(f32vec4_, HloOpcode::kDot, mul, paramY)); + HloInstruction::CreateDot(f32vec4_, mul, paramY, dot_dnums)); auto dot1 = builder.AddInstruction( - HloInstruction::CreateBinary(f32vec4_, HloOpcode::kDot, dot0, paramY)); + HloInstruction::CreateDot(f32vec4_, dot0, paramY, dot_dnums)); // The buffer for dot1 is the output. No buffers can be shared. The buffer // for mul is freed before the end, since it's no longer used after dot0 @@ -365,10 +374,13 @@ TEST_F(HeapSimulatorTest, MultiplyDotDotTuple) { HloInstruction::CreateParameter(2, f32scalar_, "paramY")); auto mul = builder.AddInstruction(HloInstruction::CreateBinary( f32vec4_, HloOpcode::kMultiply, paramA, paramX)); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); auto dot0 = builder.AddInstruction( - HloInstruction::CreateBinary(f32vec4_, HloOpcode::kDot, mul, paramY)); + HloInstruction::CreateDot(f32vec4_, mul, paramY, dot_dnums)); auto dot1 = builder.AddInstruction( - HloInstruction::CreateBinary(f32vec4_, HloOpcode::kDot, dot0, paramY)); + HloInstruction::CreateDot(f32vec4_, dot0, paramY, dot_dnums)); auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({dot0, dot1})); diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index e984bdb5f75f714fb7b4453a97178158d9b8a8b8..0e9a852788e978f79fa6f6c802f855a4c476583f 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -36,6 +36,9 @@ option cc_enable_arenas = true; // Serialization of HloInstruction. message HloInstructionProto { + reserved 10; + reserved "parameter_name"; + string name = 1; string opcode = 2; xla.Shape shape = 3; @@ -50,9 +53,8 @@ message HloInstructionProto { // Literal, only present for kConstant. xla.LiteralProto literal = 8; - // Parameter info, only present for kParameter. + // Parameter number is only present for kParameter. int64 parameter_number = 9; - string parameter_name = 10; // Fusion state, only present for kFusion. string fusion_kind = 11; @@ -118,6 +120,15 @@ message HloInstructionProto { // Shape of outfeed request. xla.Shape outfeed_shape = 29; + + // Describes the dimension numbers used for a dot operation + xla.DotDimensionNumbers dot_dimension_numbers = 30; + + // FFT type (FFT, IFFT, etc). + xla.FftType fft_type = 31; + + // FFT length. + repeated int64 fft_length = 32; } // Serialization of HloComputation. diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index c215cc48d60b93a88d64b7c4aecb2aa3bb460443..a63affa06caf75f1ccab084bd114e39ba7c91a38 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -131,9 +131,9 @@ Status HloComputation::RemoveParameter(int64 param_no) { while (param_no < param_instructions_.size()) { param_instruction = param_instructions_[param_no]; - string param_name = param_instruction->parameter_name(); + string param_name = param_instruction->name(); // Fusion parameters are named foo.param_1, bar.param_2, etc. We are - // renumbering the parameters so replace the final number in the name with + // renumbering the parameters, so replace the final number in the name with // the updated value. const string param_underscore = ".param_"; size_t index = param_name.rfind(param_underscore); @@ -176,10 +176,6 @@ bool HloComputation::IsRemovable(const HloInstruction* instruction) { return false; } - if (instruction->HasSideEffect()) { - return false; - } - return true; } @@ -207,7 +203,8 @@ Status HloComputation::RemoveInstructionAndUnusedOperands( worklist.pop(); if (removed.count(item) != 0 || item->user_count() != 0 || - item == root_instruction() || !IsRemovable(item)) { + item == root_instruction() || !IsRemovable(item) || + item->HasSideEffect()) { continue; } for (int i = 0; i < item->operand_count(); ++i) { @@ -367,26 +364,27 @@ std::list HloComputation::MakeEmbeddedComputationsList() return post_order; } -string HloComputation::ToString(int nested_level, - bool include_large_constants) const { +string HloComputation::ToString(const HloPrintOptions& options) const { std::ostringstream s; - for (int i = 0; i < nested_level; i++) { + for (int i = 0; i < options.indent_amount(); i++) { s << " "; } - s << "%" << name() << " " << ShapeUtil::HumanString(ComputeProgramShape()) - << " {\n"; + if (options.print_percent()) { + s << "%"; + } + s << name(); + if (options.print_program_shape()) { + s << " " << ShapeUtil::HumanString(ComputeProgramShape()); + } + s << " {\n"; for (const HloInstruction* instruction : MakeInstructionPostOrder()) { - for (int i = 0; i < nested_level; i++) { + for (int i = 0; i < options.indent_amount(); i++) { s << " "; } s << " " << (instruction == root_instruction_ ? "ROOT " : "") - << instruction->ToString( - /*compact_operands=*/false, - /*include_metadata=*/true, - /*include_large_constants=*/include_large_constants) - << "\n"; + << instruction->ToString(options) << "\n"; } - for (int i = 0; i < nested_level; i++) { + for (int i = 0; i < options.indent_amount(); i++) { s << " "; } s << "}"; @@ -543,7 +541,7 @@ ProgramShape HloComputation::ComputeProgramShape() const { for (auto* param_instruction : param_instructions_) { *program_shape.add_parameters() = param_instruction->shape(); - *program_shape.add_parameter_names() = param_instruction->parameter_name(); + *program_shape.add_parameter_names() = param_instruction->name(); } *program_shape.mutable_result() = root_instruction_->shape(); diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index 353b30bc69d98556311635d6097e3d6ad5fb2aaa..6436815f910405477ec21a33dec75ef71df08602 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -138,8 +138,11 @@ class HloComputation { void UniquifyName(NameUniquer* name_uniquer); // Return a string representation of the computation. - string ToString(int nested_level = 0, - bool include_large_constants = false) const; + // + // (We express the default options using an overload rather than a default + // param because gdb ignores default params, but does resolve overloads.) + string ToString() const { return ToString(HloPrintOptions()); } + string ToString(const HloPrintOptions& options) const; // Returns a serialized representation of this computation. HloComputationProto ToProto() const; @@ -313,11 +316,17 @@ class HloComputation { replacements, HloModule* module = nullptr, const string& suffix = "clone"); - // Returns true if the given instruction can be removed from the - // computation. Instructions such as parameters and send/receive instructions - // cannot be removed without violating invariants of the HLO computation or - // module with the exception of fusion computation. A parameter instruction - // is removable for a fusion computation. + // Returns true if the given instruction can be removed from the computation. + // Parameter instructions cannot be removed without violating invariants of + // the HLO computation with the exception of fusion computation. A parameter + // instruction is removable for a fusion computation. + // + // Note that IsRemovable() is a necessariy condition to remove an instruction + // rather than a sufficient condition. For example, instructions with + // side-effect (e.g., Send, Infeed) may be removed from a computation, but the + // transformation must guarantee the invariants relevant to the instructions + // still hold (e.g., Send and Recv must be removed together to make each + // channel complete). bool IsRemovable(const HloInstruction* instruction); // Returns true if this computation has a side effect. A computation has a diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc index 6fcc01dd64d1ac041e99eedb8b1de476409b257d..cd54eb74d18d0be714b5b56fc8ae0dfa55ff31a0 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc @@ -201,10 +201,11 @@ Status HloCostAnalysis::HandleCopy(const HloInstruction*) { Status HloCostAnalysis::HandleDot(const HloInstruction* dot) { const Shape& lhs_shape = dot->operand(0)->shape(); const Shape& rhs_shape = dot->operand(1)->shape(); + const DotDimensionNumbers& dnums = dot->dot_dimension_numbers(); // Count of elements along the reduction dimension (last dimension for the // rhs). - int64 reduction_width = lhs_shape.dimensions(ShapeUtil::Rank(lhs_shape) - 1); - + int64 reduction_width = + lhs_shape.dimensions(dnums.lhs_contracting_dimensions(0)); // First divide by reduction width before multiplying by rhs elements to avoid // overflow. int64 fma_count; @@ -391,13 +392,35 @@ Status HloCostAnalysis::HandleConvolution(const HloInstruction* convolution) { return Status::OK(); } +Status HloCostAnalysis::HandleFft(const HloInstruction* fft) { + auto real_shape = + ShapeUtil::IsTuple(fft->operand(0)->shape()) + ? ShapeUtil::GetTupleElementShape(fft->operand(0)->shape(), 0) + : fft->operand(0)->shape(); + constexpr int kFmaPerComplexMul = 4; + int64 log_factors = 1; + for (int64 dim : fft->fft_length()) { + log_factors *= tensorflow::Log2Floor(dim); + } + current_properties_[kFlopsKey] = kFmaFlops * kFmaPerComplexMul * log_factors * + ShapeUtil::ElementsIn(real_shape); + return Status::OK(); +} + Status HloCostAnalysis::HandleCrossReplicaSum(const HloInstruction* crs) { // We assume 2 replicas, so that each output element is the sum of two input // elements. // // TODO(b/33004697): Compute correct cost here, taking the actual number of // replicas into account. - current_properties_[kFlopsKey] = ShapeUtil::ElementsIn(crs->shape()); + double flops = 0.0; + ShapeUtil::ForEachSubshape( + crs->shape(), [&, this](const Shape& subshape, const ShapeIndex&) { + if (ShapeUtil::IsArray(subshape)) { + flops += ShapeUtil::ElementsIn(subshape); + } + }); + current_properties_[kFlopsKey] = flops; return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h index fade19522cf0c30eab037aa355de1f9203f80014..e5783539e5436f09fa58bf7889118380ee90fea0 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h @@ -67,6 +67,7 @@ class HloCostAnalysis : public ConstDfsHloVisitor { Status HandleCopy(const HloInstruction* copy) override; Status HandleDot(const HloInstruction* dot) override; Status HandleConvolution(const HloInstruction* convolution) override; + Status HandleFft(const HloInstruction* fft) override; Status HandleCrossReplicaSum(const HloInstruction* crs) override; Status HandleInfeed(const HloInstruction* infeed) override; Status HandleOutfeed(const HloInstruction* outfeed) override; diff --git a/tensorflow/compiler/xla/service/hlo_cse.cc b/tensorflow/compiler/xla/service/hlo_cse.cc index d35ba19a730555433099072c51ca5cf3774d4b99..7feda2b3b040de1f0a14303ce1adcd21c6624c8b 100644 --- a/tensorflow/compiler/xla/service/hlo_cse.cc +++ b/tensorflow/compiler/xla/service/hlo_cse.cc @@ -32,6 +32,7 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" namespace xla { @@ -91,6 +92,10 @@ bool CombineConstants(HloComputation* computation, bool is_layout_sensitive) { StatusOr HloCSE::Run(HloModule* module) { bool changed = false; + const std::function + eq_instructions = std::equal_to(); + const std::function + eq_computations = std::equal_to(); for (auto* computation : module->computations()) { changed |= CombineConstants(computation, is_layout_sensitive_); @@ -110,9 +115,11 @@ StatusOr HloCSE::Run(HloModule* module) { // of this instruction. const HloInstruction* operand = instruction->operand(0); - std::vector equivalent_instructions; + tensorflow::gtl::InlinedVector + equivalent_instructions; for (HloInstruction* user : operand->users()) { - if (user != instruction && user->Identical(*instruction) && + if (user != instruction && + user->Identical(*instruction, eq_instructions, eq_computations) && (!is_layout_sensitive_ || ShapeUtil::Equal(user->shape(), instruction->shape()))) { equivalent_instructions.push_back(user); diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc index 3f34b9ceb34abc89fca5b896bb8fbe3a06cd6ed4..d25fc5d7418ae40c7167f88d6172906482a58925 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc @@ -154,7 +154,11 @@ bool HloDataflowAnalysis::Phi( tensorflow::gtl::ArraySlice inputs) { CHECK(ssa_form_); VLOG(4) << "Phi(" << instruction->name() << ")"; - + VLOG(5) << "instruction value set = " + << GetInstructionValueSet(instruction).ToString(); + for (const InstructionValueSet* input : inputs) { + VLOG(5) << "input value set = " << input->ToString(); + } for (const InstructionValueSet* input : inputs) { DCHECK(ShapeUtil::Compatible(instruction->shape(), input->shape())); } @@ -171,9 +175,14 @@ bool HloDataflowAnalysis::Phi( value_set.values().size() == 1 ? value_set.values()[0] : nullptr; // Construct a vector of unique value IDs of the inputs. + // Don't add value ids where the input is equal to the definition. std::vector input_value_ids; for (const InstructionValueSet* input : inputs) { for (const HloValue* value : input->element(index).values()) { + if (value->defining_instruction() == instruction && + value->defining_index() == index) { + continue; + } input_value_ids.push_back(value->id()); } } @@ -190,6 +199,7 @@ bool HloDataflowAnalysis::Phi( current_value->defining_instruction() == instruction && current_value->defining_index() == index); if (current_value_defined_here) { + VLOG(5) << "current_value_defined_here: " << current_value->ToString(); CHECK(current_value->is_phi()); auto it = std::find(input_value_ids.begin(), input_value_ids.end(), current_value->id()); @@ -197,7 +207,7 @@ bool HloDataflowAnalysis::Phi( input_value_ids.erase(it); } } - + VLOG(5) << "after input_value_ids.size = " << input_value_ids.size(); if (input_value_ids.empty()) { // A value set which has at least one element should never have its value // set reduced to zero elements. During dataflow value sets only can go @@ -276,6 +286,23 @@ bool HloDataflowAnalysis::UpdateBitcastValueSet(HloInstruction* bitcast) { return false; } +bool HloDataflowAnalysis::UpdateSliceValueSet(HloInstruction* slice) { + CHECK_EQ(slice->opcode(), HloOpcode::kSlice); + if (!slice->IsInPlaceSlice()) { + return false; + } + // If this slice is lowered to an in-place version, then it forwards the + // operand value to the output. + const InstructionValueSet& operand_set = + GetInstructionValueSet(slice->operand(0)); + InstructionValueSet& slice_set = GetInstructionValueSet(slice); + if (operand_set != slice_set) { + slice_set = operand_set; + return true; + } + return false; +} + bool HloDataflowAnalysis::UpdateSendValueSet(HloInstruction* send) { CHECK_EQ(send->opcode(), HloOpcode::kSend); bool changed = false; @@ -333,6 +360,21 @@ bool HloDataflowAnalysis::UpdateCallValueSet(HloInstruction* call) { return false; } +bool HloDataflowAnalysis::UpdateConditionalValueSet( + HloInstruction* conditional) { + CHECK_EQ(conditional->opcode(), HloOpcode::kConditional); + std::vector inputs = { + &GetInstructionValueSet( + conditional->true_computation()->root_instruction()), + &GetInstructionValueSet( + conditional->false_computation()->root_instruction())}; + // A phi-node is not defined for a kConditional instruction even though it + // represents a join point. This is because the current approach is to define + // a phi-node only for kWhile to account for the dataflow through back-edges + // and deal with the ambiguity in other cases. + return GetInstructionValueSet(conditional).AssignUnionOf(inputs); +} + bool HloDataflowAnalysis::UpdateCopyValueSet(HloInstruction* copy) { CHECK_EQ(copy->opcode(), HloOpcode::kCopy); bool changed = false; @@ -394,7 +436,7 @@ bool HloDataflowAnalysis::UpdateParameterValueSet(HloInstruction* parameter) { CHECK_EQ(call_graph_node.context(), CallContext::kSequential); std::vector inputs; - bool called_from_while = false; + bool need_phi = false; for (const CallSite& callsite : call_graph_node.caller_callsites()) { if (callsite.instruction()->opcode() == HloOpcode::kCall) { // The operand values of a call instruction are forwarded to the @@ -416,14 +458,32 @@ bool HloDataflowAnalysis::UpdateParameterValueSet(HloInstruction* parameter) { inputs.push_back(&GetInstructionValueSet( callsite.instruction()->while_body()->root_instruction())); } - called_from_while = true; + need_phi = true; + } else if (callsite.instruction()->opcode() == HloOpcode::kConditional) { + CHECK_EQ(parameter->parameter_number(), 0); + auto conditional = callsite.instruction(); + // Conditional has 3 operands. Operand 0 is the predicate, operand 1 is + // the argument to the true computation and operand 2 is the argument to + // the false computation. + // + // If the parameter belongs to conditional's true computation, then + // operand 1 is forwarded to this parameter instruction. If the parameter + // belongs to conditional's false computation, then operand 2 is forwarded + // to this parameter instruction. + if (parameter->parent() == conditional->true_computation()) { + inputs.push_back(&GetInstructionValueSet(conditional->operand(1))); + } else { + CHECK_EQ(parameter->parent(), conditional->false_computation()); + inputs.push_back(&GetInstructionValueSet(conditional->operand(2))); + } + need_phi = true; } else { LOG(FATAL) << "CallContext::kSequential computations should only be " - "called from call or while instructions"; + "called from call, while, or conditional instructions"; } } - if (ssa_form_ && called_from_while) { + if (ssa_form_ && need_phi) { return Phi(parameter, inputs); } else { return GetInstructionValueSet(parameter).AssignUnionOf(inputs); @@ -494,6 +554,8 @@ bool HloDataflowAnalysis::UpdateInstructionValueSet( switch (instruction->opcode()) { case HloOpcode::kBitcast: return UpdateBitcastValueSet(instruction); + case HloOpcode::kSlice: + return UpdateSliceValueSet(instruction); case HloOpcode::kCopy: return UpdateCopyValueSet(instruction); case HloOpcode::kGetTupleElement: @@ -512,6 +574,8 @@ bool HloDataflowAnalysis::UpdateInstructionValueSet( return UpdateSendValueSet(instruction); case HloOpcode::kRecvDone: return UpdateRecvDoneValueSet(instruction); + case HloOpcode::kConditional: + return UpdateConditionalValueSet(instruction); default: // Instruction does not forward HloValues (it defines all values in its // output). No update is necessary. @@ -550,13 +614,31 @@ void HloDataflowAnalysis::Propagate() { // If user sequentially calls a computation, then the respective // parameter(s) of the computation need to be updated. - for (HloComputation* called_computation : user->called_computations()) { - const CallGraphNode& call_graph_node = - call_graph_->GetNode(called_computation); - if (call_graph_node.context() == CallContext::kSequential) { - for (int64 operand_number : user->OperandIndices(instruction)) { - worklist.push( - called_computation->parameter_instruction(operand_number)); + if (user->opcode() == HloOpcode::kConditional) { + // If operand 0 is the use of instruction, then no parameters need to be + // updated, since that is the predicate of the conditional. + // If operand 1 is the use of instruction, then the true_computation's + // parameter need to be updated. + // If operand 2 is the use of instruction, then the false_computation's + // parameter need to be updated. + // + // Note that the same instruction can be used in both operand 1 and + // operand 2. + if (user->operand(1) == instruction) { + worklist.push(user->true_computation()->parameter_instruction(0)); + } + if (user->operand(2) == instruction) { + worklist.push(user->false_computation()->parameter_instruction(0)); + } + } else { + for (HloComputation* called_computation : user->called_computations()) { + const CallGraphNode& call_graph_node = + call_graph_->GetNode(called_computation); + if (call_graph_node.context() == CallContext::kSequential) { + for (int64 operand_number : user->OperandIndices(instruction)) { + worklist.push( + called_computation->parameter_instruction(operand_number)); + } } } } @@ -568,7 +650,8 @@ void HloDataflowAnalysis::Propagate() { const CallGraphNode& call_graph_node = call_graph_->GetNode(instruction->parent()); for (const CallSite& callsite : call_graph_node.caller_callsites()) { - if (callsite.instruction()->opcode() == HloOpcode::kCall) { + if ((callsite.instruction()->opcode() == HloOpcode::kCall) || + (callsite.instruction()->opcode() == HloOpcode::kConditional)) { worklist.push(callsite.instruction()); } else if (callsite.instruction()->opcode() == HloOpcode::kWhile) { // Add the while itself, and the body and condition parameters. @@ -634,8 +717,14 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() { define_all_values(); } break; + case HloOpcode::kSlice: + if (!instruction->IsInPlaceSlice()) { + define_all_values(); + } + break; case HloOpcode::kWhile: case HloOpcode::kCall: + case HloOpcode::kConditional: case HloOpcode::kGetTupleElement: // These instructions define no values. The values in their output // flow from their operands or from cross computation dataflow. diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h index dfd81ae951042f7a4d6d3c24af4d5b7e046c272d..89d318188f0855c7924836a51cfe98d531e08cb4 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h @@ -145,7 +145,9 @@ class HloDataflowAnalysis { // Updates the value set for a particular instruction type. Returns whether // the instruction value set changed. bool UpdateBitcastValueSet(HloInstruction* bitcast); + bool UpdateSliceValueSet(HloInstruction* slice); bool UpdateCallValueSet(HloInstruction* call); + bool UpdateConditionalValueSet(HloInstruction* conditional); bool UpdateCopyValueSet(HloInstruction* copy); bool UpdateGetTupleElementValueSet(HloInstruction* gte); bool UpdateParameterValueSet(HloInstruction* parameter); diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc index f08f0b1d6833b028baa5f997929a17eb5abae205..e714b2567fd1b3eab607a19f0bb7e3288150dc64 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc @@ -34,6 +34,7 @@ limitations under the License. namespace xla { namespace { +using ::testing::ElementsAre; using ::testing::UnorderedElementsAre; // Test is parameterized on a bool which is whether the dataflow analysis is @@ -77,11 +78,23 @@ class HloDataflowAnalysisTest : public HloTestBase, analysis_->GetValueDefinedAt(b), *analysis_); } + std::unique_ptr CreateR0F32UnaryOpComputation( + HloOpcode opcode) { + HloComputation::Builder builder(TestName() + "." + HloOpcodeString(opcode)); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape_, "param0")); + builder.AddInstruction( + HloInstruction::CreateUnary(scalar_shape_, opcode, param0)); + return builder.Build(); + } + std::unique_ptr module_; std::unique_ptr analysis_; const Shape scalar_shape_ = ShapeUtil::MakeShape(F32, {}); const Shape vector_shape_ = ShapeUtil::MakeShape(F32, {42}); + const Shape tuple_shape_ = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {})}); }; TEST_P(HloDataflowAnalysisTest, BinaryOperation) { @@ -1528,6 +1541,315 @@ TEST_P(HloDataflowAnalysisTest, EmbeddedComputationInterference) { EXPECT_TRUE(InstructionsMayInterfere(ordering, negate, embedded_log)); } +TEST_P(HloDataflowAnalysisTest, ConditionalWithIdentity) { + // Test conditional with identity computations in both true and false cases. + // + // true_computation(F32[] %true_param): + // return %true_param + // + // false_computation(F32[] %false_param): + // return %false_param + // + // entry: + // %pred = Constant(true) + // %constant1 = Constant(56.0) + // %constant2 = Constant(12.0) + // return Conditional(%pred, %constant1, true_computation, + // %constant2, false_computation) + + auto true_builder = HloComputation::Builder(TestName() + "_true"); + auto true_param = true_builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape_, "true_param")); + HloComputation* true_computation = + module_->AddEmbeddedComputation(true_builder.Build()); + + auto false_builder = HloComputation::Builder(TestName() + "_false"); + auto false_param = false_builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape_, "false_param")); + HloComputation* false_computation = + module_->AddEmbeddedComputation(false_builder.Build()); + + auto builder = HloComputation::Builder(TestName()); + auto pred = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(true))); + auto constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(56.0f))); + auto constant2 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(12.0f))); + auto conditional = builder.AddInstruction(HloInstruction::CreateConditional( + scalar_shape_, pred, constant1, true_computation, constant2, + false_computation)); + module_->AddEntryComputation(builder.Build()); + + const HloDataflowAnalysis& analysis = RunAnalysis(GetParam()); + + EXPECT_TRUE(analysis.ValueIsDefinedAt(pred)); + EXPECT_TRUE(analysis.ValueIsDefinedAt(constant1)); + EXPECT_TRUE(analysis.ValueIsDefinedAt(constant2)); + + EXPECT_FALSE(analysis.ValueIsDefinedAt(true_param)); + EXPECT_FALSE(analysis.ValueIsDefinedAt(false_param)); + + EXPECT_EQ(analysis.GetUniqueValueAt(true_param), + analysis.GetValueDefinedAt(constant1)); + EXPECT_EQ(analysis.GetUniqueValueAt(false_param), + analysis.GetValueDefinedAt(constant2)); + + EXPECT_THAT(analysis.GetValueDefinedAt(pred).uses(), + ElementsAre(HloUse{conditional, 0, {}})); + EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(), + ElementsAre(HloUse{conditional, 1, {}})); + EXPECT_THAT(analysis.GetValueDefinedAt(constant2).uses(), + ElementsAre(HloUse{conditional, 2, {}})); + + EXPECT_EQ(analysis.values().size(), 3); + EXPECT_FALSE(analysis.ValueIsDefinedAt(conditional)); + EXPECT_THAT(HloValuesAt(conditional), + UnorderedElementsAre(analysis.GetValueDefinedAt(constant1), + analysis.GetValueDefinedAt(constant2))); +} + +TEST_P(HloDataflowAnalysisTest, ConditionalTakingTupleOperand) { + // Test conditional with true and false computations taking a tuple operand. + // + // true_computation((F32[], F32[]) %true_param): + // %true_x = GetTupleElement(%true_param, 0) + // %true_y = GetTupleElement(%true_param, 1) + // return Add(%true_x, %true_y) + // + // false_computation((F32[], F32[]) %false_param): + // %false_x = GetTupleElement(%false_param, 0) + // %false_y = GetTupleElement(%false_param, 1) + // return Subtract(%false_x, %false_y) + // + // entry: + // %pred = Constant(true) + // %constant1 = Constant(56.0) + // %constant2 = Constant(12.0) + // %tuple_operand = Tuple(%constant1, %constant2) + // return Conditional(%pred, %tuple_operand, true_computation, + // %tuple_operand, false_computation) + + auto true_builder = HloComputation::Builder(TestName() + "_true"); + auto true_param = true_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape_, "true_param")); + auto true_x = true_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, true_param, 0)); + auto true_y = true_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, true_param, 1)); + auto add = true_builder.AddInstruction(HloInstruction::CreateBinary( + scalar_shape_, HloOpcode::kAdd, true_x, true_y)); + HloComputation* true_computation = + module_->AddEmbeddedComputation(true_builder.Build()); + + auto false_builder = HloComputation::Builder(TestName() + "_false"); + auto false_param = false_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape_, "false_param")); + auto false_x = false_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, false_param, 0)); + auto false_y = false_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, false_param, 1)); + auto sub = false_builder.AddInstruction(HloInstruction::CreateBinary( + scalar_shape_, HloOpcode::kSubtract, false_x, false_y)); + HloComputation* false_computation = + module_->AddEmbeddedComputation(false_builder.Build()); + + auto builder = HloComputation::Builder(TestName()); + auto pred = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(true))); + auto constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(56.0f))); + auto constant2 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(12.0f))); + auto tuple_operand = builder.AddInstruction( + HloInstruction::CreateTuple({constant1, constant2})); + auto conditional = builder.AddInstruction(HloInstruction::CreateConditional( + scalar_shape_, pred, tuple_operand, true_computation, tuple_operand, + false_computation)); + module_->AddEntryComputation(builder.Build()); + + const HloDataflowAnalysis& analysis = RunAnalysis(GetParam()); + + EXPECT_TRUE(analysis.ValueIsDefinedAt(pred)); + EXPECT_TRUE(analysis.ValueIsDefinedAt(constant1)); + EXPECT_TRUE(analysis.ValueIsDefinedAt(constant2)); + EXPECT_TRUE(analysis.ValueIsDefinedAt(tuple_operand)); + EXPECT_TRUE(analysis.ValueIsDefinedAt(add)); + EXPECT_TRUE(analysis.ValueIsDefinedAt(sub)); + + EXPECT_FALSE(analysis.ValueIsDefinedAt(true_param)); + EXPECT_FALSE(analysis.ValueIsDefinedAt(false_param)); + EXPECT_FALSE(analysis.ValueIsDefinedAt(true_x)); + EXPECT_FALSE(analysis.ValueIsDefinedAt(true_y)); + EXPECT_FALSE(analysis.ValueIsDefinedAt(false_x)); + EXPECT_FALSE(analysis.ValueIsDefinedAt(false_y)); + + EXPECT_EQ(analysis.GetUniqueValueAt(true_param), + analysis.GetValueDefinedAt(tuple_operand)); + EXPECT_EQ(analysis.GetUniqueValueAt(false_param), + analysis.GetValueDefinedAt(tuple_operand)); + EXPECT_EQ(analysis.GetUniqueValueAt(true_x), + analysis.GetValueDefinedAt(constant1)); + EXPECT_EQ(analysis.GetUniqueValueAt(true_y), + analysis.GetValueDefinedAt(constant2)); + EXPECT_EQ(analysis.GetUniqueValueAt(false_x), + analysis.GetValueDefinedAt(constant1)); + EXPECT_EQ(analysis.GetUniqueValueAt(false_y), + analysis.GetValueDefinedAt(constant2)); + + EXPECT_THAT(analysis.GetValueDefinedAt(pred).uses(), + ElementsAre(HloUse{conditional, 0, {}})); + EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(), + UnorderedElementsAre(HloUse{conditional, 1, {0}}, + HloUse{conditional, 2, {0}}, + HloUse{add, 0, {}}, HloUse{sub, 0, {}})); + EXPECT_THAT(analysis.GetValueDefinedAt(constant2).uses(), + UnorderedElementsAre(HloUse{conditional, 1, {1}}, + HloUse{conditional, 2, {1}}, + HloUse{add, 1, {}}, HloUse{sub, 1, {}})); + EXPECT_THAT(analysis.GetValueDefinedAt(tuple_operand).uses(), + UnorderedElementsAre( + HloUse{conditional, 1, {}}, HloUse{conditional, 2, {}}, + HloUse{true_x, 0, {}}, HloUse{true_y, 0, {}}, + HloUse{false_x, 0, {}}, HloUse{false_y, 0, {}})); + + EXPECT_EQ(analysis.values().size(), 6); + EXPECT_FALSE(analysis.ValueIsDefinedAt(conditional)); + EXPECT_THAT(HloValuesAt(conditional), + UnorderedElementsAre(analysis.GetValueDefinedAt(add), + analysis.GetValueDefinedAt(sub))); +} + +TEST_P(HloDataflowAnalysisTest, NestedConditionals) { + // computation1(F32[] %param1): + // %ceil = Ceil(%param1) + // return %ceil + // + // computation2(F32[] %param2): + // %floor = Floor(%param2) + // return %floor + // + // computation3(F32[] %param3): + // %negate = Negate(%param3) + // return %negate + // + // inner_conditional((PRED, F32[], F32[]) %param_cond): + // %pred_cond = GetTupleElement(%param_cond, 0) + // %true_operand_cond = GetTupleElement(%param_cond, 1) + // %false_opearnd_cond = GetTupleElement(%param_cond, 2) + // return Conditional(%pred_cond, %true_operand_cond, computation1, + // %false_operand_cond, computation2) + // + // entry: + // %pred1 = Constant(true) + // %pred2 = Constant(false) + // %constant1 = Constant(1.1); + // %constant2 = Constant(2.2); + // %constant3 = Constant(3.3); + // return Conditional(%pred1, (%pred2, %constant1, %constant2), + // inner_conditional, %constant3, computation3) + + auto computation1 = module_->AddEmbeddedComputation( + CreateR0F32UnaryOpComputation(HloOpcode::kCeil)); + auto computation2 = module_->AddEmbeddedComputation( + CreateR0F32UnaryOpComputation(HloOpcode::kFloor)); + auto computation3 = module_->AddEmbeddedComputation( + CreateR0F32UnaryOpComputation(HloOpcode::kNegate)); + + // Build inner_conditional computation. + const Shape scalar_bool_shape = ShapeUtil::MakeShape(PRED, {}); + const Shape tuple_param_shape = ShapeUtil::MakeTupleShape( + {scalar_bool_shape, scalar_shape_, scalar_shape_}); + auto inner_builder = + HloComputation::Builder(TestName() + "_inner_conditional"); + auto param_cond = inner_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_param_shape, "param_cond")); + auto pred_cond = inner_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_bool_shape, param_cond, 0)); + auto true_operand_cond = inner_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, param_cond, 1)); + auto false_operand_cond = inner_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, param_cond, 2)); + auto inner_conditional = + inner_builder.AddInstruction(HloInstruction::CreateConditional( + scalar_shape_, pred_cond, true_operand_cond, computation1, + false_operand_cond, computation2)); + auto inner_conditional_computation = + module_->AddEmbeddedComputation(inner_builder.Build()); + + // Build entry computation. + auto builder = HloComputation::Builder(TestName()); + auto pred1 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(true))); + auto pred2 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(false))); + auto constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1.1f))); + auto constant2 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(2.2f))); + auto constant3 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(3.3f))); + auto tuple_operand = builder.AddInstruction( + HloInstruction::CreateTuple({pred2, constant1, constant2})); + auto conditional = builder.AddInstruction(HloInstruction::CreateConditional( + scalar_shape_, pred1, tuple_operand, inner_conditional_computation, + constant3, computation3)); + module_->AddEntryComputation(builder.Build()); + + const HloDataflowAnalysis& analysis = RunAnalysis(GetParam()); + + EXPECT_TRUE(analysis.ValueIsDefinedAt(pred1)); + EXPECT_TRUE(analysis.ValueIsDefinedAt(pred2)); + EXPECT_TRUE(analysis.ValueIsDefinedAt(constant1)); + EXPECT_TRUE(analysis.ValueIsDefinedAt(constant2)); + EXPECT_TRUE(analysis.ValueIsDefinedAt(constant3)); + EXPECT_TRUE(analysis.ValueIsDefinedAt(tuple_operand)); + EXPECT_TRUE(analysis.ValueIsDefinedAt(computation1->root_instruction())); + EXPECT_TRUE(analysis.ValueIsDefinedAt(computation2->root_instruction())); + EXPECT_TRUE(analysis.ValueIsDefinedAt(computation3->root_instruction())); + + auto computation1_param = computation1->parameter_instruction(0); + auto computation2_param = computation2->parameter_instruction(0); + auto computation3_param = computation3->parameter_instruction(0); + EXPECT_FALSE(analysis.ValueIsDefinedAt(computation1_param)); + EXPECT_FALSE(analysis.ValueIsDefinedAt(computation2_param)); + EXPECT_FALSE(analysis.ValueIsDefinedAt(computation3_param)); + EXPECT_EQ(analysis.GetUniqueValueAt(computation1_param), + analysis.GetValueDefinedAt(constant1)); + EXPECT_EQ(analysis.GetUniqueValueAt(computation2_param), + analysis.GetValueDefinedAt(constant2)); + EXPECT_EQ(analysis.GetUniqueValueAt(computation3_param), + analysis.GetValueDefinedAt(constant3)); + + EXPECT_FALSE(analysis.ValueIsDefinedAt(param_cond)); + EXPECT_FALSE(analysis.ValueIsDefinedAt(pred_cond)); + EXPECT_FALSE(analysis.ValueIsDefinedAt(true_operand_cond)); + EXPECT_FALSE(analysis.ValueIsDefinedAt(false_operand_cond)); + EXPECT_EQ(analysis.GetUniqueValueAt(param_cond), + analysis.GetValueDefinedAt(tuple_operand)); + EXPECT_EQ(analysis.GetUniqueValueAt(pred_cond), + analysis.GetValueDefinedAt(pred2)); + EXPECT_EQ(analysis.GetUniqueValueAt(true_operand_cond), + analysis.GetValueDefinedAt(constant1)); + EXPECT_EQ(analysis.GetUniqueValueAt(false_operand_cond), + analysis.GetValueDefinedAt(constant2)); + + EXPECT_EQ(analysis.values().size(), 9); + EXPECT_FALSE(analysis.ValueIsDefinedAt(inner_conditional)); + EXPECT_FALSE(analysis.ValueIsDefinedAt(conditional)); + EXPECT_THAT( + HloValuesAt(inner_conditional), + UnorderedElementsAre( + analysis.GetValueDefinedAt(computation1->root_instruction()), + analysis.GetValueDefinedAt(computation2->root_instruction()))); + EXPECT_THAT( + HloValuesAt(conditional), + UnorderedElementsAre( + analysis.GetValueDefinedAt(computation1->root_instruction()), + analysis.GetValueDefinedAt(computation2->root_instruction()), + analysis.GetValueDefinedAt(computation3->root_instruction()))); +} + INSTANTIATE_TEST_CASE_P(HloDataflowAnalysisInstantiation, HloDataflowAnalysisTest, ::testing::Values(false, true)); diff --git a/tensorflow/compiler/xla/service/hlo_dce.cc b/tensorflow/compiler/xla/service/hlo_dce.cc index 40e67c87807b3e13d8ac09206bf6be02e4f9ff31..1e5f0f797a13fd7e7ce1cc934387a274a74153bc 100644 --- a/tensorflow/compiler/xla/service/hlo_dce.cc +++ b/tensorflow/compiler/xla/service/hlo_dce.cc @@ -55,7 +55,8 @@ StatusOr HloDCE::Run(HloModule* module) { for (auto* instruction : computation->instructions()) { if (instruction->user_count() == 0 && live_instructions.count(instruction) == 0 && - computation->IsRemovable(instruction)) { + computation->IsRemovable(instruction) && + !instruction->HasSideEffect()) { dead_roots.push_back(instruction); } } diff --git a/tensorflow/compiler/xla/service/hlo_dce_test.cc b/tensorflow/compiler/xla/service/hlo_dce_test.cc index d54b9a27087a42fd23eab0bd06e8deaca567312b..5a56607a665c4cbeb7b2572f182b88e890602968 100644 --- a/tensorflow/compiler/xla/service/hlo_dce_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dce_test.cc @@ -70,6 +70,26 @@ TEST_F(HloDceTest, NoDeadCode) { EXPECT_EQ(3, computation->instruction_count()); } +TEST_F(HloDceTest, InstructionsWithSideEffect) { + // Verify that side-effect instructions (Send in this test) are not removed. + auto builder = HloComputation::Builder(TestName()); + auto constant = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + builder.AddInstruction( + HloInstruction::CreateSend(constant, /*channel_id=*/0)); + builder.AddInstruction(HloInstruction::CreateTuple({})); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_EQ(3, computation->instruction_count()); + + HloDCE dce; + EXPECT_FALSE(dce.Run(module.get()).ValueOrDie()); + + EXPECT_EQ(3, computation->instruction_count()); +} + TEST_F(HloDceTest, DeadParameters) { // Verify that dead parameters are not removed, but use of the dead parameters // are. diff --git a/tensorflow/compiler/xla/service/hlo_element_type_converter.cc b/tensorflow/compiler/xla/service/hlo_element_type_converter.cc new file mode 100644 index 0000000000000000000000000000000000000000..1773bb401d380031f6c860d295e76d2f62c9e5ff --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_element_type_converter.cc @@ -0,0 +1,137 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_element_type_converter.h" + +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_query.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace xla { +namespace { + +HloInstruction* ToElementType(HloInstruction* hlo, PrimitiveType type) { + if (hlo->shape().element_type() != type) { + Shape shape = ShapeUtil::ChangeElementType(hlo->shape(), type); + hlo = hlo->parent()->AddInstruction( + HloInstruction::CreateConvert(shape, hlo)); + } + CHECK_EQ(hlo->shape().element_type(), type); + return hlo; +} + +bool HasOperandType(HloInstruction* hlo, PrimitiveType type) { + for (HloInstruction* operand : hlo->operands()) { + if (operand->shape().element_type() == type) { + return true; + } + } + return false; +} + +} // namespace + +HloElementTypeConverter::HloElementTypeConverter( + PrimitiveType eliminate_type, PrimitiveType replace_with_type) + : eliminate_type_(eliminate_type), replace_with_type_(replace_with_type) {} + +StatusOr HloElementTypeConverter::Run(HloModule* module) { + XLA_VLOG_LINES( + 3, "HloElementTypeConverter::Run(), before:\n" + module->ToString()); + bool changed = false; + for (auto* computation : module->computations()) { + for (auto* hlo : computation->MakeInstructionPostOrder()) { + // These are ops where it does not make sense to convert them. + if (hlo->opcode() == HloOpcode::kParameter || + hlo->opcode() == HloOpcode::kConstant || + hlo->opcode() == HloOpcode::kTuple || + hlo->opcode() == HloOpcode::kConvert || + hlo->opcode() == HloOpcode::kGetTupleElement || + hlo->opcode() == HloOpcode::kInfeed || + hlo->opcode() == HloOpcode::kOutfeed) { + continue; + } + + // We cannot change a CustomCall since we have no way of adjusting the + // called binary to expect the updated type. + if (hlo->opcode() == HloOpcode::kCustomCall) { + continue; + } + + // These are ops with embedded computations where it suffices to convert + // the embedded computations instead of converting the ops themselves. + if (hlo->opcode() == HloOpcode::kWhile || + hlo->opcode() == HloOpcode::kCall || + hlo->opcode() == HloOpcode::kFusion || + hlo->opcode() == HloOpcode::kMap || + hlo->opcode() == HloOpcode::kReduce || + hlo->opcode() == HloOpcode::kReduceWindow || + hlo->opcode() == HloOpcode::kSelectAndScatter || + hlo->opcode() == HloOpcode::kConditional) { + continue; + } + TF_RET_CHECK(hlo->called_computations().empty()) << hlo->ToString(); + + if (!HasOperandType(hlo, eliminate_type_)) { + // If this CHECK fires, then this was an instruction that does not take + // the elimination type as an operand but it does return it. This pass + // does not have a feature to change the output type in that case, so + // instead of silently failing to eliminate the type, it fails loudly. + TF_RET_CHECK(hlo->shape().element_type() != eliminate_type_); + continue; + } + + std::vector new_operands; + for (HloInstruction* operand : hlo->operands()) { + if (operand->shape().element_type() == eliminate_type_) { + operand = ToElementType(operand, replace_with_type_); + } + new_operands.push_back(operand); + } + + HloInstruction* new_hlo; + if (hlo->shape().element_type() == eliminate_type_) { + Shape shape = + ShapeUtil::ChangeElementType(hlo->shape(), replace_with_type_); + new_hlo = computation->AddInstruction( + hlo->CloneWithNewOperands(shape, new_operands, hlo->GetModule())); + new_hlo = ToElementType(new_hlo, eliminate_type_); + } else { + new_hlo = computation->AddInstruction(hlo->CloneWithNewOperands( + hlo->shape(), new_operands, hlo->GetModule())); + } + TF_RETURN_IF_ERROR(computation->ReplaceInstruction(hlo, new_hlo)); + changed = true; + } + } + XLA_VLOG_LINES( + 2, "HloElementTypeConverter::Run(), after:\n" + module->ToString()); + return changed; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_element_type_converter.h b/tensorflow/compiler/xla/service/hlo_element_type_converter.h new file mode 100644 index 0000000000000000000000000000000000000000..2b109225d0b192e5c9e4f6d841377ffad8078dc2 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_element_type_converter.h @@ -0,0 +1,49 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_ELEMENT_TYPE_CONVERTER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_ELEMENT_TYPE_CONVERTER_H_ + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { + +// A pass that eliminates certain element types as the input or output of ops by +// inserting Convert ops. This allows a backend to support an element type while +// only actually implementing the Convert op for that element type. This is +// generally not the fastest approach, but it works. +class HloElementTypeConverter : public HloPassInterface { + public: + // eliminate_type is the type to eliminate as the input or output of ops, + // using Convert ops to replace it with replace_with_type. + HloElementTypeConverter(PrimitiveType eliminate_type, + PrimitiveType replace_with_type); + + tensorflow::StringPiece name() const override { + return "element_type_converter"; + } + + // Returns the pass on the module and returns whether the module was modified. + StatusOr Run(HloModule* module) override; + + private: + PrimitiveType eliminate_type_; + PrimitiveType replace_with_type_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_ELEMENT_TYPE_CONVERTER_H_ diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index e693d167a1f96f65b894d07fb2c8f33e61ff8c49..3a846a752988efd618a1d6b9ed3c9e7a27627eee 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_query.h" #include "tensorflow/compiler/xla/service/shape_inference.h" @@ -167,11 +168,37 @@ StatusOr> ElementWiseUnaryOpImpl( } // namespace -template +template class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { public: explicit TypedVisitor(HloEvaluator* p) : parent_(p) {} + // The following higher-order functions convert a function with ElementwiseT + // to a function with ReturnT. + std::function ConvertUnaryFunction( + const std::function& unary_op) { + return [&unary_op](ReturnT arg) { + return static_cast(unary_op(static_cast(arg))); + }; + } + std::function ConvertBinaryFunction( + const std::function& + binary_op) { + return [&binary_op](ReturnT arg1, ReturnT arg2) { + return static_cast(binary_op(static_cast(arg1), + static_cast(arg2))); + }; + } + std::function ConvertTernaryFunction( + const std::function& ternary_op) { + return [&ternary_op](ReturnT arg1, ReturnT arg2, ReturnT arg3) { + return static_cast(ternary_op(static_cast(arg1), + static_cast(arg2), + static_cast(arg3))); + }; + } + Status DefaultAction(HloInstruction* hlo_instruction) override { return Unimplemented("unhandled HLO ops for HloEvaluator: %s.", HloOpcodeString(hlo_instruction->opcode()).c_str()); @@ -197,24 +224,25 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { is_complex_t::value>::type* = nullptr> Status HandleAbs(HloInstruction* abs) { TF_ASSIGN_OR_RETURN(parent_->evaluated_[abs], - ElementWiseUnaryOp(abs, [](NativeT elem_operand) { + ElementWiseUnaryOp(abs, [](ElementwiseT elem_operand) { return std::abs(elem_operand); })); return Status::OK(); } Status HandleAbs(HloInstruction* abs) override { - return HandleAbs(abs); + return HandleAbs(abs); } template < typename NativeT, typename std::enable_if::value>::type* = nullptr> Status HandleRound(HloInstruction* round) { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[round], - ElementWiseUnaryOp(round, [](ReturnT elem_operand) { - return std::round(elem_operand); - })); + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[round], + ElementWiseUnaryOp(round, [](ElementwiseT elem_operand) { + return std::round(elem_operand); + })); return Status::OK(); } @@ -233,7 +261,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { parent_->evaluated_[broadcast] = Literal::CreateFromShape(broadcast->shape()); auto output = parent_->evaluated_[broadcast].get(); - auto operand_to_broadcast = + const Literal& operand_to_broadcast = parent_->GetEvaluatedLiteralFor(broadcast->operand(0)); std::vector broadcast_indices( ShapeUtil::Rank(broadcast->operand(0)->shape()), 0); @@ -264,7 +292,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { typename std::enable_if::value>::type* = nullptr> Status HandleCeil(HloInstruction* ceil) { TF_ASSIGN_OR_RETURN(parent_->evaluated_[ceil], - ElementWiseUnaryOp(ceil, [](ReturnT elem_operand) { + ElementWiseUnaryOp(ceil, [](ElementwiseT elem_operand) { return std::ceil(elem_operand); })); return Status::OK(); @@ -299,7 +327,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { Status HandleExp(HloInstruction* exp) override { TF_ASSIGN_OR_RETURN(parent_->evaluated_[exp], - ElementWiseUnaryOp(exp, [](ReturnT elem_operand) { + ElementWiseUnaryOp(exp, [](ElementwiseT elem_operand) { return std::exp(elem_operand); })); return Status::OK(); @@ -309,10 +337,11 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { typename NativeT, typename std::enable_if::value>::type* = nullptr> Status HandleFloor(HloInstruction* floor) { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[floor], - ElementWiseUnaryOp(floor, [](ReturnT elem_operand) { - return std::floor(elem_operand); - })); + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[floor], + ElementWiseUnaryOp(floor, [](ElementwiseT elem_operand) { + return std::floor(elem_operand); + })); return Status::OK(); } @@ -329,7 +358,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { Status HandleLog(HloInstruction* log) override { TF_ASSIGN_OR_RETURN(parent_->evaluated_[log], - ElementWiseUnaryOp(log, [](ReturnT elem_operand) { + ElementWiseUnaryOp(log, [](ElementwiseT elem_operand) { return std::log(elem_operand); })); return Status::OK(); @@ -341,7 +370,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { !std::is_same::value>::type* = nullptr> Status HandleNot(HloInstruction* not_) { TF_ASSIGN_OR_RETURN(parent_->evaluated_[not_], - ElementWiseUnaryOp(not_, [](ReturnT elem_operand) { + ElementWiseUnaryOp(not_, [](ElementwiseT elem_operand) { return ~elem_operand; })); return Status::OK(); @@ -351,7 +380,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { NativeT>::value>::type* = nullptr> Status HandleNot(HloInstruction* not_) { TF_ASSIGN_OR_RETURN(parent_->evaluated_[not_], - ElementWiseUnaryOp(not_, [](ReturnT elem_operand) { + ElementWiseUnaryOp(not_, [](ElementwiseT elem_operand) { return !elem_operand; })); return Status::OK(); @@ -362,7 +391,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { nullptr> Status HandleNot(HloInstruction* not_) { TF_ASSIGN_OR_RETURN(parent_->evaluated_[not_], - ElementWiseUnaryOp(not_, [](ReturnT elem_operand) { + ElementWiseUnaryOp(not_, [](ElementwiseT elem_operand) { return !elem_operand; })); return Status::OK(); @@ -376,7 +405,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { } Status HandleNot(HloInstruction* not_) override { - return HandleNot(not_); + return HandleNot(not_); } template ::value>::type* = nullptr> Status HandleNegate(HloInstruction* negate) { using type = typename std::make_unsigned::type; - TF_ASSIGN_OR_RETURN(parent_->evaluated_[negate], - ElementWiseUnaryOp(negate, [](ReturnT elem_operand) { - return NativeT(-type(elem_operand)); - })); + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[negate], + ElementWiseUnaryOp(negate, [](ElementwiseT elem_operand) { + return NativeT(-type(elem_operand)); + })); return Status::OK(); } @@ -397,10 +427,10 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { !std::is_signed::value || std::is_floating_point::value>::type* = nullptr> Status HandleNegate(HloInstruction* negate) { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[negate], - ElementWiseUnaryOp(negate, [](ReturnT elem_operand) { - return -elem_operand; - })); + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[negate], + ElementWiseUnaryOp( + negate, [](ElementwiseT elem_operand) { return -elem_operand; })); return Status::OK(); } @@ -413,9 +443,9 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { typename std::enable_if::value>::type* = nullptr> Status HandleSign(HloInstruction* sign) { TF_ASSIGN_OR_RETURN(parent_->evaluated_[sign], - ElementWiseUnaryOp(sign, [](ReturnT elem_operand) { - return (ReturnT(0) < elem_operand) - - (elem_operand < ReturnT(0)); + ElementWiseUnaryOp(sign, [](ElementwiseT elem_operand) { + return (ElementwiseT(0) < elem_operand) - + (elem_operand < ElementwiseT(0)); })); return Status::OK(); } @@ -425,9 +455,9 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { typename std::enable_if::value>::type* = nullptr> Status HandleSign(HloInstruction* sign) { TF_ASSIGN_OR_RETURN(parent_->evaluated_[sign], - ElementWiseUnaryOp(sign, [](ReturnT elem_operand) { + ElementWiseUnaryOp(sign, [](ElementwiseT elem_operand) { auto abs_val = std::abs(elem_operand); - return 0 == abs_val ? ReturnT(0) + return 0 == abs_val ? ElementwiseT(0) : elem_operand / abs_val; })); return Status::OK(); @@ -437,9 +467,30 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { return HandleSign(sign); } + template ::value>::type* = nullptr> + Status HandleAtan2(HloInstruction* atan2) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[atan2], + ElementWiseBinaryOp(atan2, [](ElementwiseT lhs_elem, + ElementwiseT rhs_elem) { + return std::atan2(lhs_elem, rhs_elem); + })); + return Status::OK(); + } + + template ::value>::type* = nullptr> + Status HandleAtan2(HloInstruction* atan2) { + return InvalidArgument("Unsupported type for Atan2"); + } + + Status HandleAtan2(HloInstruction* atan2) override { + return HandleAtan2(atan2); + } + Status HandleTanh(HloInstruction* tanh) override { TF_ASSIGN_OR_RETURN(parent_->evaluated_[tanh], - ElementWiseUnaryOp(tanh, [](ReturnT elem_operand) { + ElementWiseUnaryOp(tanh, [](ElementwiseT elem_operand) { return std::tanh(elem_operand); })); return Status::OK(); @@ -453,9 +504,10 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { using type = typename std::make_unsigned::type; TF_ASSIGN_OR_RETURN( parent_->evaluated_[multiply], - ElementWiseBinaryOp(multiply, [](ReturnT lhs_elem, ReturnT rhs_elem) { - return NativeT(type(lhs_elem) * type(rhs_elem)); - })); + ElementWiseBinaryOp(multiply, + [](ElementwiseT lhs_elem, ElementwiseT rhs_elem) { + return NativeT(type(lhs_elem) * type(rhs_elem)); + })); return Status::OK(); } @@ -467,40 +519,42 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { Status HandleMultiply(HloInstruction* multiply) { TF_ASSIGN_OR_RETURN( parent_->evaluated_[multiply], - ElementWiseBinaryOp(multiply, [](ReturnT lhs_elem, ReturnT rhs_elem) { - return lhs_elem * rhs_elem; - })); + ElementWiseBinaryOp(multiply, + [](ElementwiseT lhs_elem, ElementwiseT rhs_elem) { + return lhs_elem * rhs_elem; + })); return Status::OK(); } Status HandleMultiply(HloInstruction* multiply) override { - return HandleMultiply(multiply); + return HandleMultiply(multiply); } Status HandleSubtract(HloInstruction* subtract) override { TF_ASSIGN_OR_RETURN( parent_->evaluated_[subtract], - ElementWiseBinaryOp(subtract, [](ReturnT lhs_elem, ReturnT rhs_elem) { - return lhs_elem - rhs_elem; - })); + ElementWiseBinaryOp(subtract, + [](ElementwiseT lhs_elem, ElementwiseT rhs_elem) { + return lhs_elem - rhs_elem; + })); return Status::OK(); } Status HandleAdd(HloInstruction* add) override { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[add], - ElementWiseBinaryOp(add, [](ReturnT lhs_elem, ReturnT rhs_elem) { - return lhs_elem + rhs_elem; - })); + TF_ASSIGN_OR_RETURN(parent_->evaluated_[add], + ElementWiseBinaryOp(add, [](ElementwiseT lhs_elem, + ElementwiseT rhs_elem) { + return lhs_elem + rhs_elem; + })); return Status::OK(); } Status HandleDivide(HloInstruction* divide) override { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[divide], - ElementWiseBinaryOp(divide, [](ReturnT lhs_elem, ReturnT rhs_elem) { - return lhs_elem / rhs_elem; - })); + TF_ASSIGN_OR_RETURN(parent_->evaluated_[divide], + ElementWiseBinaryOp(divide, [](ElementwiseT lhs_elem, + ElementwiseT rhs_elem) { + return lhs_elem / rhs_elem; + })); return Status::OK(); } @@ -510,7 +564,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { Status HandleMaximum(HloInstruction* maximum) { TF_ASSIGN_OR_RETURN( parent_->evaluated_[maximum], - ElementWiseBinaryOp(maximum, [](ReturnT lhs, ReturnT rhs) { + ElementWiseBinaryOp(maximum, [](ElementwiseT lhs, ElementwiseT rhs) { return std::fmax(lhs, rhs); })); return Status::OK(); @@ -524,18 +578,18 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { } Status HandleMaximum(HloInstruction* maximum) override { - return HandleMaximum(maximum); + return HandleMaximum(maximum); } template < typename NativeT, typename std::enable_if::value>::type* = nullptr> Status HandleMinimum(HloInstruction* minimum) { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[minimum], - ElementWiseBinaryOp(minimum, [](ReturnT lhs_el, ReturnT rhs_el) { - return std::fmin(lhs_el, rhs_el); - })); + TF_ASSIGN_OR_RETURN(parent_->evaluated_[minimum], + ElementWiseBinaryOp(minimum, [](ElementwiseT lhs_el, + ElementwiseT rhs_el) { + return std::fmin(lhs_el, rhs_el); + })); return Status::OK(); } @@ -547,15 +601,15 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { } Status HandleMinimum(HloInstruction* minimum) override { - return HandleMinimum(minimum); + return HandleMinimum(minimum); } Status HandlePower(HloInstruction* power) override { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[power], - ElementWiseBinaryOp(power, [](ReturnT lhs_el, ReturnT rhs_el) { - return std::pow(lhs_el, rhs_el); - })); + TF_ASSIGN_OR_RETURN(parent_->evaluated_[power], + ElementWiseBinaryOp(power, [](ElementwiseT lhs_el, + ElementwiseT rhs_el) { + return std::pow(lhs_el, rhs_el); + })); return Status::OK(); } @@ -563,11 +617,11 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { typename NativeT, typename std::enable_if::value>::type* = nullptr> Status HandleRemainder(HloInstruction* remainder) { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[remainder], - ElementWiseBinaryOp(remainder, [](ReturnT lhs_el, ReturnT rhs_el) { - return std::fmod(lhs_el, rhs_el); - })); + TF_ASSIGN_OR_RETURN(parent_->evaluated_[remainder], + ElementWiseBinaryOp(remainder, [](ElementwiseT lhs_el, + ElementwiseT rhs_el) { + return std::fmod(lhs_el, rhs_el); + })); return Status::OK(); } @@ -579,7 +633,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { } Status HandleRemainder(HloInstruction* remainder) override { - return HandleRemainder(remainder); + return HandleRemainder(remainder); } template evaluated_[and_], - ElementWiseBinaryOp(and_, [](ReturnT lhs_el, ReturnT rhs_el) { + ElementWiseBinaryOp(and_, [](ElementwiseT lhs_el, ElementwiseT rhs_el) { return lhs_el & rhs_el; })); return Status::OK(); @@ -599,7 +653,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { Status HandleAnd(HloInstruction* and_) { TF_ASSIGN_OR_RETURN( parent_->evaluated_[and_], - ElementWiseBinaryOp(and_, [](ReturnT lhs_el, ReturnT rhs_el) { + ElementWiseBinaryOp(and_, [](ElementwiseT lhs_el, ElementwiseT rhs_el) { return lhs_el && rhs_el; })); return Status::OK(); @@ -613,7 +667,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { } Status HandleAnd(HloInstruction* and_) override { - return HandleAnd(and_); + return HandleAnd(and_); } template evaluated_[or_], - ElementWiseBinaryOp(or_, [](ReturnT lhs_el, ReturnT rhs_el) { + ElementWiseBinaryOp(or_, [](ElementwiseT lhs_el, ElementwiseT rhs_el) { return lhs_el | rhs_el; })); return Status::OK(); @@ -633,7 +687,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { Status HandleOr(HloInstruction* or_) { TF_ASSIGN_OR_RETURN( parent_->evaluated_[or_], - ElementWiseBinaryOp(or_, [](ReturnT lhs_el, ReturnT rhs_el) { + ElementWiseBinaryOp(or_, [](ElementwiseT lhs_el, ElementwiseT rhs_el) { return lhs_el || rhs_el; })); return Status::OK(); @@ -647,7 +701,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { } Status HandleOr(HloInstruction* or_) override { - return HandleOr(or_); + return HandleOr(or_); } template (shl); + return HandleShiftLeft(shl); } template (shra); + return HandleShiftRightArithmetic(shra); } template (shrl); + return HandleShiftRightLogical(shrl); } template < typename NativeT, typename std::enable_if::value>::type* = nullptr> Status HandleClamp(HloInstruction* clamp) { - std::function clamp_op = - [](ReturnT low, ReturnT value, ReturnT high) { + std::function + clamp_op = [](ElementwiseT low, ElementwiseT value, ElementwiseT high) { return std::fmax(low, std::fmin(value, high)); }; - TF_ASSIGN_OR_RETURN(parent_->evaluated_[clamp], - ElementWiseTernaryOp(clamp, std::move(clamp_op))); + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[clamp], + ElementwiseTernaryOp(clamp, + std::move(ConvertTernaryFunction(clamp_op)))); return Status::OK(); } @@ -749,7 +805,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { } Status HandleClamp(HloInstruction* clamp) override { - return HandleClamp(clamp); + return HandleClamp(clamp); } Status HandleSelect(HloInstruction* select) override { @@ -762,7 +818,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { return on_false; }; TF_ASSIGN_OR_RETURN(parent_->evaluated_[select], - ElementWiseTernaryOp(select, std::move(select_op))); + ElementwiseTernaryOp(select, std::move(select_op))); return Status::OK(); } @@ -780,7 +836,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { << " but is inferred to be: " << ShapeUtil::HumanString(inferred_return_shape); - auto operand_literal = parent_->GetEvaluatedLiteralFor(operand); + const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); auto result = Literal::CreateFromShape(result_shape); TF_RETURN_IF_ERROR(result->Populate( @@ -860,7 +916,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { DimensionVector rhs_spatial_index(dnums.kernel_spatial_dimensions_size()); auto func = [&](tensorflow::gtl::ArraySlice out_index) { - ReturnT result_val = static_cast(0); + ElementwiseT result_val = static_cast(0); std::fill(lhs_index.begin(), lhs_index.end(), 0); std::fill(rhs_index.begin(), rhs_index.end(), 0); @@ -911,13 +967,14 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { : rhs_spatial_index[ki]; } - result_val += lhs_literal.Get(lhs_index) * - rhs_literal.Get(rhs_index); + result_val += + static_cast(lhs_literal.Get(lhs_index)) * + static_cast(rhs_literal.Get(rhs_index)); } cnt : {} } while (IndexUtil::BumpIndices(window_shape, &rhs_spatial_index)); - return result_val; + return static_cast(result_val); }; auto result = Literal::CreateFromShape(result_shape); @@ -967,7 +1024,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { auto result = Literal::CreateFromShape(dot->shape()); TF_RETURN_IF_ERROR(result->Populate( [&](tensorflow::gtl::ArraySlice multi_index) { - ReturnT result_val = static_cast(0); + ElementwiseT result_val = static_cast(0); std::vector lhs_index(lhs_rank, 0); std::vector rhs_index(rhs_rank, 0); @@ -984,11 +1041,12 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { lhs_index[lhs_contracted_dimension] = i; rhs_index[rhs_contracted_dimension] = i; - result_val += lhs_literal.Get(lhs_index) * - rhs_literal.Get(rhs_index); + result_val += + static_cast(lhs_literal.Get(lhs_index)) * + static_cast(rhs_literal.Get(rhs_index)); } - return result_val; + return static_cast(result_val); })); parent_->evaluated_[dot] = std::move(result); @@ -1021,7 +1079,8 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { return scalar; })); - auto evaluated_operand = parent_->GetEvaluatedLiteralFor(pad->operand(0)); + const Literal& evaluated_operand = + parent_->GetEvaluatedLiteralFor(pad->operand(0)); std::vector input_index(ShapeUtil::Rank(evaluated_operand.shape()), 0); @@ -1174,6 +1233,97 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); } + template + StatusOr> MapImpl(HloInstruction* map) { + auto operands = map->operands(); + HloComputation* computation = map->to_apply(); + + auto result = Literal::CreateFromShape(map->shape()); + + HloEvaluator embedded_evaluator; + TF_RETURN_IF_ERROR(result->Populate( + [&](tensorflow::gtl::ArraySlice multi_index) { + std::vector> arg_literals; + arg_literals.reserve(operands.size()); + + // Construct scalar literal parameters to be passed to the map + // computation. + for (auto operand : operands) { + const Literal& arg_literal = + parent_->GetEvaluatedLiteralFor(operand); + + auto curr_val = arg_literal.Get(multi_index); + auto curr_val_literal = Literal::CreateR0(curr_val); + + arg_literals.push_back(std::move(curr_val_literal)); + } + + std::unique_ptr computed_result = + embedded_evaluator + .Evaluate>(*computation, + arg_literals) + .ConsumeValueOrDie(); + // Clear visit states so that the we can use the evaluate again on + // the same computation. + embedded_evaluator.ResetVisitStates(); + + return computed_result->Get({}); + })); + return std::move(result); + } + + Status HandleMap(HloInstruction* map) override { + switch (map->operand(0)->shape().element_type()) { + case PRED: { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); + break; + } + case U8: { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); + break; + } + case U32: { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); + break; + } + case U64: { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); + break; + } + case S8: { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); + break; + } + case S32: { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); + break; + } + case S64: { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); + break; + } + case F32: { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); + break; + } + case F64: { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); + break; + } + case C64: { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); + break; + } + default: + LOG(FATAL) << "HandleMap: unhandled primitive type for " + "input operand: " + << PrimitiveType_Name( + map->operand(0)->shape().element_type()); + } + + return Status::OK(); + } + Status HandleReduce(HloInstruction* reduce) override { auto arg = reduce->operand(0); auto init_value = reduce->operand(1); @@ -1220,6 +1370,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { } } + HloEvaluator embedded_evaluator; // For each resulting dimension, calculate and assign computed value. TF_RETURN_IF_ERROR(result->Populate( [&](tensorflow::gtl::ArraySlice multi_index) { @@ -1239,13 +1390,12 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { std::vector args = {curr_val_literal.get(), result_val_literal.get()}; - // We need a new visitor for each evaluation, so that the same - // computation can be visited more than once (with different - // inputs). - HloEvaluator embedded_evaluator; std::unique_ptr computed_result = - embedded_evaluator.Evaluate(*function, args) + embedded_evaluator.Evaluate(*function, args) .ConsumeValueOrDie(); + // Clear visit states so that the we can use the evaluate again on + // the same computation. + embedded_evaluator.ResetVisitStates(); // Assign computed result to result_val. result_val = computed_result->Get({}); @@ -1302,6 +1452,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { DimensionVector window_index(window.dimensions_size()); DimensionVector operand_index(ShapeUtil::Rank(operand_literal.shape())); + HloEvaluator embedded_evaluator; // For each resulting dimension, calculate and assign computed value. TF_RETURN_IF_ERROR(result->Populate( [&](tensorflow::gtl::ArraySlice output_index) { @@ -1311,8 +1462,6 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { std::fill(operand_index.begin(), operand_index.end(), 0); do { - // Set curr_val to 0 if out of bound (padded). - ReturnT curr_val = static_cast(0); bool out_of_bound = false; for (int i = 0; i < operand_index.size(); ++i) { operand_index[i] = @@ -1325,23 +1474,25 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { } } if (!out_of_bound) { - curr_val = operand_literal.Get(operand_index); + auto curr_val = operand_literal.Get(operand_index); + + // Evaluate computation with specified literal operands. + const auto curr_val_literal = + Literal::CreateR0(curr_val); + const auto result_val_literal = + Literal::CreateR0(result_val); + const std::vector args = { + curr_val_literal.get(), result_val_literal.get()}; + std::unique_ptr computed_result = + embedded_evaluator.Evaluate(*function, args) + .ConsumeValueOrDie(); + + // Clear visit states so that the we can use the evaluate again on + // the same computation. + embedded_evaluator.ResetVisitStates(); + + result_val = computed_result->Get({}); } - // Evaluate computation with specified literal operands. - const auto curr_val_literal = Literal::CreateR0(curr_val); - const auto result_val_literal = - Literal::CreateR0(result_val); - const std::vector args = {curr_val_literal.get(), - result_val_literal.get()}; - // We need a new visitor for each evaluation, so that the same - // computation can be visited more than once (with different - // inputs). - HloEvaluator embedded_evaluator; - std::unique_ptr computed_result = - embedded_evaluator.Evaluate(*function, args) - .ConsumeValueOrDie(); - - result_val = computed_result->Get({}); } while (IndexUtil::BumpIndices(window_shape, &window_index)); return result_val; @@ -1364,7 +1515,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { << ShapeUtil::HumanString(inferred_return_shape); const int64 rank = ShapeUtil::Rank(operand->shape()); - auto operand_literal = parent_->GetEvaluatedLiteralFor(operand); + const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); auto func = [&](tensorflow::gtl::ArraySlice out_index) { DimensionVector operand_index(rank); for (int64 i = 0; i < rank; ++i) { @@ -1385,7 +1536,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { NativeT>::value>::type* = nullptr> Status HandleSin(HloInstruction* sin) { TF_ASSIGN_OR_RETURN(parent_->evaluated_[sin], - ElementWiseUnaryOp(sin, [](ReturnT elem_operand) { + ElementWiseUnaryOp(sin, [](ElementwiseT elem_operand) { return std::sin(elem_operand); })); return Status::OK(); @@ -1400,14 +1551,14 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { } Status HandleSin(HloInstruction* sin) override { - return HandleSin(sin); + return HandleSin(sin); } template ::value>::type* = nullptr> Status HandleCos(HloInstruction* cos) { TF_ASSIGN_OR_RETURN(parent_->evaluated_[cos], - ElementWiseUnaryOp(cos, [](ReturnT elem_operand) { + ElementWiseUnaryOp(cos, [](ElementwiseT elem_operand) { return std::cos(elem_operand); })); return Status::OK(); @@ -1422,7 +1573,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { } Status HandleCos(HloInstruction* cos) override { - return HandleCos(cos); + return HandleCos(cos); } private: @@ -1430,8 +1581,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { StatusOr> DynamicSlice( const Literal& operand_literal, const Literal& start_indices_literal, const Shape& result_shape) { - const auto& start_indices_typed = - start_indices_literal.GetArraySlice(); + auto start_indices_typed = start_indices_literal.data(); std::vector start(start_indices_typed.begin(), start_indices_typed.end()); @@ -1459,12 +1609,11 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { StatusOr> DynamicUpdateSlice( const Literal& operand_literal, const Literal& update_literal, const Literal& start_indices_literal) { - const auto& start_indices_typed = - start_indices_literal.GetArraySlice(); + auto start_indices_typed = start_indices_literal.data(); const std::vector start(start_indices_typed.begin(), start_indices_typed.end()); - auto result = MakeUnique(operand_literal); + auto result = operand_literal.CloneToUnique(); std::vector result_index(ShapeUtil::Rank(result->shape()), 0); auto func = [&](const std::vector& update_index) { @@ -1487,22 +1636,27 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { StatusOr> ElementWiseUnaryOp( HloInstruction* instruction, - const std::function& unary_op) { + const std::function& unary_op) { const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(instruction->operand(0)); - return ElementWiseUnaryOpImpl(instruction, unary_op, - operand_literal); + TF_ASSIGN_OR_RETURN( + auto result_literal, + (ElementWiseUnaryOpImpl( + instruction, ConvertUnaryFunction(unary_op), operand_literal))); + + return std::move(result_literal); } StatusOr> ElementWiseBinaryOp( HloInstruction* instruction, - const std::function& binary_op) { + const std::function& + binary_op) { const auto shape = instruction->shape(); const auto* lhs = instruction->operand(0); const auto* rhs = instruction->operand(1); - // TODO(b/35950897, b/27796129): add DCHECK back once implicit broadcast is - // removed. + // TODO(b/35950897, b/27796129): add DCHECK back once implicit broadcast + // is removed. if (!(ShapeUtil::SameDimensions(shape, rhs->shape()) && ShapeUtil::SameDimensions(lhs->shape(), rhs->shape()))) { return Unimplemented( @@ -1520,14 +1674,15 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { TF_RETURN_IF_ERROR(result->Populate( [&](tensorflow::gtl::ArraySlice multi_index) { - return binary_op(lhs_literal.Get(multi_index), - rhs_literal.Get(multi_index)); + return ConvertBinaryFunction(binary_op)( + lhs_literal.Get(multi_index), + rhs_literal.Get(multi_index)); })); return std::move(result); } template - StatusOr> ElementWiseTernaryOp( + StatusOr> ElementwiseTernaryOp( HloInstruction* instruction, const std::function& ternary_op) { const auto shape = instruction->shape(); @@ -1589,9 +1744,11 @@ HloEvaluator::HloEvaluator() { typed_visitors_[F64] = MakeUnique>(this); typed_visitors_[C64] = MakeUnique>(this); - typed_visitors_[BF16] = MakeUnique([](HloInstruction*) { - return Unimplemented("HloEvaluator: unhandled primitive type: BF16."); - }); + // Most of the evaluator computations we use don't support BF16 (e.g., + // std::ceil, std::tanh). To make evaluator work with BF16, we set all + // elementwise computations to be done in F32 and do BF16<->F32 conversion + // around the input and the output of the computations. + typed_visitors_[BF16] = MakeUnique>(this); typed_visitors_[TUPLE] = MakeUnique([](HloInstruction*) { return Unimplemented("HloEvaluator: unhandled primitive type: TUPLE."); }); @@ -1600,41 +1757,53 @@ HloEvaluator::HloEvaluator() { }); } +template StatusOr> HloEvaluator::Evaluate( const HloModule& module, - tensorflow::gtl::ArraySlice arg_literals) { + tensorflow::gtl::ArraySlice arg_literals) { XLA_VLOG_LINES(2, "HloEvaluator::Evaluate module:\n" + module.ToString()); - arg_literals_ = arg_literals; evaluated_.clear(); + arg_literals_.clear(); + for (const auto& literal_ptr : arg_literals) { + arg_literals_.push_back(&*literal_ptr); + } TF_RETURN_IF_ERROR(module.entry_computation()->Accept(this)); - return MakeUnique( - GetEvaluatedLiteralFor(module.entry_computation()->root_instruction())); + return GetEvaluatedLiteralFor(module.entry_computation()->root_instruction()) + .CloneToUnique(); } +template StatusOr> HloEvaluator::Evaluate( const HloComputation& computation, - tensorflow::gtl::ArraySlice arg_literals) { + tensorflow::gtl::ArraySlice arg_literals) { XLA_VLOG_LINES( 2, "HloEvaluator::Evaluate computation:\n" + computation.ToString()); - arg_literals_ = arg_literals; + evaluated_.clear(); + arg_literals_.clear(); + for (const auto& literal_ptr : arg_literals) { + arg_literals_.push_back(&*literal_ptr); + } TF_RETURN_IF_ERROR(computation.Accept(this)); - return MakeUnique( - GetEvaluatedLiteralFor(computation.root_instruction())); + return GetEvaluatedLiteralFor(computation.root_instruction()).CloneToUnique(); } +template StatusOr> HloEvaluator::Evaluate( HloInstruction* instruction, - tensorflow::gtl::ArraySlice operands) { + tensorflow::gtl::ArraySlice arg_literals) { TF_RET_CHECK(hlo_query::AllOperandsAreParametersOrConstants(*instruction)); TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(instruction->shape())); - arg_literals_ = operands; evaluated_.clear(); + arg_literals_.clear(); + for (const auto& literal_ptr : arg_literals) { + arg_literals_.push_back(&*literal_ptr); + } // Evaluate operands of Parameter type against the input literals which // caches the evaluated literal results. @@ -1645,14 +1814,14 @@ StatusOr> HloEvaluator::Evaluate( << input_literal->ToString(); TF_RET_CHECK(ShapeUtil::Equal(operand->shape(), input_literal->shape())); - evaluated_[operand] = MakeUnique(*input_literal); + evaluated_[operand] = input_literal->CloneToUnique(); } } TF_RETURN_IF_ERROR(Preprocess(instruction)); TF_RETURN_IF_ERROR(instruction->Visit(this)); TF_RETURN_IF_ERROR(Postprocess(instruction)); - return MakeUnique(GetEvaluatedLiteralFor(instruction)); + return GetEvaluatedLiteralFor(instruction).CloneToUnique(); } StatusOr> HloEvaluator::Evaluate( @@ -1673,7 +1842,7 @@ StatusOr> HloEvaluator::Evaluate( TF_RETURN_IF_ERROR(Preprocess(instruction)); TF_RETURN_IF_ERROR(instruction->Visit(this)); TF_RETURN_IF_ERROR(Postprocess(instruction)); - return MakeUnique(GetEvaluatedLiteralFor(instruction)); + return GetEvaluatedLiteralFor(instruction).CloneToUnique(); } std::unique_ptr HloEvaluator::TryEvaluate( @@ -1722,11 +1891,15 @@ StatusOr> HloEvaluator::EvaluateWithSubstitutions( } Status HloEvaluator::HandleParameter(HloInstruction* parameter) { + CHECK_LT(parameter->parameter_number(), arg_literals_.size()); const Literal* input_literal = arg_literals_[parameter->parameter_number()]; VLOG(2) << "Parameter evaluated to: " << input_literal->ToString(); - DCHECK(ShapeUtil::Equal(parameter->shape(), input_literal->shape())); + DCHECK(ShapeUtil::Equal(parameter->shape(), input_literal->shape())) + << "parameter shape is: " << ShapeUtil::HumanString(parameter->shape()) + << ", but input literal shape is: " + << ShapeUtil::HumanString(input_literal->shape()); - evaluated_[parameter] = MakeUnique(*input_literal); + evaluated_[parameter] = input_literal->CloneToUnique(); return Status::OK(); } @@ -1749,8 +1922,8 @@ Status HloEvaluator::HandleTranspose(HloInstruction* transpose) { Status HloEvaluator::HandleConcatenate(HloInstruction* concatenate) { tensorflow::gtl::ArraySlice operands( concatenate->operands()); - // The result concatenate dimension is going to be the sum of all concatenate - // dimensions of the operands taking part of the operation. + // The result concatenate dimension is going to be the sum of all + // concatenate dimensions of the operands taking part of the operation. const Shape& reference_shape = operands[0]->shape(); CHECK(!ShapeUtil::IsTuple(reference_shape)); const int64 rank = ShapeUtil::Rank(reference_shape); @@ -1777,7 +1950,7 @@ Status HloEvaluator::HandleConcatenate(HloInstruction* concatenate) { for (auto operand : operands) { const Shape& operand_shape = operand->shape(); - TF_RETURN_IF_ERROR(result_literal->Copy( + TF_RETURN_IF_ERROR(result_literal->CopySliceFrom( GetEvaluatedLiteralFor(operand), source_indices, dest_indices, AsInt64Slice(operand_shape.dimensions()))); dest_indices[concat_dim] += @@ -1935,16 +2108,17 @@ Status HloEvaluator::HandleGetTupleElement(HloInstruction* get_tuple_element) { const Literal& operand_tuple_literal = GetEvaluatedLiteralFor(operand); - evaluated_[get_tuple_element] = - MakeUnique(operand_tuple_literal.tuple_literals(index)); - - return Status::OK(); + evaluated_[get_tuple_element] = MakeUnique( + ShapeUtil::GetTupleElementShape(operand->shape(), index)); + return evaluated_[get_tuple_element]->CopyFrom(operand_tuple_literal, + /*dest_shape_index=*/{}, + /*src_shape_index=*/{index}); } Status HloEvaluator::HandleCopy(HloInstruction* copy) { TF_RET_CHECK(ShapeUtil::Compatible(copy->shape(), copy->operand(0)->shape())); - auto result = MakeUnique(GetEvaluatedLiteralFor(copy->operand(0))); + auto result = GetEvaluatedLiteralFor(copy->operand(0)).CloneToUnique(); evaluated_[copy] = std::move(result); return Status::OK(); } @@ -1960,4 +2134,30 @@ Status HloEvaluator::Postprocess(HloInstruction* hlo) { return Status::OK(); } +// Explicit instantiation of templatized Evaluate* methods. +// +template StatusOr> HloEvaluator::Evaluate< + const Literal*>(const HloModule& module, + tensorflow::gtl::ArraySlice arg_literals); +template StatusOr> +HloEvaluator::Evaluate>( + const HloModule& module, + tensorflow::gtl::ArraySlice> arg_literals); + +template StatusOr> HloEvaluator::Evaluate< + const Literal*>(const HloComputation& computation, + tensorflow::gtl::ArraySlice arg_literals); +template StatusOr> +HloEvaluator::Evaluate>( + const HloComputation& computation, + tensorflow::gtl::ArraySlice> arg_literals); + +template StatusOr> HloEvaluator::Evaluate< + const Literal*>(HloInstruction* instruction, + tensorflow::gtl::ArraySlice arg_literals); +template StatusOr> +HloEvaluator::Evaluate>( + HloInstruction* instruction, + tensorflow::gtl::ArraySlice> arg_literals); + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h index 7557aaa2484d184555411a79d8dce2c9241427b0..02bb8b0a47065c359603a113f49626bf3ad344d8 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator.h @@ -42,9 +42,12 @@ class HloEvaluator : public DfsHloVisitorWithDefault { // Precondition: The indices of arg_literals correspond to the parameter // numbers of the HLO parameters in the computation. See comment below for an // example. + // `LiteralPtr` accepts either std::unique_ptr or const Literal* + // type. + template StatusOr> Evaluate( const HloModule& module, - tensorflow::gtl::ArraySlice arg_literals); + tensorflow::gtl::ArraySlice arg_literals); // Evaluates an HLO computation and an array of pointers to literals. // Returns the evaluated result as a literal if successful. @@ -62,9 +65,12 @@ class HloEvaluator : public DfsHloVisitorWithDefault { // where Parameter0 has parameter_number 0 and Parameter1 has parameter_number // 1 in this computation. The input literals array will then have its first // literal map to Parameter0 and the second map to Parameter1. + // `LiteralPtr` accepts either std::unique_ptr or const Literal* + // type. + template StatusOr> Evaluate( const HloComputation& computation, - tensorflow::gtl::ArraySlice arg_literals); + tensorflow::gtl::ArraySlice arg_literals); // Evaluates a single HLO instruction and an array of pointers to literals. // Return the evaluated result as literal if successful. @@ -72,10 +78,12 @@ class HloEvaluator : public DfsHloVisitorWithDefault { // 1. argument literals correspond to the input instruction's parameters in // their post-ordering. // 2. the instruction's operands must be of either Parameter or Constant type. - // TODO(b/35950897): implement more ops other than element-wise ops. + // `LiteralPtr` accepts either std::unique_ptr or const Literal* + // type. + template StatusOr> Evaluate( HloInstruction* instruction, - tensorflow::gtl::ArraySlice arg_literals); + tensorflow::gtl::ArraySlice arg_literals); // Evaluates a single HLO instruction with constant operands. // Returns the evaluated result as literal if successful. @@ -100,12 +108,16 @@ class HloEvaluator : public DfsHloVisitorWithDefault { protected: // Templated DfsHloVisitor. Typically ReturnT here indicates the resulting // literal type of each evaluated Handle* method of a TypedVisitor. - // There are however a few notable exceptions to this is rule, notably: + // There are however a few notable exceptions to this rule, notably: // - HandleCompare and HandleIsFinite: where the resulting literal type is // always boolean. // These operations are handled outside of the parent HloEvaluator handlers // instead of from within TypedVisitor. - template + // + // Type params: + // - ReturnT: The type of input and output of each operation. + // - ElementwiseT: The type in which internal computation are done. + template class TypedVisitor; // Wraps around instruction handling to infer types before dispatching to @@ -134,6 +146,7 @@ class HloEvaluator : public DfsHloVisitorWithDefault { Status HandleIsFinite(HloInstruction* is_finite) override; Status HandleCompare(HloInstruction* compare) override; + Status HandleTuple(HloInstruction* tuple) override; Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; @@ -167,13 +180,15 @@ class HloEvaluator : public DfsHloVisitorWithDefault { // TODO(b/35950897): have better memory management here to free instructions // that are no longer a parent for any other subsequent instruction in // post-orderring. + // Must be cleared for each evaluation. tensorflow::gtl::FlatMap> evaluated_; - // Stores input literals, assuming they are in post-order. Literals are not - // owned by this class, and they must outlive the lifetime of the instance of - // this class. - tensorflow::gtl::ArraySlice arg_literals_; + // Caches pointers to input literals, assuming they are in post-order. + // Literals are not owned by this class, and they must outlive the lifetime of + // each invocation to the Evaluate* method. + // Must be cleared for each evaluation. + std::vector arg_literals_; TF_DISALLOW_COPY_AND_ASSIGN(HloEvaluator); }; diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc index b2c4351896764fa8683e91396f526d97ba208df6..97765d65909cee192f65069777f8f195081603b2 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc @@ -25,8 +25,10 @@ limitations under the License. #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_element_type_converter.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" @@ -35,15 +37,33 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" namespace xla { namespace { -class HloEvaluatorTest : public HloVerifiedTestBase { +static std::array use_bf16_params{true, false}; + +class HloEvaluatorTest : public ::testing::WithParamInterface, + public HloVerifiedTestBase { protected: - HloEvaluatorTest() { evaluator_ = MakeUnique(); } + HloEvaluatorTest() : use_bfloat16_(GetParam()) { + evaluator_ = MakeUnique(); + } + + std::unique_ptr Evaluate( + tensorflow::gtl::ArraySlice arg_literals = {}) { + if (use_bfloat16_) { + // In BF16 mode, we convert all F32 type to BF16 and evaluate the module. + auto type_converter = HloElementTypeConverter(F32, BF16); + type_converter.Run(&module()).ValueOrDie(); + } + return evaluator_->Evaluate(*module().entry_computation(), arg_literals) + .ConsumeValueOrDie(); + } std::unique_ptr evaluator_; @@ -52,12 +72,11 @@ class HloEvaluatorTest : public HloVerifiedTestBase { HloComputation::Builder b(TestName()); auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(input))); - auto instruction = b.AddInstruction( + b.AddInstruction( HloInstruction::CreateUnary(expected->shape(), opcode, c1)); module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(instruction, {}).ConsumeValueOrDie(); + std::unique_ptr result = Evaluate(); auto element_type = expected->shape().element_type(); if (element_type == F32 || element_type == F64) { @@ -74,20 +93,24 @@ class HloEvaluatorTest : public HloVerifiedTestBase { HloComputation::Builder b(TestName()); auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs))); auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs))); - auto instruction = b.AddInstruction( + b.AddInstruction( HloInstruction::CreateBinary(expected->shape(), opcode, c1, c2)); module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(instruction, {}).ConsumeValueOrDie(); + std::unique_ptr result = Evaluate(); LiteralTestUtil::ExpectEqual(*expected, *result); } + + bool use_bfloat16_; }; +#define XLA_TYPED_TEST_P(test_case_name, test_name, test_type1) \ + TEST_P(test_case_name, test_name) + // Verifies that HloEvaluator evaluates a HLO instruction that performs clamp // with 3 operands. -TEST_F(HloEvaluatorTest, DoesClamp) { +TEST_P(HloEvaluatorTest, DoesClamp) { auto low = Literal::CreateR2({{0.f, 2.f}, {2.f, 4.f}}); auto value = Literal::CreateR2({{0.f, 5.f}, {0.f, 4.f}}); auto high = Literal::CreateR2({{2.f, 4.f}, {4.f, 4.f}}); @@ -97,19 +120,18 @@ TEST_F(HloEvaluatorTest, DoesClamp) { auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(low))); auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(value))); auto c3 = b.AddInstruction(HloInstruction::CreateConstant(std::move(high))); - auto instruction = b.AddInstruction( + b.AddInstruction( HloInstruction::CreateTernary(shape, HloOpcode::kClamp, c1, c2, c3)); module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(instruction, {}).ConsumeValueOrDie(); + std::unique_ptr result = Evaluate(); auto expected = Literal::CreateR2({{0, 4}, {2, 4}}); LiteralTestUtil::ExpectEqual(*expected, *result); } -TEST_F(HloEvaluatorTest, DISABLED_DoesClampSpecialBroadcast) { +TEST_P(HloEvaluatorTest, DISABLED_DoesClampSpecialBroadcast) { auto low = Literal::CreateR0(0.f); auto value = Literal::CreateR2({{-1.f, 0.f}, {1.f, 2.f}}); auto high = Literal::CreateR0(1.f); @@ -119,12 +141,11 @@ TEST_F(HloEvaluatorTest, DISABLED_DoesClampSpecialBroadcast) { auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(low))); auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(value))); auto c3 = b.AddInstruction(HloInstruction::CreateConstant(std::move(high))); - auto instruction = b.AddInstruction( + b.AddInstruction( HloInstruction::CreateTernary(shape, HloOpcode::kClamp, c1, c2, c3)); module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(instruction, {}).ConsumeValueOrDie(); + std::unique_ptr result = Evaluate(); auto expected = Literal::CreateR2({{0, 0}, {1, 1}}); @@ -133,7 +154,7 @@ TEST_F(HloEvaluatorTest, DISABLED_DoesClampSpecialBroadcast) { // Verifies that HloEvaluator evaluates a HLO instruction that performs select // with 3 operands. -TEST_F(HloEvaluatorTest, DoesSelect) { +TEST_P(HloEvaluatorTest, DoesSelect) { auto pred = Literal::CreateR2({{true, false}, {false, true}}); auto on_true = Literal::CreateR2({{2.f, 4.f}, {4.f, 4.f}}); auto on_false = Literal::CreateR2({{0.f, 5.f}, {0.f, 4.f}}); @@ -145,12 +166,11 @@ TEST_F(HloEvaluatorTest, DoesSelect) { b.AddInstruction(HloInstruction::CreateConstant(std::move(on_true))); auto c3 = b.AddInstruction(HloInstruction::CreateConstant(std::move(on_false))); - auto instruction = b.AddInstruction( + b.AddInstruction( HloInstruction::CreateTernary(shape, HloOpcode::kSelect, c1, c2, c3)); module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(instruction, {}).ConsumeValueOrDie(); + std::unique_ptr result = Evaluate({}); auto expected = Literal::CreateR2({{2, 5}, {0, 4}}); @@ -159,7 +179,7 @@ TEST_F(HloEvaluatorTest, DoesSelect) { // Verifies that HloEvaluator evaluates a HLO instruction that performs // element-wise addition with 2 operands. -TEST_F(HloEvaluatorTest, DoesAdd) { +TEST_P(HloEvaluatorTest, DoesAdd) { auto lhs = Literal::CreateR2({{1, 0}, {-100, 4}}); auto rhs = Literal::CreateR2({{2, 4}, {4, 4}}); auto expected = Literal::CreateR2({{3, 4}, {-96, 8}}); @@ -168,7 +188,7 @@ TEST_F(HloEvaluatorTest, DoesAdd) { } // Verifies that HloEvaluator evaluates a HLO instruction that performs // element-wise and with 2 operands. -TEST_F(HloEvaluatorTest, DoesAnd) { +TEST_P(HloEvaluatorTest, DoesAnd) { auto lhs = Literal::CreateR2({{1, 0}, {-100, 4}}); auto rhs = Literal::CreateR2({{2, 4}, {4, 4}}); auto expected = Literal::CreateR2({{0, 0}, {4, 4}}); @@ -177,7 +197,7 @@ TEST_F(HloEvaluatorTest, DoesAnd) { } // Verifies that HloEvaluator evaluates a HLO instruction that performs // element-wise or with 2 operands. -TEST_F(HloEvaluatorTest, DoesOr) { +TEST_P(HloEvaluatorTest, DoesOr) { auto lhs = Literal::CreateR2({{1, 0}, {-100, 4}}); auto rhs = Literal::CreateR2({{2, 4}, {4, 4}}); auto expected = Literal::CreateR2({{3, 4}, {-100, 4}}); @@ -186,7 +206,7 @@ TEST_F(HloEvaluatorTest, DoesOr) { } // Verifies that HloEvaluator evaluates a HLO instruction that performs // element-wise multiply with 2 operands. -TEST_F(HloEvaluatorTest, DoesMultiply) { +TEST_P(HloEvaluatorTest, DoesMultiply) { auto lhs = Literal::CreateR2({{-1, 0}, {-100, 4}}); auto rhs = Literal::CreateR2( {{std::numeric_limits::min(), 4}, {4, 4}}); @@ -197,14 +217,14 @@ TEST_F(HloEvaluatorTest, DoesMultiply) { } // Verifies that HloEvaluator evaluates a HLO instruction that performs // element-wise divide with 2 operands. -TEST_F(HloEvaluatorTest, DoesDivideInt64) { +TEST_P(HloEvaluatorTest, DoesDivideInt64) { auto lhs = Literal::CreateR2({{1, 0}, {-100, 4}}); auto rhs = Literal::CreateR2({{2, 4}, {4, 4}}); auto expected = Literal::CreateR2({{0, 0}, {-25, 1}}); TestBinaryOp(HloOpcode::kDivide, std::move(expected), std::move(lhs), std::move(rhs)); } -TEST_F(HloEvaluatorTest, DoesDivideDouble) { +TEST_P(HloEvaluatorTest, DoesDivideDouble) { auto lhs = Literal::CreateR2({{1.0, 0.0}, {-100.0, 4.0}}); auto rhs = Literal::CreateR2({{2.2, 4.0}, {4.0, 4.0}}); auto expected = @@ -215,40 +235,41 @@ TEST_F(HloEvaluatorTest, DoesDivideDouble) { // Verifies that HloEvaluator evaluates a HLO instruction that performs // element-wise abs op with 1 operand. -TEST_F(HloEvaluatorTest, DoesAbsR2) { +TEST_P(HloEvaluatorTest, DoesAbsR2) { auto operand = Literal::CreateR2({{1, -20}, {-100, 4}}); auto expected = Literal::CreateR2({{1, 20}, {100, 4}}); TestUnaryOp(HloOpcode::kAbs, std::move(expected), std::move(operand)); } -TEST_F(HloEvaluatorTest, DoesAbsR0) { +TEST_P(HloEvaluatorTest, DoesAbsR0) { auto operand = Literal::CreateR0(-1.0f); auto expected = Literal::CreateR0(1.0f); TestUnaryOp(HloOpcode::kAbs, std::move(expected), std::move(operand)); } -TEST_F(HloEvaluatorTest, DoesAbsR1WithZeroSize) { +TEST_P(HloEvaluatorTest, DoesAbsR1WithZeroSize) { auto operand = Literal::CreateR1({}); auto expected = Literal::CreateR1({}); TestUnaryOp(HloOpcode::kAbs, std::move(expected), std::move(operand)); } -TEST_F(HloEvaluatorTest, DoesNegateR2) { +TEST_P(HloEvaluatorTest, DoesNegateR2) { auto operand = Literal::CreateR2( {{0, std::numeric_limits::min()}, {-1, 4}}); auto expected = Literal::CreateR2({{0, std::numeric_limits::min()}, {1, -4}}); TestUnaryOp(HloOpcode::kNegate, std::move(expected), std::move(operand)); } -TEST_F(HloEvaluatorTest, DoesCosR2) { +TEST_P(HloEvaluatorTest, DoesCosR2) { auto operand = Literal::CreateR2({{0, M_PI}, {-M_PI, 2 * M_PI}}); auto expected = Literal::CreateR2({{1, -1}, {-1, 1}}); - TestUnaryOp(HloOpcode::kCos, std::move(expected), std::move(operand)); + TestUnaryOp(HloOpcode::kCos, std::move(expected), std::move(operand), + use_bfloat16_ ? 0x1.0P-5 : 0x1.0P-20); } -TEST_F(HloEvaluatorTest, DoesSinR2) { +TEST_P(HloEvaluatorTest, DoesSinR2) { auto operand = Literal::CreateR2({{0, M_PI}, {-M_PI, 2 * M_PI}}); auto expected = Literal::CreateR2({{0, 0}, {0, 0}}); TestUnaryOp(HloOpcode::kSin, std::move(expected), std::move(operand), - 0x1.0P-20); + use_bfloat16_ ? 0x1.0P-5 : 0x1.0P-20); } -TEST_F(HloEvaluatorTest, DoesNotR2) { +TEST_P(HloEvaluatorTest, DoesNotR2) { auto operand = Literal::CreateR2({{0, std::numeric_limits::min()}, {-1, std::numeric_limits::max()}}); @@ -259,7 +280,7 @@ TEST_F(HloEvaluatorTest, DoesNotR2) { } // Verifies that HloEvaluator evaluates a HLO Computation with non-parameter nor // constant operands. -TEST_F(HloEvaluatorTest, DoesTraverseInstructions) { +TEST_P(HloEvaluatorTest, DoesTraverseInstructions) { auto lhs = Literal::CreateR2({{1, 0}, {-100, 4}}); auto rhs = Literal::CreateR2({{2, 4}, {4, 4}}); auto rhs2 = Literal::CreateR2({{1, -20}, {-100, 4}}); @@ -279,10 +300,9 @@ TEST_F(HloEvaluatorTest, DoesTraverseInstructions) { b.AddInstruction(HloInstruction::CreateParameter(2, shape, "rhs2")); b.AddInstruction(HloInstruction::CreateBinary(shape, HloOpcode::kAdd, lhs_instruction, param_rhs2)); - auto computation = module().AddEntryComputation(b.Build()); + module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(*computation, args).ConsumeValueOrDie(); + std::unique_ptr result = Evaluate(args); auto expected = Literal::CreateR2({{4, -16}, {-196, 12}}); @@ -290,7 +310,7 @@ TEST_F(HloEvaluatorTest, DoesTraverseInstructions) { } // Verifies Reshape operation is correctly evaluated. -TEST_F(HloEvaluatorTest, DoesReshape) { +TEST_P(HloEvaluatorTest, DoesReshape) { HloComputation::Builder b(TestName()); const int64 dimensions[] = {11, 8, 7, 5, 9}; TF_ASSERT_OK_AND_ASSIGN(auto literal, @@ -304,21 +324,20 @@ TEST_F(HloEvaluatorTest, DoesReshape) { const int64 permutation[] = {1, 2, 0, 4, 3}; b.AddInstruction( HloInstruction::CreateTranspose(shape, literal_instruction, permutation)); - auto computation = module().AddEntryComputation(b.Build()); + module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + std::unique_ptr result = Evaluate({}); using NativeT = typename primitive_util::PrimitiveTypeToNative::type; result->EachCell( [&](tensorflow::gtl::ArraySlice indices, NativeT value) { std::vector rindexes = Permute(permutation, indices); - EXPECT_TRUE(value == literal_clone->Get(rindexes)); + EXPECT_NEAR(value, literal_clone->Get(rindexes), 0x1.0P-5); }); } // Verifies Broadcast operation is correctly evaluated. -TEST_F(HloEvaluatorTest, DoesBroadcast) { +TEST_P(HloEvaluatorTest, DoesBroadcast) { HloComputation::Builder b(TestName()); auto input_literal = Literal::CreateR2({{1, 2}, {3, 4}, {5, 6}}); auto output_literal = Literal::CreateR3( @@ -327,15 +346,14 @@ TEST_F(HloEvaluatorTest, DoesBroadcast) { HloInstruction::CreateConstant(std::move(input_literal))); b.AddInstruction(HloInstruction::CreateBroadcast( output_literal->shape(), literal_instruction, {1, 2})); - auto computation = module().AddEntryComputation(b.Build()); + module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + std::unique_ptr result = Evaluate({}); LiteralTestUtil::ExpectEqual(*result, *output_literal); } -TEST_F(HloEvaluatorTest, DoesBroadcastScalar) { +TEST_P(HloEvaluatorTest, DoesBroadcastScalar) { HloComputation::Builder b(TestName()); auto input_literal = Literal::CreateR0(111); auto output_literal = Literal::CreateR2( @@ -347,15 +365,14 @@ TEST_F(HloEvaluatorTest, DoesBroadcastScalar) { b.AddInstruction(HloInstruction::CreateBroadcast( output_literal->shape(), literal_instruction, /*broadcast_dimensions=*/{})); - auto computation = module().AddEntryComputation(b.Build()); + module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + std::unique_ptr result = Evaluate({}); LiteralTestUtil::ExpectEqual(*result, *output_literal); } -TEST_F(HloEvaluatorTest, DoesConcatenateSimple) { +TEST_P(HloEvaluatorTest, DoesConcatenateSimple) { HloComputation::Builder b(TestName()); HloInstruction* operand1 = b.AddInstruction(HloInstruction::CreateConstant( @@ -368,17 +385,16 @@ TEST_F(HloEvaluatorTest, DoesConcatenateSimple) { Shape shape = ShapeUtil::MakeShape(S64, {4, 2}); b.AddInstruction(HloInstruction::CreateConcatenate(shape, operands, 0)); - auto computation = module().AddEntryComputation(b.Build()); + module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + std::unique_ptr result = Evaluate(); auto expected = Literal::CreateR2({{-1, -2}, {100, 200}, {-2, -3}, {-100, -200}}); LiteralTestUtil::ExpectEqual(*expected, *result); } -TEST_F(HloEvaluatorTest, ConcatenateHandlesShapeWithZeroElement) { +TEST_P(HloEvaluatorTest, ConcatenateHandlesShapeWithZeroElement) { HloComputation::Builder b(TestName()); HloInstruction* operand1 = b.AddInstruction( @@ -391,16 +407,15 @@ TEST_F(HloEvaluatorTest, ConcatenateHandlesShapeWithZeroElement) { Shape shape = ShapeUtil::MakeShape(S64, {2}); b.AddInstruction(HloInstruction::CreateConcatenate(shape, operands, 0)); - auto computation = module().AddEntryComputation(b.Build()); + module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + std::unique_ptr result = Evaluate(); auto expected = Literal::CreateR1({100, 200}); LiteralTestUtil::ExpectEqual(*expected, *result); } -TEST_F(HloEvaluatorTest, ConvertWithSameLayout) { +TEST_P(HloEvaluatorTest, ConvertWithSameLayout) { HloComputation::Builder b(TestName()); auto input_literal = Literal::CreateR2({{1, 2}, {3, 4}, {5, 6}}); @@ -412,15 +427,14 @@ TEST_F(HloEvaluatorTest, ConvertWithSameLayout) { HloInstruction* constant = b.AddInstruction( HloInstruction::CreateConstant(std::move(input_literal))); b.AddInstruction(HloInstruction::CreateConvert(expected->shape(), constant)); - auto computation = module().AddEntryComputation(b.Build()); + module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + std::unique_ptr result = Evaluate(); LiteralTestUtil::ExpectEqual(*result, *expected); } -TEST_F(HloEvaluatorTest, ConvertWithDifferentLayout) { +TEST_P(HloEvaluatorTest, ConvertWithDifferentLayout) { HloComputation::Builder b(TestName()); auto input_literal = Literal::CreateR2WithLayout( @@ -433,10 +447,9 @@ TEST_F(HloEvaluatorTest, ConvertWithDifferentLayout) { HloInstruction* constant = b.AddInstruction( HloInstruction::CreateConstant(std::move(input_literal))); b.AddInstruction(HloInstruction::CreateConvert(expected->shape(), constant)); - auto computation = module().AddEntryComputation(b.Build()); + module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + std::unique_ptr result = Evaluate(); LiteralTestUtil::ExpectEqual(*result, *expected); } @@ -454,7 +467,7 @@ PaddingConfig CreatePaddingConfig( return padding_config; } -TEST_F(HloEvaluatorTest, Pad2DIntegerArrayWithZeroDimension) { +TEST_P(HloEvaluatorTest, Pad2DIntegerArrayWithZeroDimension) { auto operand = Literal::CreateR2({{}, {}}); HloComputation::Builder b(TestName()); auto operand_instruction = @@ -467,11 +480,11 @@ TEST_F(HloEvaluatorTest, Pad2DIntegerArrayWithZeroDimension) { auto padding_config = CreatePaddingConfig({{{1, 0, 2}}, {{0, 2, 1}}}); Shape shape = ShapeUtil::MakeShape(S32, {5, 2}); - auto pad_instruction = b.AddInstruction(HloInstruction::CreatePad( + b.AddInstruction(HloInstruction::CreatePad( shape, operand_instruction, padding_value_instruction, padding_config)); module().AddEntryComputation(b.Build()); - auto result = evaluator_->Evaluate(pad_instruction).ConsumeValueOrDie(); + std::unique_ptr result = Evaluate(); auto expected = Literal::CreateR2( {{10, 10}, {10, 10}, {10, 10}, {10, 10}, {10, 10}}); @@ -479,7 +492,7 @@ TEST_F(HloEvaluatorTest, Pad2DIntegerArrayWithZeroDimension) { LiteralTestUtil::ExpectEqual(*expected, *result); } -TEST_F(HloEvaluatorTest, Pad4DFloatArrayWithInteriorPadding) { +TEST_P(HloEvaluatorTest, Pad4DFloatArrayWithInteriorPadding) { HloComputation::Builder b(TestName()); Array4D input_array(3, 2, 1, 1, {1, 2, 3, 4, 5, 6}); @@ -496,10 +509,9 @@ TEST_F(HloEvaluatorTest, Pad4DFloatArrayWithInteriorPadding) { CreatePaddingConfig({{{1, 0, 2}}, {{0, 2, 1}}, {{0, 0, 0}}, {{0, 0, 0}}}); b.AddInstruction(HloInstruction::CreatePad( shape, input_instruction, pad_instruction, r4_padding_on_dim0_dim1)); - auto computation = module().AddEntryComputation(b.Build()); + module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + std::unique_ptr result = Evaluate(); auto expected_array = MakeUnique>(8, 5, 1, 1); expected_array->Fill(kPadValue); @@ -515,7 +527,7 @@ TEST_F(HloEvaluatorTest, Pad4DFloatArrayWithInteriorPadding) { LiteralTestUtil::ExpectEqual(*expected, *result); } -TEST_F(HloEvaluatorTest, NegativePadding2D) { +TEST_P(HloEvaluatorTest, NegativePadding2D) { HloComputation::Builder b(TestName()); // input_array: @@ -541,10 +553,9 @@ TEST_F(HloEvaluatorTest, NegativePadding2D) { pad_value_instruction, r2_padding_on_dim0_dim1)); - auto computation = module().AddEntryComputation(b.Build()); + module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + std::unique_ptr result = Evaluate(); // f32[1,5] { 7.0, 2.718, 2.718, 2.718, 2.718 } auto expected_array = MakeUnique>(1, 5); @@ -555,10 +566,10 @@ TEST_F(HloEvaluatorTest, NegativePadding2D) { (*expected_array)(0, 4) = 2.718f; auto expected = Literal::CreateR2FromArray2D(*expected_array); - LiteralTestUtil::ExpectEqual(*expected, *result); + LiteralTestUtil::ExpectNear(*expected, *result, ErrorSpec(0x1.0P-5)); } -TEST_F(HloEvaluatorTest, NegativeAndInteriorPadding2D) { +TEST_P(HloEvaluatorTest, NegativeAndInteriorPadding2D) { HloComputation::Builder b(TestName()); // f32[4,3] { @@ -587,10 +598,9 @@ TEST_F(HloEvaluatorTest, NegativeAndInteriorPadding2D) { pad_value_instruction, r2_padding_on_dim0_dim1)); - auto computation = module().AddEntryComputation(b.Build()); + module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + std::unique_ptr result = Evaluate(); auto expected_array = MakeUnique>(0, 9); auto expected = Literal::CreateR2FromArray2D(*expected_array); @@ -598,7 +608,7 @@ TEST_F(HloEvaluatorTest, NegativeAndInteriorPadding2D) { LiteralTestUtil::ExpectEqual(*expected, *result); } -TEST_F(HloEvaluatorTest, DotRank2AndRank1) { +TEST_P(HloEvaluatorTest, DotRank2AndRank1) { HloComputation::Builder b(TestName()); // lhs: @@ -621,12 +631,14 @@ TEST_F(HloEvaluatorTest, DotRank2AndRank1) { b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal))); Shape shape = ShapeUtil::MakeShape(F32, {4, 2}); - b.AddInstruction(HloInstruction::CreateBinary( - shape, HloOpcode::kDot, lhs_instruction, rhs_instruction)); - auto computation = module().AddEntryComputation(b.Build()); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + b.AddInstruction(HloInstruction::CreateDot(shape, lhs_instruction, + rhs_instruction, dot_dnums)); + module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + std::unique_ptr result = Evaluate(); // clang-format off auto expected_array = Array2D({ @@ -641,7 +653,7 @@ TEST_F(HloEvaluatorTest, DotRank2AndRank1) { LiteralTestUtil::ExpectEqual(*expected, *result); } -TEST_F(HloEvaluatorTest, DotRank1AndRank2) { +TEST_P(HloEvaluatorTest, DotRank1AndRank2) { HloComputation::Builder b(TestName()); // lhs: @@ -664,19 +676,21 @@ TEST_F(HloEvaluatorTest, DotRank1AndRank2) { b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal))); Shape shape = ShapeUtil::MakeShape(F32, {2}); - b.AddInstruction(HloInstruction::CreateBinary( - shape, HloOpcode::kDot, lhs_instruction, rhs_instruction)); - auto computation = module().AddEntryComputation(b.Build()); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(0); + dot_dnums.add_rhs_contracting_dimensions(0); + b.AddInstruction(HloInstruction::CreateDot(shape, lhs_instruction, + rhs_instruction, dot_dnums)); + module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + std::unique_ptr result = Evaluate(); auto expected = Literal::CreateR1({22.f, 28.f}); LiteralTestUtil::ExpectEqual(*expected, *result); } -TEST_F(HloEvaluatorTest, DotRank2AndRank2) { +TEST_P(HloEvaluatorTest, DotRank2AndRank2) { HloComputation::Builder b(TestName()); // lhs: @@ -705,12 +719,14 @@ TEST_F(HloEvaluatorTest, DotRank2AndRank2) { b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal))); Shape shape = ShapeUtil::MakeShape(F32, {4, 2}); - b.AddInstruction(HloInstruction::CreateBinary( - shape, HloOpcode::kDot, lhs_instruction, rhs_instruction)); - auto computation = module().AddEntryComputation(b.Build()); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + b.AddInstruction(HloInstruction::CreateDot(shape, lhs_instruction, + rhs_instruction, dot_dnums)); + module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + std::unique_ptr result = Evaluate(); auto expected_array = Array2D({ {22.f, 28.f}, @@ -723,7 +739,7 @@ TEST_F(HloEvaluatorTest, DotRank2AndRank2) { LiteralTestUtil::ExpectEqual(*expected, *result); } -TEST_F(HloEvaluatorTest, SimpleConv1D) { +TEST_P(HloEvaluatorTest, SimpleConv1D) { HloComputation::Builder b(TestName()); Array3D lhs_array = {{{1, 2, 3}}}; @@ -761,10 +777,9 @@ TEST_F(HloEvaluatorTest, SimpleConv1D) { const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 3}); b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, window, dnums)); - auto computation = module().AddEntryComputation(b.Build()); + module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + std::unique_ptr result = Evaluate(); Array3D expected_array = {{{11.f, 18.f, 9.f}}}; auto expected = Literal::CreateR3FromArray3D(expected_array); @@ -772,7 +787,7 @@ TEST_F(HloEvaluatorTest, SimpleConv1D) { LiteralTestUtil::ExpectEqual(*expected, *result); } -TEST_F(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) { +TEST_P(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) { HloComputation::Builder b(TestName()); Array4D lhs_array(1, 1, 4, 4); @@ -816,10 +831,9 @@ TEST_F(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) { const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4}); b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, window, dnums)); - auto computation = module().AddEntryComputation(b.Build()); + module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + std::unique_ptr result = Evaluate(); Array4D expected_array(1, 1, 4, 4); // clang-format off @@ -835,7 +849,7 @@ TEST_F(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) { LiteralTestUtil::ExpectEqual(*expected, *result); } -TEST_F(HloEvaluatorTest, Conv2DGeneralDimensionsReversed) { +TEST_P(HloEvaluatorTest, Conv2DGeneralDimensionsReversed) { HloComputation::Builder b(TestName()); // clang-format off @@ -900,21 +914,22 @@ TEST_F(HloEvaluatorTest, Conv2DGeneralDimensionsReversed) { const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2}); b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, window, dnums)); - auto computation = module().AddEntryComputation(b.Build()); + module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + std::unique_ptr result = Evaluate(); // clang-format off // Result dimensions: [feature=1, height=1, batch=1, width=2] Array4D expected_array({{{{2514, 2685}}}}); + Array4D expected_array_bf16({{{{2512, 2672}}}}); // clang-format on - auto expected = Literal::CreateR4FromArray4D(expected_array); + auto expected = Literal::CreateR4FromArray4D( + use_bfloat16_ ? expected_array_bf16 : expected_array); LiteralTestUtil::ExpectEqual(*expected, *result); } -TEST_F(HloEvaluatorTest, Conv2DGeneralDimensions) { +TEST_P(HloEvaluatorTest, Conv2DGeneralDimensions) { HloComputation::Builder b(TestName()); // clang-format off @@ -976,21 +991,22 @@ TEST_F(HloEvaluatorTest, Conv2DGeneralDimensions) { const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2}); b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, window, dnums)); - auto computation = module().AddEntryComputation(b.Build()); + module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + std::unique_ptr result = Evaluate(); // clang-format off // Result dimensions: [feature=1, height=1, batch=1, width=2] Array4D expected_array({{{{2514, 2685}}}}); + Array4D expected_array_bf16({{{{2512, 2672}}}}); // clang-format on - auto expected = Literal::CreateR4FromArray4D(expected_array); + auto expected = Literal::CreateR4FromArray4D( + use_bfloat16_ ? expected_array_bf16 : expected_array); LiteralTestUtil::ExpectEqual(*expected, *result); } -TEST_F(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) { +TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) { HloComputation::Builder b(TestName()); Array4D lhs_array(1, 1, 4, 4); @@ -1034,10 +1050,9 @@ TEST_F(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) { const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 7, 7}); b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, window, dnums)); - auto computation = module().AddEntryComputation(b.Build()); + module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + std::unique_ptr result = Evaluate(); Array4D expected_array(1, 1, 7, 7); expected_array.FillWithYX(Array2D({ @@ -1054,7 +1069,7 @@ TEST_F(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) { LiteralTestUtil::ExpectEqual(*expected, *result); } -TEST_F(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) { +TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) { HloComputation::Builder b(TestName()); Array4D lhs_array(1, 1, 4, 4); @@ -1098,10 +1113,9 @@ TEST_F(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) { const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 8, 8}); b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, window, dnums)); - auto computation = module().AddEntryComputation(b.Build()); + module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + std::unique_ptr result = Evaluate(); Array4D expected_array(1, 1, 8, 8); expected_array.FillWithYX(Array2D({ @@ -1119,7 +1133,7 @@ TEST_F(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) { LiteralTestUtil::ExpectEqual(*expected, *result); } -TEST_F(HloEvaluatorTest, +TEST_P(HloEvaluatorTest, DilatedWindowAndBaseConv2DWithDifferentLowAndHighPaddingAndStrides) { HloComputation::Builder b(TestName()); @@ -1170,10 +1184,9 @@ TEST_F(HloEvaluatorTest, const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 9, 3}); b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, window, dnums)); - auto computation = module().AddEntryComputation(b.Build()); + module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + std::unique_ptr result = Evaluate(); Array4D expected_array(1, 1, 9, 3); expected_array.FillWithYX(Array2D({ @@ -1192,7 +1205,7 @@ TEST_F(HloEvaluatorTest, LiteralTestUtil::ExpectEqual(*expected, *result); } -TEST_F(HloEvaluatorTest, ReduceAdd) { +TEST_P(HloEvaluatorTest, ReduceAdd) { HloComputation::Builder b(TestName()); // arg: @@ -1225,17 +1238,16 @@ TEST_F(HloEvaluatorTest, ReduceAdd) { HloInstruction::CreateReduce(shape, arg_instruction, init_value, /*dimensions_to_reduce=*/{1}, add_func)); - auto computation = module().AddEntryComputation(b.Build()); + module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + std::unique_ptr result = Evaluate(); auto expected = Literal::CreateR1({6, 18}); LiteralTestUtil::ExpectEqual(*expected, *result); } -TEST_F(HloEvaluatorTest, ReduceWindowMax) { +TEST_P(HloEvaluatorTest, ReduceWindowMax) { HloComputation::Builder b(TestName()); // arg: @@ -1278,15 +1290,15 @@ TEST_F(HloEvaluatorTest, ReduceWindowMax) { b.AddInstruction(HloInstruction::CreateReduceWindow( shape, arg_instruction, init_value, window, max_func)); - auto computation = module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + module().AddEntryComputation(b.Build()); + + std::unique_ptr result = Evaluate(); auto expected = Literal::CreateR2({{6, 7}}); LiteralTestUtil::ExpectEqual(*expected, *result); } -TEST_F(HloEvaluatorTest, ReduceWindowAdd) { +TEST_P(HloEvaluatorTest, ReduceWindowAdd) { HloComputation::Builder b(TestName()); // arg: @@ -1335,21 +1347,21 @@ TEST_F(HloEvaluatorTest, ReduceWindowAdd) { b.AddInstruction(HloInstruction::CreateReduceWindow( shape, arg_instruction, init_value, window, add_func)); - auto computation = module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + module().AddEntryComputation(b.Build()); + + std::unique_ptr result = Evaluate(); auto expected = Literal::CreateR2({{1, 3, 5}, {5, 11, 13}}); LiteralTestUtil::ExpectEqual(*expected, *result); } -TEST_F(HloEvaluatorTest, ReduceWindowAdd6D) { +TEST_P(HloEvaluatorTest, ReduceWindowAdd6D) { HloComputation::Builder b(TestName()); // arg: f32[4,4,4,4,4,4] full of ones. Using small dims to limit run-time. std::vector input_dims(6, 4); std::unique_ptr arg_literal = - Literal::CreateFullWithMonotonicDim0MajorLayout(input_dims, 1.0f); + Literal::CreateFullWithDescendingLayout(input_dims, 1.0f); HloInstruction* arg_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(arg_literal))); @@ -1396,17 +1408,17 @@ TEST_F(HloEvaluatorTest, ReduceWindowAdd6D) { b.AddInstruction(HloInstruction::CreateReduceWindow( shape, arg_instruction, init_value, window, add_func)); - auto computation = module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + module().AddEntryComputation(b.Build()); + + std::unique_ptr result = Evaluate(); std::vector output_dims = {4, 3, 3, 3, 4, 4}; std::unique_ptr result_literal = - Literal::CreateFullWithMonotonicDim0MajorLayout(output_dims, 8.0f); + Literal::CreateFullWithDescendingLayout(output_dims, 8.0f); LiteralTestUtil::ExpectEqual(*result_literal, *result); } -TEST_F(HloEvaluatorTest, StridedSlice) { +TEST_P(HloEvaluatorTest, StridedSlice) { HloComputation::Builder b(TestName()); // arg: @@ -1427,10 +1439,9 @@ TEST_F(HloEvaluatorTest, StridedSlice) { /*start_indices=*/{0, 2}, /*limit_indices=*/{3, 5}, /*strides=*/{2, 3})); - auto computation = module().AddEntryComputation(b.Build()); + module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + std::unique_ptr result = Evaluate(); auto expected = Literal::CreateR2({ {3}, @@ -1440,7 +1451,7 @@ TEST_F(HloEvaluatorTest, StridedSlice) { LiteralTestUtil::ExpectEqual(*expected, *result); } -TEST_F(HloEvaluatorTest, DynamicSlice) { +TEST_P(HloEvaluatorTest, DynamicSlice) { HloComputation::Builder b(TestName()); // arg: @@ -1461,10 +1472,9 @@ TEST_F(HloEvaluatorTest, DynamicSlice) { Shape shape = ShapeUtil::MakeShape(F32, {2, 3}); b.AddInstruction(HloInstruction::CreateDynamicSlice(shape, operand, start_indices, {2, 3})); - auto computation = module().AddEntryComputation(b.Build()); + module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + std::unique_ptr result = Evaluate(); auto expected = Literal::CreateR2({ {2, 3, 4}, @@ -1476,7 +1486,7 @@ TEST_F(HloEvaluatorTest, DynamicSlice) { // Verifies that the HloEvaluator's implementation goes along with existing // backends' behavior, although this is not required by the spec. -TEST_F(HloEvaluatorTest, DynamicSliceModSlice) { +TEST_P(HloEvaluatorTest, DynamicSliceModSlice) { HloComputation::Builder b(TestName()); // arg: @@ -1497,10 +1507,9 @@ TEST_F(HloEvaluatorTest, DynamicSliceModSlice) { Shape shape = ShapeUtil::MakeShape(F32, {2, 3}); b.AddInstruction(HloInstruction::CreateDynamicSlice(shape, operand, start_indices, {2, 3})); - auto computation = module().AddEntryComputation(b.Build()); + module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + std::unique_ptr result = Evaluate(); auto expected = Literal::CreateR2({ {2, 3, 4}, @@ -1510,7 +1519,7 @@ TEST_F(HloEvaluatorTest, DynamicSliceModSlice) { LiteralTestUtil::ExpectEqual(*expected, *result); } -TEST_F(HloEvaluatorTest, DynamicSliceUpdate) { +TEST_P(HloEvaluatorTest, DynamicSliceUpdate) { HloComputation::Builder b(TestName()); // arg: @@ -1534,10 +1543,9 @@ TEST_F(HloEvaluatorTest, DynamicSliceUpdate) { Shape shape = ShapeUtil::MakeShape(F64, {2, 3}); b.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( shape, operand, update, start_indices)); - auto computation = module().AddEntryComputation(b.Build()); + module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + std::unique_ptr result = Evaluate(); auto expected = Literal::CreateR2({ {1, -2, -3}, @@ -1547,7 +1555,7 @@ TEST_F(HloEvaluatorTest, DynamicSliceUpdate) { LiteralTestUtil::ExpectEqual(*expected, *result); } -TEST_F(HloEvaluatorTest, SetAndGetTuples) { +TEST_P(HloEvaluatorTest, SetAndGetTuples) { HloComputation::Builder b(TestName()); // arg: @@ -1570,9 +1578,9 @@ TEST_F(HloEvaluatorTest, SetAndGetTuples) { Shape shape = ShapeUtil::MakeShape(F64, {2, 3}); b.AddInstruction(HloInstruction::CreateGetTupleElement(shape, tuple, 1)); - auto computation = module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + module().AddEntryComputation(b.Build()); + + std::unique_ptr result = Evaluate(); auto expected = Literal::CreateR2({ {1, 2, 3}, @@ -1582,7 +1590,7 @@ TEST_F(HloEvaluatorTest, SetAndGetTuples) { LiteralTestUtil::ExpectEqual(*expected, *result); } -TEST_F(HloEvaluatorTest, SetAndGetNestedTuples) { +TEST_P(HloEvaluatorTest, SetAndGetNestedTuples) { HloComputation::Builder b(TestName()); // arg: @@ -1609,9 +1617,9 @@ TEST_F(HloEvaluatorTest, SetAndGetNestedTuples) { b.AddInstruction( HloInstruction::CreateGetTupleElement(tuple2->shape(), outer_tuple, 1)); - auto computation = module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + module().AddEntryComputation(b.Build()); + + std::unique_ptr result = Evaluate(); auto result_inner_literal = Literal::CreateR2FromArray2D(*operand_array); @@ -1623,7 +1631,7 @@ TEST_F(HloEvaluatorTest, SetAndGetNestedTuples) { LiteralTestUtil::ExpectEqual(*expected, *result); } -TEST_F(HloEvaluatorTest, Reverse) { +TEST_P(HloEvaluatorTest, Reverse) { HloComputation::Builder b(TestName()); // Input shape is float[4x3x2x1]. @@ -1649,10 +1657,9 @@ TEST_F(HloEvaluatorTest, Reverse) { const Shape shape = ShapeUtil::MakeShape(F32, {4, 3, 2, 1}); b.AddInstruction(HloInstruction::CreateReverse(shape, operand, {0, 1})); - auto computation = module().AddEntryComputation(b.Build()); + module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + std::unique_ptr result = Evaluate(); // clang-format off auto expected = Literal::CreateR4FromArray4D({ @@ -1677,7 +1684,7 @@ TEST_F(HloEvaluatorTest, Reverse) { LiteralTestUtil::ExpectEqual(*expected, *result); } -TEST_F(HloEvaluatorTest, EvaluateWithSubstitutions) { +TEST_P(HloEvaluatorTest, EvaluateWithSubstitutions) { HloComputation::Builder b(TestName()); Shape shape = ShapeUtil::MakeShape(F32, {4}); @@ -1700,7 +1707,7 @@ TEST_F(HloEvaluatorTest, EvaluateWithSubstitutions) { // Check that EvaluateWithSubstitutions works if one of the operands to the op // we're evaluating is a constant. -TEST_F(HloEvaluatorTest, EvaluateWithSubstitutionsWithConstantOperand) { +TEST_P(HloEvaluatorTest, EvaluateWithSubstitutionsWithConstantOperand) { HloComputation::Builder b(TestName()); Shape shape = ShapeUtil::MakeShape(F32, {4}); @@ -1722,5 +1729,8 @@ TEST_F(HloEvaluatorTest, EvaluateWithSubstitutionsWithConstantOperand) { *result.ValueOrDie()); } +INSTANTIATE_TEST_CASE_P(HloEvaluatorTest_Instantiation, HloEvaluatorTest, + ::testing::ValuesIn(use_bf16_params)); + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_execution_profile.cc b/tensorflow/compiler/xla/service/hlo_execution_profile.cc index ba75e2ef1b485f015f3b8f8dbd76f214d6ab0130..849aac0b12b096e5f7c4a5c441fc019c48a27060 100644 --- a/tensorflow/compiler/xla/service/hlo_execution_profile.cc +++ b/tensorflow/compiler/xla/service/hlo_execution_profile.cc @@ -32,7 +32,7 @@ HloProfileIndexMap::HloProfileIndexMap(const HloModule& module) { InsertOrDie(&computation_to_profile_idx_, computation, current_profile_index++); for (const HloInstruction* instruction : computation->instructions()) { - // For simplicity we track all instrutions here, but we could skip + // For simplicity we track all instructions here, but we could skip // non-executing instructions like constants and parameters. InsertOrDie(&instruction_to_profile_idx_, instruction, current_profile_index++); @@ -76,8 +76,8 @@ std::unique_ptr CreateHloProfilePrinter( HloProfilePrinter::HloInstructionInfo* instruction_info = &computation_info->instructions[instruction_index_in_static_data++]; instruction_info->long_name = strdup(hlo->ToString().c_str()); - instruction_info->short_name = - strdup(hlo->ToString(/*compact_operands=*/true).c_str()); + instruction_info->short_name = strdup( + hlo->ToString(HloPrintOptions().set_compact_operands(true)).c_str()); instruction_info->category = strdup(hlo->ToCategory().c_str()); instruction_info->flop_count = cost_analysis.flop_count(*hlo); instruction_info->transcendental_count = @@ -109,7 +109,8 @@ std::unique_ptr CreateHloProfilePrinter( }; return MakeUnique( - computation_infos, hlo_profile_index_map.computation_count(), deleter); + computation_infos, hlo_profile_index_map.computation_count(), + /*profile_counters_size=*/max_profile_index, deleter); } HloExecutionProfile::HloExecutionProfile( diff --git a/tensorflow/compiler/xla/service/hlo_execution_profile.h b/tensorflow/compiler/xla/service/hlo_execution_profile.h index 470fd4ce3c205d84152238f4b18daad77e403f68..1a6b069609cb58bcc9659b4457453758a277bc0e 100644 --- a/tensorflow/compiler/xla/service/hlo_execution_profile.h +++ b/tensorflow/compiler/xla/service/hlo_execution_profile.h @@ -125,6 +125,9 @@ class HloExecutionProfile { } std::vector* mutable_profile_counters() { return &profile_counters_; } + const std::vector& profile_counters() const { + return profile_counters_; + } private: const HloProfilePrinter& hlo_profile_printer_; diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index 84187d578346eafd5e32727a15f5eab9cc79feef..f7c6435002d278d93cc0814041a7e055e5573e3e 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -34,6 +34,7 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/strings/numbers.h" @@ -508,8 +509,17 @@ stylesheet=" // The "to_node" value may be a NULL, indicating that this points to the // "root" tag rather than a normal node. - int64 from_node_id = node_ids_.at(from_node); - int64 to_node_id = to_node ? node_ids_.at(to_node) : root_node_id_; + int64 from_node_id = + tensorflow::gtl::FindWithDefault(node_ids_, from_node, -1); + if (from_node_id == -1) { + LOG(FATAL) << from_node->name() << " was added to edges but not to nodes"; + } + int64 to_node_id = + to_node ? tensorflow::gtl::FindWithDefault(node_ids_, to_node, -1) + : root_node_id_; + if (to_node != nullptr && to_node_id == -1) { + LOG(FATAL) << to_node->name() << " was added to edges but not to nodes"; + } add_hover_css_rule("node", from_node_id, kBlue); add_hover_css_rule("node", to_node_id, kRed); @@ -653,12 +663,15 @@ string HloDotDumper::DumpComputation(const HloComputation* comp) { string HloDotDumper::DumpRootTag() { const HloInstruction* from = GetNodeForEdge(computation_->root_instruction()); - auto from_id = InstructionId(from); - if (!filter_.Show(from)) { + // We didn't display constants as separate nodes; so if the root is a + // constant, we don't add root tag or edge for it. + if (!filter_.Show(from) || from->opcode() == HloOpcode::kConstant) { return ""; } + auto from_id = InstructionId(from); + // The ID of the root computation is otherwise unused, so it makes a good ID // to use for the root-tag node. However, the edge_ids_ map requires a // HloInstruction* pointer for the 'to' value, so we use a NULL value there @@ -784,7 +797,7 @@ string HloDotDumper::GetInstructionNodeInlinedOperands( // Otherwise, print e.g. "%constant.42 (s32[100])". string constant_name; - if (tensorflow::StringPiece(constant->name()).starts_with("%constant")) { + if (tensorflow::StringPiece(constant->name()).starts_with("constant")) { constant_name = constant->name(); } else { constant_name = StrCat("constant ", constant->name()); @@ -948,6 +961,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { return kGreen; case HloOpcode::kConvolution: case HloOpcode::kDot: + case HloOpcode::kFft: return kDarkBlue; case HloOpcode::kReducePrecision: return kRed; @@ -1000,7 +1014,7 @@ string HloDotDumper::GetInstructionNodeLabel(const HloInstruction* instr) { // The HLO instruction name contains usually the opcode, e.g. "%add.42" is // an add instruction. In this case we render just the name. if (tensorflow::StringPiece(instr->name()) - .starts_with(StrCat("%", HloOpcodeString(instr->opcode())))) { + .starts_with(HloOpcodeString(instr->opcode()))) { return Printf("%s", HtmlLikeStringSanitize(instr->name())); } string extended_opcode = @@ -1036,50 +1050,15 @@ string HloDotDumper::GetInstructionNodeMetadata(const HloInstruction* instr) { } string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) { - string opcode_specific_info = [&]() -> string { - switch (instr->opcode()) { - case HloOpcode::kRng: - return RandomDistribution_Name(instr->random_distribution()); - case HloOpcode::kConvolution: - return StrCat( - HtmlLikeStringSanitize( - instr->ConvolutionDimensionNumbersToString()), - "
", - HtmlLikeStringSanitize(window_util::ToString(instr->window()))); - case HloOpcode::kBroadcast: - case HloOpcode::kTranspose: - case HloOpcode::kReduce: - return Printf("dims={%s}", Join(instr->dimensions(), ",")); - case HloOpcode::kGetTupleElement: - return Printf("index=%lld", instr->tuple_index()); - case HloOpcode::kBatchNormTraining: - case HloOpcode::kBatchNormGrad: - return Printf("feature_index=%lld", instr->feature_index()); - case HloOpcode::kCustomCall: - return Printf("custom_call_target=%s", instr->custom_call_target()); - case HloOpcode::kSlice: - return std::all_of(instr->slice_strides().begin(), - instr->slice_strides().end(), - [](int64 stride) { return stride == 1; }) - ? "" - : StrCat("stride=", VectorString(instr->slice_strides())); - case HloOpcode::kSend: - case HloOpcode::kSendDone: - case HloOpcode::kRecv: - case HloOpcode::kRecvDone: - return StrCat("channel_id=", instr->channel_id()); - default: - return ""; - } - }(); - std::vector lines; - if (!opcode_specific_info.empty()) { - lines.push_back(opcode_specific_info); - } - if (instr->has_sharding()) { - lines.push_back(StrCat("sharding=", instr->sharding().ToString())); + + // Get the instruction's extra attributes excluding the names of its + // subcomputations, since those are drawn explicitly in the graph. + for (const auto& line : instr->ExtraAttributesToString( + HloPrintOptions().set_print_subcomputation_references(false))) { + lines.push_back(HtmlLikeStringSanitize(line)); } + // Show the shape and layout of the instruction, unless it's an inlined fusion // node -- there the shape and layout is present in the output node. if (instr->opcode() != HloOpcode::kFusion || @@ -1091,7 +1070,7 @@ string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) { instr->shape().dimensions_size() > 1 && !ShapeUtil::IsTuple(instr->shape())) { StrAppend(&instr_shape, "{", - Join(instr->shape().layout().minor_to_major(), ","), "}"); + Join(LayoutUtil::MinorToMajor(instr->shape()), ","), "}"); } // Some instructions have giant tuples as their shapes, so truncate the @@ -1353,19 +1332,16 @@ string SaveGraph(const string& graph, file_extension = ".pbtxt"; break; } - string path = JoinPath( - dest_path, StrCat("hlo_graph_", output_num++, ".XXXXXX", file_extension)); + string path = JoinPath(dest_path, StrCat("hlo_graph_", output_num++, ".")); auto status = Status::OK(); - int fd = mkstemps(&path[0], file_extension.length()); - if (fd < 0) { + auto env = tensorflow::Env::Default(); + if (!env->CreateUniqueFileName(&path, file_extension)) { status = Status(tensorflow::error::Code::UNKNOWN, StrCat("Failed to create temporary file to dump HLO graph: ", strerror(errno))); } else { - status = - tensorflow::WriteStringToFile(tensorflow::Env::Default(), path, graph); - close(fd); + status = tensorflow::WriteStringToFile(env, path, graph); } if (!status.ok()) { LOG(WARNING) << "Saving HLO graph failed: " << status; @@ -1438,7 +1414,8 @@ void DumpText(const HloModule& module, const string& label, do_prefix ? StrCat(prefix, "-", label, ".txt") : StrCat(label, ".txt"); string path = JoinPath(directory_path, filename); TF_CHECK_OK(WriteStringToFile( - env, path, module.ToString(/*include_large_constants=*/true))); + env, path, + module.ToString(HloPrintOptions().set_print_large_constants(true)))); LOG(INFO) << "dumping module '" << module.name() << "' to " << path; } diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc index 8e1531c87f9c6e133e2d6763b046b1d5dcbcd09f..1f00aa41dc783f9e5657f5fa654884a31fae0fe7 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc @@ -117,5 +117,18 @@ TEST(HloGraphDumperTest, NestedFusion) { HasSubstr(inner_sum->name())); } +TEST(HloGraphDumperTest, Constant) { + HloComputation::Builder b("b"); + auto instruction = b.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(-42))); + instruction->set_name("i_am_a_constant_root_instruction"); + HloModule m(TestName()); + HloComputation* root_computation = m.AddEntryComputation(b.Build()); + string graph = hlo_graph_dumper::DumpGraph( + *root_computation, /*label=*/"an_empty_graph", DebugOptions()); + EXPECT_THAT(graph, HasSubstr("an_empty_graph")); + EXPECT_THAT(graph, Not(HasSubstr("i_am_a_constant_root_instruction"))); +} + } // anonymous namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index c30c4326547bbeae4f7054974f0d3fade65e3382..90121f7ffe11b379bea9e83a483c7e752c97998c 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -101,10 +101,10 @@ StatusOr> HloInstruction::CreateFromProto( instruction->metadata_ = proto.metadata(); if (proto.has_literal()) { - instruction->literal_ = MakeUnique(proto.literal()); + TF_ASSIGN_OR_RETURN(instruction->literal_, + Literal::CreateFromProto(proto.literal())); } instruction->parameter_number_ = proto.parameter_number(); - instruction->parameter_name_ = proto.parameter_name(); instruction->tuple_index_ = proto.tuple_index(); for (int64 dimension : proto.dimensions()) { @@ -118,6 +118,10 @@ StatusOr> HloInstruction::CreateFromProto( MakeUnique( proto.convolution_dimension_numbers()); } + if (proto.has_dot_dimension_numbers()) { + instruction->dot_dimension_numbers_ = + MakeUnique(proto.dot_dimension_numbers()); + } for (const HloInstructionProto::SliceDimensions& slice_dimensions : proto.slice_dimensions()) { instruction->slice_starts_.push_back(slice_dimensions.start()); @@ -141,6 +145,10 @@ StatusOr> HloInstruction::CreateFromProto( instruction->infeed_config_ = proto.infeed_config(); instruction->custom_call_target_ = proto.custom_call_target(); instruction->outfeed_shape_ = proto.outfeed_shape(); + instruction->fft_type_ = proto.fft_type(); + for (int64 fft_len : proto.fft_length()) { + instruction->fft_length_.push_back(fft_len); + } return std::move(instruction); } @@ -150,7 +158,6 @@ StatusOr> HloInstruction::CreateFromProto( auto instruction = WrapUnique(new HloInstruction(HloOpcode::kParameter, shape)); instruction->parameter_number_ = parameter_number; - instruction->parameter_name_ = name; instruction->name_ = name; return instruction; } @@ -160,8 +167,7 @@ StatusOr> HloInstruction::CreateFromProto( auto instruction = WrapUnique(new HloInstruction(HloOpcode::kTrace, ShapeUtil::MakeNil())); instruction->operands_.push_back(operand); - instruction->literal_.reset(new Literal); - instruction->literal_->append_u8s(tag); + instruction->literal_ = Literal::CreateR1U8(tag); return instruction; } @@ -332,6 +338,41 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, return instruction; } +/* static */ std::unique_ptr HloInstruction::CreateFft( + const Shape& shape, HloInstruction* operand, FftType fft_type, + tensorflow::gtl::ArraySlice fft_length) { + auto instruction = WrapUnique(new HloInstruction(HloOpcode::kFft, shape)); + instruction->AppendOperand(operand); + instruction->fft_type_ = fft_type; + instruction->fft_length_.assign(fft_length.begin(), fft_length.end()); + return instruction; +} + +/* static */ std::unique_ptr HloInstruction::CreateDot( + const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, + const DotDimensionNumbers& dimension_numbers) { + auto instruction = WrapUnique(new HloInstruction(HloOpcode::kDot, shape)); + instruction->AppendOperand(lhs); + instruction->AppendOperand(rhs); + instruction->dot_dimension_numbers_ = + MakeUnique(dimension_numbers); + return instruction; +} + +/* static */ std::unique_ptr HloInstruction::CreateCanonicalDot( + const Shape& shape, HloInstruction* lhs, HloInstruction* rhs) { + CHECK_EQ(ShapeUtil::Rank(lhs->shape()), 2); + CHECK_EQ(ShapeUtil::Rank(rhs->shape()), 2); + + auto instruction = WrapUnique(new HloInstruction(HloOpcode::kDot, shape)); + instruction->AppendOperand(lhs); + instruction->AppendOperand(rhs); + instruction->dot_dimension_numbers_ = MakeUnique(); + instruction->dot_dimension_numbers_->add_lhs_contracting_dimensions(1); + instruction->dot_dimension_numbers_->add_rhs_contracting_dimensions(0); + return instruction; +} + /* static */ std::unique_ptr HloInstruction::CreateReducePrecision(const Shape& shape, HloInstruction* operand, @@ -346,12 +387,9 @@ HloInstruction::CreateReducePrecision(const Shape& shape, } /* static */ std::unique_ptr -HloInstruction::CreateCrossReplicaSum(const Shape& shape, - HloInstruction* operand) { - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kCrossReplicaSum, shape)); - instruction->AppendOperand(operand); - return instruction; +HloInstruction::CreateCrossReplicaSum( + const Shape& shape, tensorflow::gtl::ArraySlice operands) { + return CreateNary(shape, HloOpcode::kCrossReplicaSum, operands); } /* static */ std::unique_ptr HloInstruction::CreateInfeed( @@ -670,10 +708,26 @@ HloInstruction::CreateSelectAndScatter( return instruction; } +// We put the fusion kind into the instruction's name for transpose-dot and +// backward-conv fusions, since those fusions are really just describing a type +// of dot/conv rather than generating a novel computation. +static string FusionNodeName(HloInstruction::FusionKind fusion_kind) { + switch (fusion_kind) { + case HloInstruction::FusionKind::kTransposeDot: + return "dot_fusion"; + case HloInstruction::FusionKind::kConvBackwardInput: + case HloInstruction::FusionKind::kConvBackwardFilter: + return "conv_fusion"; + default: + return "fusion"; + } +} + /* static */ std::unique_ptr HloInstruction::CreateFusion( const Shape& shape, FusionKind fusion_kind, HloInstruction* fused_root) { auto instruction = WrapUnique(new HloInstruction(HloOpcode::kFusion, shape)); instruction->fusion_kind_ = fusion_kind; + instruction->name_ = FusionNodeName(fusion_kind); instruction->set_parent(fused_root->parent()); instruction->set_metadata(fused_root->metadata()); instruction->CloneAndFuseInternal(fused_root); @@ -689,6 +743,7 @@ HloInstruction::CreateSelectAndScatter( instruction->AppendOperand(operand); } instruction->fusion_kind_ = fusion_kind; + instruction->name_ = FusionNodeName(fusion_kind); instruction->called_computations_.push_back(fusion_computation); fusion_computation->SetFusionInstruction(instruction.get()); return instruction; @@ -985,6 +1040,7 @@ bool HloInstruction::HasSideEffect() const { case HloOpcode::kSendDone: case HloOpcode::kRecv: case HloOpcode::kRecvDone: + case HloOpcode::kRng: case HloOpcode::kInfeed: case HloOpcode::kOutfeed: case HloOpcode::kTrace: @@ -1086,7 +1142,6 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( case HloOpcode::kLe: case HloOpcode::kLt: case HloOpcode::kNe: - case HloOpcode::kDot: case HloOpcode::kMaximum: case HloOpcode::kMinimum: case HloOpcode::kPower: @@ -1138,9 +1193,16 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( clone = CreateConvolve(shape, new_operands[0], new_operands[1], *window_, *convolution_dimension_numbers_); break; - case HloOpcode::kCrossReplicaSum: + case HloOpcode::kDot: + CHECK_EQ(new_operands.size(), 2); + clone = CreateDot(shape, new_operands[0], new_operands[1], + *dot_dimension_numbers_); + break; + case HloOpcode::kFft: CHECK_EQ(new_operands.size(), 1); - clone = CreateCrossReplicaSum(shape, new_operands[0]); + return CreateFft(shape, new_operands[0], fft_type_, fft_length_); + case HloOpcode::kCrossReplicaSum: + clone = CreateCrossReplicaSum(shape, new_operands); break; case HloOpcode::kGetTupleElement: CHECK_EQ(new_operands.size(), 1); @@ -1215,7 +1277,7 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( clone = CloneFusionWithNewOperands(shape, new_operands, module); break; case HloOpcode::kParameter: - clone = CreateParameter(parameter_number_, shape, parameter_name_); + clone = CreateParameter(parameter_number_, shape, name_); break; case HloOpcode::kBatchNormTraining: CHECK_EQ(new_operands.size(), 3); @@ -1244,10 +1306,27 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( new_operands[4], epsilon(), feature_index()); break; case HloOpcode::kConditional: - case HloOpcode::kRecv: - case HloOpcode::kRecvDone: + CHECK_EQ(new_operands.size(), 3); + clone = CreateConditional(shape, new_operands[0], new_operands[1], + true_computation(), new_operands[2], + false_computation()); + break; case HloOpcode::kSend: + CHECK_EQ(new_operands.size(), 1); + clone = CreateSend(new_operands[0], channel_id()); + break; case HloOpcode::kSendDone: + CHECK_EQ(new_operands.size(), 1); + clone = CreateSendDone(new_operands[0]); + break; + case HloOpcode::kRecv: + CHECK_EQ(new_operands.size(), 0); + clone = CreateRecv(shape, channel_id()); + break; + case HloOpcode::kRecvDone: + CHECK_EQ(new_operands.size(), 1); + clone = CreateRecvDone(new_operands[0]); + break; case HloOpcode::kTrace: LOG(FATAL) << "Not yet implemented, clone: " << HloOpcodeString(opcode_); } @@ -1492,7 +1571,7 @@ bool HloInstruction::HasConstantOperand() const { bool HloInstruction::IdenticalSlowPath( const HloInstruction& other, - std::function + const std::function& eq_computations) const { // Perform opcode specific checks. switch (opcode()) { @@ -1509,7 +1588,6 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kCos: case HloOpcode::kCrossReplicaSum: case HloOpcode::kDivide: - case HloOpcode::kDot: case HloOpcode::kEq: case HloOpcode::kExp: case HloOpcode::kFloor: @@ -1582,6 +1660,15 @@ bool HloInstruction::IdenticalSlowPath( protobuf_util::ProtobufEquals( convolution_dimension_numbers(), other.convolution_dimension_numbers()); + // Check dot dimension numbers. + case HloOpcode::kDot: + return protobuf_util::ProtobufEquals(dot_dimension_numbers(), + other.dot_dimension_numbers()); + + // FFT has various types & lengths. + case HloOpcode::kFft: + return fft_type() == other.fft_type() && + fft_length() == other.fft_length(); // Reduction results are determined by the reduction dimension and the // reduction computation. @@ -1636,9 +1723,11 @@ bool HloInstruction::IdenticalSlowPath( return custom_call_target_ == other.custom_call_target_; case HloOpcode::kReverse: return dimensions() == other.dimensions(); + case HloOpcode::kConditional: + return eq_computations(true_computation(), other.true_computation()) && + eq_computations(false_computation(), other.false_computation()); // These opcodes are not yet supported. - case HloOpcode::kConditional: case HloOpcode::kInfeed: case HloOpcode::kOutfeed: case HloOpcode::kSort: @@ -1882,16 +1971,23 @@ string HloInstruction::SignatureString() const { return StrCat("(", operands, ") -> ", ShapeUtil::HumanString(shape())); } -string HloInstruction::ToString(bool compact_operands, bool include_metadata, - bool include_large_constants) const { +namespace { + +string PrintName(const string& name, const HloPrintOptions& options) { + return StrCat(options.print_percent() ? "%" : "", name); +} + +} // namespace + +string HloInstruction::ToString(const HloPrintOptions& options) const { string result = - StrCat("%", name(), " = ", ShapeUtil::HumanStringWithLayout(shape()), " ", - HloOpcodeString(opcode()), "(", - OperandsToString(compact_operands, include_large_constants), ")"); - for (const string& extra : ExtraAttributesToString()) { + StrCat(PrintName(name(), options), " = ", + ShapeUtil::HumanStringWithLayout(shape()), " ", + HloOpcodeString(opcode()), "(", OperandsToString(options), ")"); + for (const string& extra : ExtraAttributesToString(options)) { StrAppend(&result, ", ", extra); } - if (include_metadata && + if (options.print_metadata() && (!metadata_.op_type().empty() || !metadata_.op_name().empty() || !metadata_.source_file().empty())) { StrAppend(&result, ", metadata={", xla::OpMetadataToString(metadata_), "}"); @@ -1899,14 +1995,13 @@ string HloInstruction::ToString(bool compact_operands, bool include_metadata, return result; } -string HloInstruction::OperandsToString(bool compact, - bool include_large_constants) const { +string HloInstruction::OperandsToString(const HloPrintOptions& options) const { string operands; if (opcode() == HloOpcode::kConstant) { // For constants, show the actual value in place of an empty operand list. if ((!ShapeUtil::IsTuple(shape()) && ShapeUtil::ElementsIn(shape()) <= 10) || - include_large_constants) { + options.print_large_constants()) { // Literal::ToString emits multidimensional arrays over multiple // lines. Compact this into one line by stripping out white space. string tmp = literal().ToString(); @@ -1931,14 +2026,19 @@ string HloInstruction::OperandsToString(bool compact, } else { tensorflow::gtl::ArraySlice slice(operands_); const int64 kMaxOperandsToShowIfCompact = 4; - if (compact && slice.size() > kMaxOperandsToShowIfCompact) { + if (options.compact_operands() && + slice.size() > kMaxOperandsToShowIfCompact) { slice.remove_suffix(slice.size() - kMaxOperandsToShowIfCompact); } operands = Join(slice, ", ", [&](string* out, HloInstruction* operand) { - *out += ShapeUtil::HumanStringWithLayout(operand->shape()); - if (!compact) { - StrAppend(out, " %", operand->name()); + std::vector str; + if (options.print_operand_shape()) { + str.push_back(ShapeUtil::HumanStringWithLayout(operand->shape())); + } + if (!options.compact_operands()) { + str.push_back(PrintName(operand->name(), options)); } + StrAppend(out, Join(str, " ")); }); const int64 remaining = operands_.size() - slice.size(); if (slice.size() != operands_.size()) { @@ -1948,7 +2048,8 @@ string HloInstruction::OperandsToString(bool compact, return operands; } -std::vector HloInstruction::ExtraAttributesToString() const { +std::vector HloInstruction::ExtraAttributesToString( + const HloPrintOptions& options) const { std::vector extra; if (opcode() == HloOpcode::kFusion) { extra.push_back(StrCat("kind=", xla::ToString(fusion_kind()))); @@ -1990,23 +2091,42 @@ std::vector HloInstruction::ExtraAttributesToString() const { if (convolution_dimension_numbers_ != nullptr) { extra.push_back(ConvolutionDimensionNumbersToString()); } - - if (opcode() == HloOpcode::kWhile) { - extra.push_back(StrCat("condition=%", while_condition()->name())); - extra.push_back(StrCat("body=%", while_body()->name())); - } else if (opcode() == HloOpcode::kSelectAndScatter) { - extra.push_back(StrCat("select=%", select()->name())); - extra.push_back(StrCat("scatter=%", scatter()->name())); - } else if (opcode() == HloOpcode::kCall || opcode() == HloOpcode::kMap || - opcode() == HloOpcode::kReduceWindow || - opcode() == HloOpcode::kReduce) { - extra.push_back(StrCat("to_apply=%", to_apply()->name())); - } else if (!called_computations().empty()) { - extra.push_back(StrCat( - "calls=", Join(called_computations(), ", ", - [](string* out, const HloComputation* computation) { - StrAppend(out, "%", computation->name()); - }))); + if (dot_dimension_numbers_ != nullptr) { + extra.push_back(DotDimensionNumbersToString()); + } + if (opcode() == HloOpcode::kFft) { + extra.push_back(StrCat("fft_type=", FftType_Name(fft_type()))); + extra.push_back(StrCat("fft_length={", Join(fft_length(), ","), "}")); + } + + if (options.print_subcomputation_references()) { + if (opcode() == HloOpcode::kWhile) { + extra.push_back( + StrCat("condition=", PrintName(while_condition()->name(), options))); + extra.push_back( + StrCat("body=", PrintName(while_body()->name(), options))); + } else if (opcode() == HloOpcode::kSelectAndScatter) { + extra.push_back(StrCat("select=", PrintName(select()->name(), options))); + extra.push_back( + StrCat("scatter=", PrintName(scatter()->name(), options))); + } else if (opcode() == HloOpcode::kConditional) { + extra.push_back(StrCat("true_computation=", + PrintName(true_computation()->name(), options))); + extra.push_back(StrCat("false_computation=", + PrintName(false_computation()->name(), options))); + } else if (opcode() == HloOpcode::kCall || opcode() == HloOpcode::kMap || + opcode() == HloOpcode::kReduceWindow || + opcode() == HloOpcode::kReduce) { + extra.push_back( + StrCat("to_apply=", PrintName(to_apply()->name(), options))); + } else if (!called_computations().empty()) { + extra.push_back(StrCat( + "calls=", Join(called_computations(), ", ", + [&](string* out, const HloComputation* computation) { + StrAppend(out, + PrintName(computation->name(), options)); + }))); + } } if (opcode() == HloOpcode::kSend || opcode() == HloOpcode::kRecv || @@ -2023,8 +2143,9 @@ std::vector HloInstruction::ExtraAttributesToString() const { if (!control_predecessors_.empty()) { extra.push_back(StrCat("control-predecessors={", Join(control_predecessors_, ", ", - [](string* out, HloInstruction* pre) { - StrAppend(out, "%", pre->name()); + [&](string* out, HloInstruction* pre) { + StrAppend(out, + PrintName(pre->name(), options)); }), "}")); } @@ -2035,6 +2156,22 @@ std::vector HloInstruction::ExtraAttributesToString() const { extra.push_back( StrCat("outfeed_config=\"", CEscape(outfeed_config_), "\"")); } + if (opcode() == HloOpcode::kRng) { + extra.push_back( + StrCat("distribution=", RandomDistributionToString(distribution_))); + } + if (opcode() == HloOpcode::kReducePrecision) { + extra.push_back(StrCat("exponent_bits=", exponent_bits_)); + extra.push_back(StrCat("mantissa_bits=", mantissa_bits_)); + } + + // By contract, we print the custom call target even if + // !options.print_subcomputation_references(), because the call target is not + // an HloComputation. + if (opcode() == HloOpcode::kCustomCall) { + extra.push_back( + StrCat("custom_call_target=\"", CEscape(custom_call_target_), "\"")); + } return extra; } @@ -2064,7 +2201,6 @@ HloInstructionProto HloInstruction::ToProto() const { *proto.mutable_literal() = literal_->ToProto(); } proto.set_parameter_number(parameter_number_); - proto.set_parameter_name(parameter_name_); if (opcode() == HloOpcode::kFusion) { proto.set_fusion_kind(xla::ToString(fusion_kind())); *proto.mutable_fused_instructions_computation() = @@ -2086,6 +2222,9 @@ HloInstructionProto HloInstruction::ToProto() const { *proto.mutable_convolution_dimension_numbers() = *convolution_dimension_numbers_; } + if (dot_dimension_numbers_ != nullptr) { + *proto.mutable_dot_dimension_numbers() = *dot_dimension_numbers_; + } for (int i = 0; i < slice_starts_.size(); ++i) { auto* slice_dimension = proto.add_slice_dimensions(); slice_dimension->set_start(slice_starts_[i]); @@ -2110,6 +2249,10 @@ HloInstructionProto HloInstruction::ToProto() const { proto.set_infeed_config(infeed_config_); proto.set_custom_call_target(custom_call_target_); *proto.mutable_outfeed_shape() = outfeed_shape_; + proto.set_fft_type(fft_type_); + for (int64 fft_len : fft_length_) { + proto.add_fft_length(fft_len); + } return proto; } @@ -2120,7 +2263,7 @@ string HloInstruction::ToCategory() const { return "data formatting"; } - if (opcode() == HloOpcode::kConvolution) { + auto conv_category = [&] { string category = "convolution"; if (window_util::HasBaseDilation(window())) { category += " base-dilated"; @@ -2129,44 +2272,36 @@ string HloInstruction::ToCategory() const { category += " window-dilated"; } return category; + }; + + if (opcode() == HloOpcode::kConvolution) { + return conv_category(); } + // Give transpose-dot and backwards-conv fusions the categories "dot" and + // "convolution" so they match the categories of proper kDot and kConvolution + // ops. These fusion categories are really just a way of expressing a + // particular kind of dot or conv, so they should have the same category as a + // vanilla dot/conv. if (opcode() == HloOpcode::kFusion) { - if (operands().size() == 2) { - bool saw_rank_1 = false; - bool saw_higher_rank = false; - for (const auto* operand : operands()) { - if (!ShapeUtil::IsTuple(operand->shape())) { - saw_rank_1 |= ShapeUtil::Rank(operand->shape()) == 1; - saw_higher_rank |= ShapeUtil::Rank(operand->shape()) > 1; - } - } - if (saw_rank_1 && saw_higher_rank) { - return "rank-1-broadcast binary fusion"; - } - } switch (fusion_kind()) { case FusionKind::kLoop: - if (IsElementwise()) { - return "elementwise fusion"; - } else { - return "non-elementwise fusion"; - } + return "loop fusion"; case FusionKind::kInput: return "input fusion"; case FusionKind::kOutput: return "output fusion"; case FusionKind::kTransposeDot: - return "dot fusion"; + return "dot"; case FusionKind::kConvBackwardFilter: case FusionKind::kConvBackwardInput: - return "convolution fusion"; + return conv_category(); case FusionKind::kCustom: return "custom fusion"; } } - if (IsElementwise() && opcode() != HloOpcode::kFusion) { + if (IsElementwise()) { return "non-fusion elementwise"; } @@ -2182,7 +2317,7 @@ void HloInstruction::set_tracing(HloInstruction* trace_instruction) { string HloInstruction::TracingTag() const { CHECK_EQ(HloOpcode::kTrace, opcode()); CHECK(literal_ != nullptr); - return literal_->u8s_string(); + return literal_->GetR1U8AsString(); } bool HloInstruction::IsFused() const { return parent_->IsFusionComputation(); } @@ -2325,6 +2460,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase* visitor) { return visitor->HandleSelect(this); case HloOpcode::kConvolution: return visitor->HandleConvolution(this); + case HloOpcode::kFft: + return visitor->HandleFft(this); case HloOpcode::kCrossReplicaSum: return visitor->HandleCrossReplicaSum(this); case HloOpcode::kTuple: @@ -3001,6 +3138,28 @@ string OpMetadataToString(const OpMetadata& metadata) { return Join(result, " "); } +string RandomDistributionToString(const RandomDistribution& distribution) { + return tensorflow::str_util::Lowercase(RandomDistribution_Name(distribution)); +} + +StatusOr StringToRandomDistribution(const string& name) { + static std::unordered_map* map = [] { + static auto* map = new std::unordered_map; + for (int i = 0; i < RandomDistribution_ARRAYSIZE; i++) { + if (RandomDistribution_IsValid(i)) { + auto value = static_cast(i); + (*map)[RandomDistributionToString(value)] = value; + } + } + return map; + }(); + auto found = map->find(tensorflow::str_util::Lowercase(name)); + if (found == map->end()) { + return InvalidArgument("Unknown distribution"); + } + return found->second; +} + std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind) { return os << ToString(kind); } @@ -3051,6 +3210,29 @@ string HloInstruction::ConvolutionDimensionNumbersToString() const { return result; } +string HloInstruction::DotDimensionNumbersToString() const { + std::vector result; + if (dot_dimension_numbers_ == nullptr) { + return ""; + } + const DotDimensionNumbers& dnums = *dot_dimension_numbers_; + if (!dnums.lhs_batch_dimensions().empty()) { + result.push_back(StrCat("lhs_batch_dims={", + Join(dnums.lhs_batch_dimensions(), ","), "}")); + } + result.push_back(StrCat("lhs_contracting_dims={", + Join(dnums.lhs_contracting_dimensions(), ","), "}")); + + if (!dnums.rhs_batch_dimensions().empty()) { + result.push_back(StrCat("rhs_batch_dims={", + Join(dnums.rhs_batch_dimensions(), ","), "}")); + } + result.push_back(StrCat("rhs_contracting_dims={", + Join(dnums.rhs_contracting_dimensions(), ","), "}")); + + return Join(result, ", "); +} + bool HloInstruction::CouldBeBitcast() const { switch (opcode_) { case HloOpcode::kTranspose: diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index cda8b07c61e2b36a83184648f6f3744deeb86812..e700ec1d2903ac0bb77e36097c3e1e582206e4d5 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -25,6 +25,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -56,6 +57,107 @@ namespace xla { class HloComputation; class HloModule; +// A bunch of switches that control how the hlo text should be printed. +class HloPrintOptions { + public: + // Constructs the default print options: don't print large constants, don't + // compact operands, no indentation. + HloPrintOptions() + : print_large_constants_(false), + print_subcomputation_references_(true), + print_metadata_(true), + compact_operands_(false), + print_operand_shape_(true), + print_program_shape_(true), + print_percent_(true), + indent_amount_(0) {} + + static HloPrintOptions ShortParsable() { + return HloPrintOptions() + .set_print_large_constants(true) + .set_print_subcomputation_references(true) + .set_print_metadata(false) + .set_print_operand_shape(false) + .set_print_program_shape(false) + .set_print_percent(false); + } + + // If true, large constants will be printed out. + HloPrintOptions& set_print_large_constants(bool value) { + print_large_constants_ = value; + return *this; + } + + // If true, the names of subcomputations (e.g. a fusion node's fused + // computation) won't be printed. This makes the resulting text not parsable. + // + // A CustomCall's call target is printed even if + // print_subcomputation_references is false, because the call target isn't an + // HloComputation. + HloPrintOptions& set_print_subcomputation_references(bool value) { + print_subcomputation_references_ = value; + return *this; + } + + // If true, metatdata will be printed. + HloPrintOptions& set_print_metadata(bool value) { + print_metadata_ = value; + return *this; + } + + // If true, operands' shapes will be printed. + HloPrintOptions& set_print_operand_shape(bool value) { + print_operand_shape_ = value; + return *this; + } + + // If true, program shape of hlo computations will be printed. + HloPrintOptions& set_print_program_shape(bool value) { + print_program_shape_ = value; + return *this; + } + + // If true, names will be printed with prefix '%'. + HloPrintOptions& set_print_percent(bool value) { + print_percent_ = value; + return *this; + } + + // If true, only a part of operands will be printed out, and their names will + // be omitted (note that in this case the text will not be parsable). + HloPrintOptions& set_compact_operands(bool value) { + compact_operands_ = value; + return *this; + } + + // The indent of the hlo text block. + HloPrintOptions& set_indent_amount(int value) { + indent_amount_ = value; + return *this; + } + + bool print_large_constants() const { return print_large_constants_; } + bool print_subcomputation_references() const { + return print_subcomputation_references_; + } + bool print_metadata() const { return print_metadata_; } + bool compact_operands() const { return compact_operands_; } + bool print_operand_shape() const { return print_operand_shape_; } + bool print_program_shape() const { return print_program_shape_; } + bool print_percent() const { return print_percent_; } + int indent_amount() const { return indent_amount_; } + + private: + bool print_large_constants_; + bool print_subcomputation_references_; + bool print_metadata_; + bool compact_operands_; + bool print_operand_shape_; + bool print_program_shape_; + bool print_percent_; + int indent_amount_; +}; + // HLO instructions are the IR used by the high-level compiler. class HloInstruction { public: @@ -160,6 +262,23 @@ class HloInstruction { const Window& window, const ConvolutionDimensionNumbers& dimension_numbers); + // Creates an FFT op, of the type indicated by fft_type. + static std::unique_ptr CreateFft( + const Shape& shape, HloInstruction* operand, FftType fft_type, + tensorflow::gtl::ArraySlice fft_length); + + // Creates a dot op with operands 'lhs' and 'rhs' with contracting and batch + // dimensions specified in 'dimension_numbers'. + static std::unique_ptr CreateDot( + const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, + const DotDimensionNumbers& dimension_numbers); + + // Creates a dot op with operands 'lhs' and 'rhs' that contracts dimension 1 + // of the LHS with dimension 0 of the RHS with no batch dimensions. Both LHS + // and the RHS must be of rank 2. + static std::unique_ptr CreateCanonicalDot( + const Shape& shape, HloInstruction* lhs, HloInstruction* rhs); + // Creates a reduce-precision op, where operand is the data to reduce in // precision, and exponent_bits and mantissa_bits describe the precision to // reduce it to. @@ -169,7 +288,8 @@ class HloInstruction { // Creates a cross replica sum op. static std::unique_ptr CreateCrossReplicaSum( - const Shape& shape, HloInstruction* operand); + const Shape& shape, + tensorflow::gtl::ArraySlice operands); // Creates a conversion instruction, where operand is the data to convert and // shape is the target shape for the conversion. @@ -421,7 +541,7 @@ class HloInstruction { Status RemoveControlDependencyTo(HloInstruction* instruction); // Returns the set of control predecessors (successors) of this - // instruction. Control predecessors (sucessors) must execute before (after) + // instruction. Control predecessors (successors) must execute before (after) // the current instruction. const std::vector& control_predecessors() const { return control_predecessors_; @@ -434,9 +554,9 @@ class HloInstruction { // Layout of the instructions' output array is not considered. bool Identical( const HloInstruction& other, - std::function + const std::function& eq_operands = std::equal_to(), - std::function + const std::function& eq_computations = std::equal_to()) const { // An instruction is always identical to itself. if (this == &other) { @@ -446,11 +566,19 @@ class HloInstruction { // Identical instruction must have the same opcode and identical operands. // In general, there is no need to check shape because shape is inferred // from the shape of the operands. - if (opcode() != other.opcode() || - !ContainersEqual(operands(), other.operands(), - std::move(eq_operands))) { + if (opcode() != other.opcode()) { + return false; + } + if (operands().size() != other.operands().size()) { return false; } + // Use an explicit loop rather than ContainerEquals, because copying around + // std::functions may be too expensive in some cases. + for (size_t i = 0; i < operands().size(); ++i) { + if (!eq_operands(operand(i), other.operand(i))) { + return false; + } + } return IdenticalSlowPath(other, eq_computations); } @@ -540,16 +668,6 @@ class HloInstruction { return parameter_number_; } - const string& parameter_name() const { - CHECK_EQ(HloOpcode::kParameter, opcode_); - return parameter_name_; - } - - void set_parameter_name(const string& str) { - CHECK_EQ(HloOpcode::kParameter, opcode_); - parameter_name_ = str; - } - // Returns the dimension sizes or numbers associated with this instruction. // // Precondition: opcode() is one of: concatenate, reduce, broadcast, reshape, @@ -637,18 +755,20 @@ class HloInstruction { string SignatureString() const; // Returns a debugging string that represents this instruction. - string ToString(bool compact_operands = false, bool include_metadata = true, - bool include_large_constants = false) const; + // + // (We express the default options using an overload rather than a default + // param because gdb ignores default params, but does resolve overloads.) + string ToString() const { return ToString(HloPrintOptions()); } + string ToString(const HloPrintOptions& options) const; // Components of the ToString() representation: // Returns a string representation of the operand list. - string OperandsToString(bool compact, bool include_large_constants) const; + string OperandsToString(const HloPrintOptions& options) const; // Returns string representation of op-specific attributes. - std::vector ExtraAttributesToString() const; - - string ToStringNoMetadata() const { return ToString(false, false); } + std::vector ExtraAttributesToString( + const HloPrintOptions& options) const; // As ToString, but returns a shorter string. string ToShortString() const; @@ -676,13 +796,15 @@ class HloInstruction { // Returns feature_index field associated with the instruction. The index // represents the index of the feature dimension. // - // Precondition: opcode() == HloOpcode::kBatchNormTraining + // Precondition: opcode() is one of kBatchNormTraining, kBatchNormInference, + // or kBatchNormGrad. int64 feature_index() const { return feature_index_; } // Returns a epsilon value associated with the instruction. The is a small // number added to the variance to avoid divide-by-zero error. // - // Precondition: opcode() == HloOpcode::kBatchNormTraining + // Precondition: opcode() is one of kBatchNormTraining, kBatchNormInference, + // or kBatchNormGrad. float epsilon() const { return epsilon_; } // Returns the infeed configuration string. The infeed configuration includes @@ -856,6 +978,17 @@ class HloInstruction { } const std::vector& slice_strides() const { return slice_strides_; } + // Returns the flag that describes whether a slice must be lowered into an + // offset into the original operand. + bool IsInPlaceSlice() const { return is_in_place_slice_; } + + // Sets and returns the flag that describes whether a slice must be lowered + // into an offset into the original operand. + bool SetIsInPlaceSlice(bool value) { + is_in_place_slice_ = value; + return value; + } + // Returns the size of the slice in the given dimension for a dynamic // slice node. // @@ -912,9 +1045,28 @@ class HloInstruction { return *convolution_dimension_numbers_; } + FftType fft_type() const { + CHECK_EQ(HloOpcode::kFft, opcode_); + return fft_type_; + } + + const std::vector& fft_length() const { + CHECK_EQ(HloOpcode::kFft, opcode_); + return fft_length_; + } + // Returns the dump string of the convolution dimension numbers. string ConvolutionDimensionNumbersToString() const; + // Returns data on the dimension numbers used for a dot operation. + const DotDimensionNumbers& dot_dimension_numbers() const { + CHECK(dot_dimension_numbers_ != nullptr); + return *dot_dimension_numbers_; + } + + // Returns the dump string of the dot dimension numbers. + string DotDimensionNumbersToString() const; + // Returns the random distribution for this rng node. // // Precondition: opcode() == HloOpcode::kRng @@ -1006,10 +1158,9 @@ class HloInstruction { std::tuple, std::vector> ReshapeMerelyInsertsOrDeletes1SizedDimensions() const; - // Returns a string identifier for this instruction. If no string identifier - // has been explicitly set, then the identifier is the serialized pointer to - // this instruction. + // Gets/sets the string identifier for this instruction. const string& name() const { return name_; } + void set_name(tensorflow::StringPiece name) { name_ = name.ToString(); } // Use the given NameUniquer to select a unique name for the instruction based // on the instruction's existing name. @@ -1070,7 +1221,7 @@ class HloInstruction { // See comments on Identical(). bool IdenticalSlowPath( const HloInstruction& other, - std::function + const std::function& eq_computations) const; // Creates an n-ary elementwise operation. @@ -1173,11 +1324,23 @@ class HloInstruction { // Describes the dimension numbers used for a convolution. std::unique_ptr convolution_dimension_numbers_; + // Describes the dimension numbers used for a dot. + std::unique_ptr dot_dimension_numbers_; + + // Describes FFT type for an FFT instruction. + FftType fft_type_ = FftType::FFT; + + // Indicates the FFT length for an FFT instruction. + std::vector fft_length_; + // Describes the [begin, end) index range for a slice. std::vector slice_starts_; std::vector slice_limits_; std::vector slice_strides_; + // Describes whether the slice can be lowered to an offset into the operand. + bool is_in_place_slice_ = false; + // The bit sizes for a reduce-precision operation. int32 exponent_bits_ = 0; int32 mantissa_bits_ = 0; @@ -1198,7 +1361,6 @@ class HloInstruction { // For parameter instructions this field holds the parameter number. int64 parameter_number_ = 0; - string parameter_name_; // Name of a global symbol to call, only present for kCustomCall. string custom_call_target_; @@ -1267,9 +1429,12 @@ string ToString(HloInstruction::FusionKind kind); StatusOr StringToFusionKind( const string& kind_name); -// Custom stringification functions for protos that live inside HloInstruction. +// Custom (de)stringification functions for protos that live inside +// HloInstruction. string PaddingConfigToString(const PaddingConfig& padding); string OpMetadataToString(const OpMetadata& metadata); +string RandomDistributionToString(const RandomDistribution& distribution); +StatusOr StringToRandomDistribution(const string& name); std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind); @@ -1295,6 +1460,10 @@ template using ConstHloInstructionMap = std::map; +using HloInstructionSet = std::set; +using ConstHloInstructionSet = + std::set; + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTION_H_ diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc index 76b12fc8d3aadc0a874ce059851666fbcd6a4e94..3af3b29cedd06996dd4a175fdb1584c705ceea87 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc @@ -1068,8 +1068,11 @@ TEST_F(HloInstructionTest, CloneOfFusionPreservesShape) { builder.AddInstruction(HloInstruction::CreateParameter(1, s2, "y")); HloInstruction* reshape = builder.AddInstruction(HloInstruction::CreateTranspose(s2t, y, {1, 0})); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); HloInstruction* dot = builder.AddInstruction( - HloInstruction::CreateBinary(sout, HloOpcode::kDot, x, reshape)); + HloInstruction::CreateDot(sout, x, reshape, dot_dnums)); HloModule module(TestName()); auto* computation = module.AddEntryComputation(builder.Build()); @@ -1088,48 +1091,6 @@ TEST_F(HloInstructionTest, CloneOfFusionPreservesShape) { root2->operand(1)->operand(0)->shape())); } -TEST_F(HloInstructionTest, IsRandomFusable) { - auto shape = ShapeUtil::MakeShape(F32, {2, 2}); - { - auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); - auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR0(0.0))); - auto const1 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR0(1.0))); - auto rng = builder.AddInstruction(HloInstruction::CreateRng( - shape, RandomDistribution::RNG_NORMAL, {const0, const1})); - - auto* computation = hlo_module->AddEntryComputation(builder.Build()); - computation->CreateFusionInstruction({rng, const0, const1}, - HloInstruction::FusionKind::kLoop); - - auto* root = computation->root_instruction(); - - EXPECT_EQ(HloOpcode::kFusion, root->opcode()); - } - { - auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); - auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR0(0.0))); - auto const1 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR0(1.0))); - auto rng = builder.AddInstruction(HloInstruction::CreateRng( - shape, RandomDistribution::RNG_NORMAL, {const0, const1})); - builder.AddInstruction(HloInstruction::CreateUnary( - shape, HloOpcode::kNegate, rng)); - auto* computation = hlo_module->AddEntryComputation(builder.Build()); - computation->CreateFusionInstruction({rng, const0, const1}, - HloInstruction::FusionKind::kLoop); - - auto* root = computation->root_instruction(); - - EXPECT_EQ(HloOpcode::kFusion, root->operand(0)->opcode()); - } -} - - TEST_F(HloInstructionTest, CloneSuffixNames) { // Test that the suffix string added to cloned instructions is not // duplicated. Rather a numeric incrementing value should be appended. That @@ -1169,7 +1130,7 @@ TEST_F(HloInstructionTest, CloneSuffixNames) { } TEST_F(HloInstructionTest, Stringification) { - // Tests stringification of a simple op, fusion, and while. + // Tests stringification of a simple op, fusion, while, and conditional. const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10}); const Shape s2 = ShapeUtil::MakeShape(F32, {20, 10}); const Shape s2t = ShapeUtil::MakeShape(F32, {10, 20}); @@ -1182,12 +1143,17 @@ TEST_F(HloInstructionTest, Stringification) { builder.AddInstruction(HloInstruction::CreateParameter(1, s2, "y")); HloInstruction* reshape = builder.AddInstruction(HloInstruction::CreateTranspose(s2t, y, {1, 0})); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); HloInstruction* dot = builder.AddInstruction( - HloInstruction::CreateBinary(sout, HloOpcode::kDot, x, reshape)); + HloInstruction::CreateDot(sout, x, reshape, dot_dnums)); + + auto options = HloPrintOptions().set_print_metadata(false); - EXPECT_EQ(dot->ToString(false, false), + EXPECT_EQ(dot->ToString(options), "%dot = f32[5,20]{1,0} dot(f32[5,10]{1,0} %x, f32[10,20]{1,0} " - "%transpose)"); + "%transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0}"); HloModule module(TestName()); auto* computation = module.AddEntryComputation(builder.Build()); @@ -1195,15 +1161,25 @@ TEST_F(HloInstructionTest, Stringification) { {dot, reshape}, HloInstruction::FusionKind::kTransposeDot); EXPECT_EQ( - fusion->ToString(false, false), - "%fusion = f32[5,20]{1,0} fusion(f32[5,10]{1,0} %x, " + fusion->ToString(options), + "%dot_fusion = f32[5,20]{1,0} fusion(f32[5,10]{1,0} %x, " "f32[20,10]{1,0} %y), kind=kTransposeDot, calls=%fused_computation"); HloInstruction* loop = builder.AddInstruction( HloInstruction::CreateWhile(sout, computation, computation, x)); - EXPECT_EQ(loop->ToString(false, false), + EXPECT_EQ(loop->ToString(options), "%while = f32[5,20]{1,0} while(f32[5,10]{1,0} %x), " "condition=%TransposeDot, body=%TransposeDot"); + + auto pred = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(true))); + HloInstruction* conditional = + builder.AddInstruction(HloInstruction::CreateConditional( + sout, pred, x, computation, x, computation)); + EXPECT_EQ(conditional->ToString(options), + "%conditional = f32[5,20]{1,0} conditional(pred[] %constant, " + "f32[5,10]{1,0} %x, f32[5,10]{1,0} %x), " + "true_computation=%TransposeDot, false_computation=%TransposeDot"); } } // namespace diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index faaf73ea1ce5c77b0522cb3276b4efd78aabde16..58bb94221149c9a8b550add900dff52a53565985 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -35,14 +35,15 @@ namespace xla { HloModule::HloModule(const string& name, const VersionedComputationHandle& entry_computation_handle, const HloModuleConfig& config) - : name_(name), + : name_(NameUniquer::GetSanitizedName(name)), config_(config), has_entry_computation_handle_(true), entry_computation_handle_(entry_computation_handle) {} -HloModule::HloModule(const string& name) : name_(name) {} +HloModule::HloModule(const string& name) + : name_(NameUniquer::GetSanitizedName(name)) {} HloModule::HloModule(const string& name, const HloModuleConfig& config) - : name_(name), config_(config) {} + : name_(NameUniquer::GetSanitizedName(name)), config_(config) {} HloComputation* HloModule::AddComputationInternal( std::unique_ptr computation, bool is_entry, @@ -170,17 +171,14 @@ void HloModule::ReplaceComputations( computations_ = std::move(new_computations); } -string HloModule::ToString(bool include_large_constants) const { +string HloModule::ToString(const HloPrintOptions& options) const { std::ostringstream s; - s << "HloModule " << name() << ":\n\n"; + s << "HloModule " << name() << "\n\n"; for (const HloComputation* computation : MakeComputationPostOrder()) { if (computation == entry_computation()) { s << "ENTRY "; } - s << computation->ToString( - /*nested_level=*/0, - /*include_large_constants=*/include_large_constants) - << "\n\n"; + s << computation->ToString(options) << "\n\n"; } return s.str(); } @@ -232,8 +230,8 @@ StatusOr ProgramShapeFromProto(const HloModuleProto& module) { << "Entry computation has more than one parameter instruction " "with parameter number " << instruction.parameter_number(); - parameters[instruction.parameter_number()] = { - instruction.parameter_name(), &instruction.shape()}; + parameters[instruction.parameter_number()] = {instruction.name(), + &instruction.shape()}; } } TF_RET_CHECK(root != nullptr) @@ -459,6 +457,14 @@ HloInstruction* HloModule::OutlineExpressionFromComputation( return call; } +int64 HloModule::instruction_count() const { + int64 n = 0; + for (const auto& computation : computations_) { + n += computation->instruction_count(); + } + return n; +} + std::list HloModule::MakeComputationPostOrder() const { // First determine all root computations by building a set of nonroot // computations (computations which are called by an instruction in the diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h index 5141e7bc8d4cf0ef4cd83310772e0c5d66b5da12..e377654d024819d00f73f43a70d363bd902dc981 100644 --- a/tensorflow/compiler/xla/service/hlo_module.h +++ b/tensorflow/compiler/xla/service/hlo_module.h @@ -98,6 +98,10 @@ class HloModule { return config_.mutable_entry_computation_layout(); } + ComputationLayout entry_computation_layout() const { + return config_.entry_computation_layout(); + } + const VersionedComputationHandle& entry_computation_handle() const { return entry_computation_handle_; } @@ -125,6 +129,9 @@ class HloModule { // Gets the number of computations in this module. int64 computation_count() const { return computations_.size(); } + // Gets the number of instructions in this module. + int64 instruction_count() const; + // Compute and return a post order of all computations in the module. The sort // is defined like so: if computation A has an instruction which calls // computation B, then A will appear after B in the sort. @@ -143,7 +150,12 @@ class HloModule { const HloModuleConfig& config() const { return config_; } - string ToString(bool include_large_constants = false) const; + // Return a string representation of the module. + // + // (We express the default options using an overload rather than a default + // param because gdb ignores default params, but does resolve overloads.) + string ToString() const { return ToString(HloPrintOptions()); } + string ToString(const HloPrintOptions& options) const; // Convert an HloModule to or from a proto. HloModuleProto ToProto() const; diff --git a/tensorflow/compiler/xla/service/hlo_module_test.cc b/tensorflow/compiler/xla/service/hlo_module_test.cc index bf6440d66cac0d3a929c377202b212aba262f887..0f5d3dccb74e6e3c88e51685392171f940c03596 100644 --- a/tensorflow/compiler/xla/service/hlo_module_test.cc +++ b/tensorflow/compiler/xla/service/hlo_module_test.cc @@ -135,14 +135,15 @@ TEST_F(HloModuleTest, LargeConstantToString) { module->AddEntryComputation(builder.Build()); EXPECT_EQ( - "HloModule LargeConstantToString:\n\nENTRY %Constant () -> f32[16] {\n " + "HloModule LargeConstantToString\n\nENTRY %Constant () -> f32[16] {\n " "ROOT %constant = f32[16]{0} constant({...})\n}\n\n", - module->ToString(/*include_large_constants=*/false)); + module->ToString(HloPrintOptions().set_print_large_constants(false))); + EXPECT_EQ( - "HloModule LargeConstantToString:\n\nENTRY %Constant () -> f32[16] {\n " + "HloModule LargeConstantToString\n\nENTRY %Constant () -> f32[16] {\n " "ROOT %constant = f32[16]{0} constant({42, 42, 42, 42, 42, 42, 42, 42, " "42, 42, 42, 42, 42, 42, 42, 42})\n}\n\n", - module->ToString(/*include_large_constants=*/true)); + module->ToString(HloPrintOptions().set_print_large_constants(true))); } } // namespace diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h index f3f79357582ac7661a532e94031acdbca0b86784..3d64523a79fc50638fdf378b5d521a5cd4482b90 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.h +++ b/tensorflow/compiler/xla/service/hlo_opcode.h @@ -73,6 +73,7 @@ namespace xla { V(kDynamicUpdateSlice, "dynamic-update-slice") \ V(kEq, "equal-to", kHloOpcodeIsComparison) \ V(kExp, "exponential") \ + V(kFft, "fft") \ V(kFloor, "floor") \ V(kFusion, "fusion", kHloOpcodeIsVariadic) \ V(kGe, "greater-than-or-equal-to", kHloOpcodeIsComparison) \ diff --git a/tensorflow/compiler/xla/service/hlo_profile_printer.h b/tensorflow/compiler/xla/service/hlo_profile_printer.h index 316753a82ab2a9b5459b71c723a8e817ee2cacbf..2f056490ae027872570f7a0821ee63114f49fab8 100644 --- a/tensorflow/compiler/xla/service/hlo_profile_printer.h +++ b/tensorflow/compiler/xla/service/hlo_profile_printer.h @@ -65,9 +65,11 @@ class HloProfilePrinter { HloProfilePrinter( HloComputationInfo* computation_infos, int64 computation_infos_size, + int64 profile_counters_size, std::function deleter = nullptr) : computation_infos_(computation_infos), computation_infos_size_(computation_infos_size), + profile_counters_size_(profile_counters_size), deleter_(std::move(deleter)) {} HloProfilePrinter(HloProfilePrinter&& other) { @@ -79,10 +81,13 @@ class HloProfilePrinter { HloProfilePrinter(const HloProfilePrinter&) = delete; HloProfilePrinter& operator=(const HloProfilePrinter&) = delete; - // Convert the profile counter sequence `counters` to a human readable string + // Converts the profile counter sequence `counters` to a human readable string // representation. string ToString(const int64* counters, double clock_rate_ghz) const; + // Returns the size of the profile buffer expected by this printer. + int64 profile_counters_size() const { return profile_counters_size_; } + ~HloProfilePrinter(); private: @@ -90,6 +95,7 @@ class HloProfilePrinter { // is manifested as the deleter_ function. HloComputationInfo* computation_infos_ = nullptr; int64 computation_infos_size_ = 0; + int64 profile_counters_size_ = 0; std::function deleter_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_proto_util.cc b/tensorflow/compiler/xla/service/hlo_proto_util.cc index 727ad0178c6227cd2e64c31a4618e781671b9393..78e6a101c10a1e812e3e2631d520139fd0bc425c 100644 --- a/tensorflow/compiler/xla/service/hlo_proto_util.cc +++ b/tensorflow/compiler/xla/service/hlo_proto_util.cc @@ -19,15 +19,20 @@ namespace xla { HloProto MakeHloProto(const HloModule& module, const BufferAssignment& assignment) { - HloModuleProto proto_module = module.ToProto(); HloOrderingProto proto_ordering = assignment.liveness().hlo_ordering().ToProto(); BufferAssignmentProto proto_assignment = assignment.ToProto(); - HloProto proto; - proto.mutable_hlo_module()->Swap(&proto_module); + HloProto proto = MakeHloProto(module); proto.mutable_hlo_ordering()->Swap(&proto_ordering); proto.mutable_buffer_assignment()->Swap(&proto_assignment); return proto; } +HloProto MakeHloProto(const HloModule& module) { + HloModuleProto proto_module = module.ToProto(); + HloProto proto; + proto.mutable_hlo_module()->Swap(&proto_module); + return proto; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_proto_util.h b/tensorflow/compiler/xla/service/hlo_proto_util.h index 603259a11fcdca59f58653d9a7a164c983711a57..320288fdb9aa0810b306b1d78bd1ff4cfc366ed2 100644 --- a/tensorflow/compiler/xla/service/hlo_proto_util.h +++ b/tensorflow/compiler/xla/service/hlo_proto_util.h @@ -31,6 +31,10 @@ namespace xla { HloProto MakeHloProto(const HloModule& module, const BufferAssignment& assignment); +// Returns a serialized representation of the HLO state, but buffer assignment +// will not be included in the output. +HloProto MakeHloProto(const HloModule& module); + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PROTO_UTIL_H_ diff --git a/tensorflow/compiler/xla/service/hlo_reachability.h b/tensorflow/compiler/xla/service/hlo_reachability.h index d7bdac9c86579f19afbba133772c2c50894853d1..553ec11f6f9a2997ab7113f9b8241e04c7fe20d5 100644 --- a/tensorflow/compiler/xla/service/hlo_reachability.h +++ b/tensorflow/compiler/xla/service/hlo_reachability.h @@ -30,11 +30,17 @@ namespace xla { class HloInstruction; -// A class for computing and representing reachability between HloInstructions. +// A class for representing reachability between HloInstructions. +// +// !!! THIS CLASS DOES NOT COMPUTE REACHABILITY !!! It has an adjacency matrix +// and it is up to the user of the class to set the adjacency matrix such that +// it represents reachability, i.e. such that it is transitive. That the graph +// be transitive is thus not an invariant of this class, but it is required for +// the name of the class and its methods to make sense. class HloReachabilityMap { public: - // Sets up an empty reachable matrix for the full set of instructions - // specified in 'instructions'. + // Sets up a graph with no edges and where the nodes correspond to the given + // instructions. explicit HloReachabilityMap(const std::list& instructions); // Set the reachability set of 'instruction' to the union of the reachability @@ -42,17 +48,33 @@ class HloReachabilityMap { // 'x' is not 'instruction' will return true iff IsReachable(x, input) is true // for some 'input' in 'inputs'. Also sets 'instruction' to be reachable from // itself. Returns whether the reachability set of 'instruction' changed. + // + // !!! THIS FUNCTION DOES NOT COMPUTE REACHABILITY !!! It sets the adjacency + // vector in the internal graph of this HloReachabilityMap for the given + // instruction and does not transitively update any other part of the + // adjacency matrix. bool SetReachabilityToUnion( tensorflow::gtl::ArraySlice inputs, const HloInstruction* instruction); // Sets entry so that IsReachable(a, b) will return true + // + // !!! THIS FUNCTION DOES NOT COMPUTE REACHABILITY !!! It sets the adjacency + // matrix in the internal graph of this HloReachabilityMap to have an edge + // from a to b and does not transitively update any other part of the + // adjacency matrix. void SetReachable(const HloInstruction* a, const HloInstruction* b); // Returns true if "b" is reachable from "a" + // + // Note that this function only correctly answers queries about reachability + // if the set of edges that have been provided to this class are transitive. bool IsReachable(const HloInstruction* a, const HloInstruction* b) const; // Returns true if "b" is reachable from "a" or "a" is reachable from "b" + // + // Note that this function only correctly answers queries about reachability + // if the set of edges that have been provided to this class are transitive. bool IsConnected(const HloInstruction* a, const HloInstruction* b) const; private: diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc index 017f996bc4d1902c81f96425b7bc28d52622df0f..c6b4dc0368d92fd477decdfb38045f74f8696803 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc @@ -566,7 +566,9 @@ Status MemoryUsageTracker::BeginInstruction(Item* item) { VLOG(3) << " memory usage = " << memory_usage_; VLOG(10) << ToString(); - DCHECK(Check()); + if (VLOG_IS_ON(1)) { + DCHECK(Check()); + } return Status::OK(); } @@ -603,8 +605,9 @@ Status MemoryUsageTracker::EndInstruction() { VLOG(3) << " memory usage = " << memory_usage_; VLOG(10) << ToString(); - DCHECK(Check()); - + if (VLOG_IS_ON(1)) { + DCHECK(Check()); + } return Status::OK(); } @@ -1021,7 +1024,9 @@ StatusOr HloRematerialization::RematerializeComputation( HloInstruction* best = best_item->instruction; VLOG(1) << "Rematerializing instruction " << best->name() << " (saving " - << memory_tracker.MemoryReducedIfRematerialized(best_item) << ")"; + << HumanReadableNumBytes( + memory_tracker.MemoryReducedIfRematerialized(best_item)) + << ")"; changed = true; remat_count++; @@ -1101,8 +1106,8 @@ StatusOr HloRematerialization::RematerializeComputation( net_instructions_added++; } - VLOG(3) << "memory_usage after rematerialization = " - << memory_tracker.memory_usage(); + VLOG(1) << "memory_usage after rematerialization = " + << HumanReadableNumBytes(memory_tracker.memory_usage()); } const CallSite* callsite = call_graph_node.GetCallSite(instruction); @@ -1208,11 +1213,12 @@ StatusOr HloRematerialization::Run( XLA_VLOG_LINES(3, "Before HloRematerialization:\n" + module->ToString()); // Create initial sequence of HLO instructions. - TF_ASSIGN_OR_RETURN(*sequence, - CreateMemoryMinimizingSequence( - *module, [this](const LogicalBuffer& buffer) { - return size_function_(buffer.shape()); - })); + TF_ASSIGN_OR_RETURN(*sequence, CreateMemoryMinimizingSequence( + *module, + [this](const LogicalBuffer& buffer) { + return size_function_(buffer.shape()); + }, + scheduler_algorithm_)); // Compute peak memory usage of all computations in the module called in a // sequential context. call_graph_ = CallGraph::Build(module); @@ -1313,9 +1319,10 @@ StatusOr HloRematerialization::Run( /* static */ StatusOr HloRematerialization::RematerializeAndSchedule( const HloRematerialization::ShapeSizeFunction& size_function, int64 memory_limit_bytes, HloModule* hlo_module, + SchedulerAlgorithm scheduler_algorithm, SequentialHloOrdering::HloModuleSequence* sequence, RematerializationSizes* sizes) { - HloRematerialization remat(size_function); + HloRematerialization remat(scheduler_algorithm, size_function); return remat.Run(hlo_module, sequence, memory_limit_bytes, sizes); } diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.h b/tensorflow/compiler/xla/service/hlo_rematerialization.h index 11f79a6d4158c6251c2faf63e9cac4e742440863..52553439033a3bcfa4b472f13f9cd4b1ecf5ed96 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.h +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.h @@ -20,6 +20,7 @@ #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_scheduling.h" #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" namespace xla { @@ -65,12 +66,15 @@ class HloRematerialization { // code generation. static StatusOr RematerializeAndSchedule( const ShapeSizeFunction& size_function, int64 memory_limit_bytes, - HloModule* hlo_module, SequentialHloOrdering::HloModuleSequence* sequence, + HloModule* hlo_module, SchedulerAlgorithm scheduler_algorithm, + SequentialHloOrdering::HloModuleSequence* sequence, RematerializationSizes* sizes = nullptr); protected: - HloRematerialization(const ShapeSizeFunction& size_function) - : size_function_(size_function) {} + HloRematerialization(SchedulerAlgorithm scheduler_algorithm, + const ShapeSizeFunction& size_function) + : scheduler_algorithm_(scheduler_algorithm), + size_function_(size_function) {} ~HloRematerialization() {} // Runs rematerialization on the given module. Returns whether the module was @@ -103,6 +107,9 @@ class HloRematerialization { StatusOr CalledComputationsMemoryUsage( const HloInstruction* instruction) const; + // Selects an algorithm to use for HLO scheduling. + SchedulerAlgorithm scheduler_algorithm_; + // Function which computes the size of the top-level buffer of a shape. const ShapeSizeFunction size_function_; diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc index d88aa4bb567c6c5f6eab54f12239bf7040339c39..1b7d26dde501a6a0955d62ea0938e0683a32d49d 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc @@ -158,11 +158,11 @@ TEST_F(HloRematerializationTest, SingleComputation) { SequentialHloOrdering::HloModuleSequence sequence; // Computation requires 16KB without rematerialization, but uses only 12KB // with rematerialization so pick a memory limit between these values (14KB). - TF_ASSERT_OK_AND_ASSIGN( - bool changed, - HloRematerialization::RematerializeAndSchedule( - ByteSizeOf, - /*memory_limit_bytes=*/14 * 1024, module.get(), &sequence)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + HloRematerialization::RematerializeAndSchedule( + ByteSizeOf, + /*memory_limit_bytes=*/14 * 1024, module.get(), + SchedulerAlgorithm::kAuto, &sequence)); EXPECT_TRUE(changed); // Root should not have changed. @@ -191,11 +191,11 @@ TEST_F(HloRematerializationTest, SingleComputationNoRematerialization) { EXPECT_EQ(computation->instruction_count(), 7); SequentialHloOrdering::HloModuleSequence sequence; - TF_ASSERT_OK_AND_ASSIGN( - bool changed, - HloRematerialization::RematerializeAndSchedule( - ByteSizeOf, - /*memory_limit_bytes=*/20 * 1024, module.get(), &sequence)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + HloRematerialization::RematerializeAndSchedule( + ByteSizeOf, + /*memory_limit_bytes=*/20 * 1024, module.get(), + SchedulerAlgorithm::kAuto, &sequence)); // No instructions should have been materialized. EXPECT_FALSE(changed); @@ -232,11 +232,11 @@ TEST_F(HloRematerializationTest, RematerializeAroundWhile) { // while so the peak memory use of the module is 18KB. Set the memory limit a // bit lower (17KB) to force rematerialization of the entry computation. SequentialHloOrdering::HloModuleSequence sequence; - TF_ASSERT_OK_AND_ASSIGN( - bool changed, - HloRematerialization::RematerializeAndSchedule( - ByteSizeOf, - /*memory_limit_bytes=*/17 * 1024, module.get(), &sequence)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + HloRematerialization::RematerializeAndSchedule( + ByteSizeOf, + /*memory_limit_bytes=*/17 * 1024, module.get(), + SchedulerAlgorithm::kAuto, &sequence)); EXPECT_TRUE(changed); // Only the entry computation should have a rematerialized instruction added. @@ -268,11 +268,11 @@ TEST_F(HloRematerializationTest, RematerializeEntryAndWhileBody) { EXPECT_EQ(body_computation->instruction_count(), 7); SequentialHloOrdering::HloModuleSequence sequence; - TF_ASSERT_OK_AND_ASSIGN( - bool changed, - HloRematerialization::RematerializeAndSchedule( - ByteSizeOf, - /*memory_limit_bytes=*/15 * 1024, module.get(), &sequence)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + HloRematerialization::RematerializeAndSchedule( + ByteSizeOf, + /*memory_limit_bytes=*/15 * 1024, module.get(), + SchedulerAlgorithm::kAuto, &sequence)); EXPECT_TRUE(changed); // Both computations should have a rematerialized instruction added. @@ -310,11 +310,11 @@ TEST_F(HloRematerializationTest, RematerializeNestedComputations) { // If all computations are maximally rematerialized then peak memory usage is // ~12K so pick something slightly larger. SequentialHloOrdering::HloModuleSequence sequence; - TF_ASSERT_OK_AND_ASSIGN( - bool changed, - HloRematerialization::RematerializeAndSchedule( - ByteSizeOf, - /*memory_limit_bytes=*/13 * 1024, module.get(), &sequence)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + HloRematerialization::RematerializeAndSchedule( + ByteSizeOf, + /*memory_limit_bytes=*/13 * 1024, module.get(), + SchedulerAlgorithm::kAuto, &sequence)); EXPECT_TRUE(changed); // All computations should have a rematerialized instruction added. @@ -323,6 +323,76 @@ TEST_F(HloRematerializationTest, RematerializeNestedComputations) { EXPECT_EQ(inner_computation->instruction_count(), 8); } +TEST_F(HloRematerializationTest, RngNotRematerialized) { + // Test that a single rng is not rematerialized: + // + // Entry computation: + // F32[] %param = {...} + // F32[1024] rng = rng(param) + // F32[1024] tanh = tanh(rng) + // F32[1024] exp = exp(rng) + // F32[1024] add_0 = add(rng, tanh) // LIVE: add_0 + rng + + // // tanh + exp + // + // F32[1024] add_1 = add(rng, add(exp, add_0)) // LIVE: add_1 + add_0 + + // // rng + tanh + exp + // + // F32[1024] add_2 = add(rng, add(tanh, add_1)) // LIVE: add_2 + add_1 + + // // rng + tanh + exp + auto module = CreateNewModule(); + + auto builder = HloComputation::Builder(TestName()); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape_, "param")); + auto rng = builder.AddInstruction(HloInstruction::CreateRng( + vec1024_shape_, RandomDistribution::RNG_UNIFORM, {param, param})); + auto tanh = builder.AddInstruction( + HloInstruction::CreateUnary(vec1024_shape_, HloOpcode::kTanh, rng)); + auto exp = builder.AddInstruction( + HloInstruction::CreateUnary(vec1024_shape_, HloOpcode::kExp, rng)); + auto add_0 = builder.AddInstruction( + HloInstruction::CreateBinary(vec1024_shape_, HloOpcode::kAdd, rng, tanh)); + auto add_1 = builder.AddInstruction(HloInstruction::CreateBinary( + vec1024_shape_, HloOpcode::kAdd, rng, + builder.AddInstruction(HloInstruction::CreateBinary( + vec1024_shape_, HloOpcode::kAdd, exp, add_0)))); + builder.AddInstruction(HloInstruction::CreateBinary( + vec1024_shape_, HloOpcode::kAdd, rng, + builder.AddInstruction(HloInstruction::CreateBinary( + vec1024_shape_, HloOpcode::kAdd, tanh, add_1)))); + HloComputation* entry_computation = + module->AddEntryComputation(builder.Build()); + + auto count_rngs = [](const HloComputation* computation) { + int64 rng_count = 0; + for (auto* instruction : computation->instructions()) { + if (instruction->opcode() == HloOpcode::kRng) { + ++rng_count; + } + } + return rng_count; + }; + // Before rematerialization there should be a single broadcast rng in + // the graph. + ASSERT_EQ(count_rngs(entry_computation), 1); + const int64 original_instruction_count = + entry_computation->instruction_count(); + SequentialHloOrdering::HloModuleSequence sequence; + // Pick a memory limit some where between 24KB (initial peak memory including + // parameter and output) and 20KB (peak memory possible with + // rematerialization). + TF_ASSERT_OK_AND_ASSIGN( + bool changed, HloRematerialization::RematerializeAndSchedule( + ByteSizeOf, + /*memory_limit_bytes=*/4 * ByteSizeOf(vec1024_shape_), + module.get(), SchedulerAlgorithm::kAuto, &sequence)); + EXPECT_TRUE(changed); + // The rng should not have been rematerialized. + EXPECT_EQ(count_rngs(entry_computation), 1); + // There should have been rematerialization. + EXPECT_GT(entry_computation->instruction_count(), original_instruction_count); +} + TEST_F(HloRematerializationTest, InstructionRematerializedMultipleTimes) { // Test that a single instruction is rematerialized several times. Module: // @@ -406,11 +476,11 @@ TEST_F(HloRematerializationTest, InstructionRematerializedMultipleTimes) { // Pick a memory limit some where between 24KB (initial peak memory including // parameter and output) and 20KB (peak memory possible with // rematerialization). - TF_ASSERT_OK_AND_ASSIGN( - bool changed, - HloRematerialization::RematerializeAndSchedule( - ByteSizeOf, - /*memory_limit_bytes=*/22 * 1024, module.get(), &sequence)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + HloRematerialization::RematerializeAndSchedule( + ByteSizeOf, + /*memory_limit_bytes=*/22 * 1024, module.get(), + SchedulerAlgorithm::kAuto, &sequence)); EXPECT_TRUE(changed); // The broadcast should have been rematerialized 3 times. @@ -503,11 +573,11 @@ TEST_P(IndirectUseTest, IndirectUseNotRematerialized) { // Pick a memory limit some where between 24KB (initial peak memory including // parameter and output) and 20KB (peak memory possible with // rematerialization). - TF_ASSERT_OK_AND_ASSIGN( - bool changed, - HloRematerialization::RematerializeAndSchedule( - ByteSizeOf, - /*memory_limit_bytes=*/22 * 1024, module.get(), &sequence)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + HloRematerialization::RematerializeAndSchedule( + ByteSizeOf, + /*memory_limit_bytes=*/22 * 1024, module.get(), + SchedulerAlgorithm::kAuto, &sequence)); // Rematerialization should only occur if the rematerializable instruction has // no indirect uses. if (indirectly_used) { diff --git a/tensorflow/compiler/xla/service/hlo_runner.cc b/tensorflow/compiler/xla/service/hlo_runner.cc index 6b6d48233a7da50927207b8334186ee5105db268..204a8bf748685af71ac82be0d102cf7f76c7b38f 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.cc +++ b/tensorflow/compiler/xla/service/hlo_runner.cc @@ -39,6 +39,14 @@ namespace se = ::perftools::gputools; namespace xla { +/*static*/ StatusOr> +HloRunner::CreateModuleFromString(const tensorflow::StringPiece hlo_string, + const DebugOptions& debug_options) { + HloModuleConfig config; + config.set_debug_options(debug_options); + return tools::Parse(hlo_string, config); +} + /*static*/ StatusOr> HloRunner::ReadModuleFromHloProtoFile(const std::string& filename, const DebugOptions& debug_options) { @@ -104,17 +112,12 @@ HloRunner::HloRunner(se::Platform* platform) { VLOG(1) << "Created HloRunner for platform: " << platform->Name(); } -HloRunner::~HloRunner() { - // Deallocate all the memory allocated during the tests. - for (auto& allocation : allocations_) { - backend().default_stream_executor()->Deallocate(&allocation); - } -} +HloRunner::~HloRunner() {} -StatusOr HloRunner::Execute( +StatusOr> HloRunner::ExecuteInternal( std::unique_ptr module, - tensorflow::gtl::ArraySlice arguments, - Shape* result_shape, bool run_hlo_passes) { + const tensorflow::gtl::ArraySlice arguments, + bool run_hlo_passes) { if (run_hlo_passes) { TF_ASSIGN_OR_RETURN( module, backend().compiler()->RunHloPasses( @@ -129,6 +132,7 @@ StatusOr HloRunner::Execute( stream.Init(); ExecutableRunOptions run_options; + run_options.set_device_ordinal(backend().default_device_ordinal()); run_options.set_stream(&stream); run_options.set_allocator(backend().memory_allocator()); run_options.set_inter_op_thread_pool(backend().inter_op_thread_pool()); @@ -138,73 +142,43 @@ StatusOr HloRunner::Execute( ServiceExecutableRunOptions service_run_options( run_options, backend().StreamBorrower(), backend().inter_op_thread_pool()); - TF_ASSIGN_OR_RETURN( - se::DeviceMemoryBase result, - executable->ExecuteOnStream(&service_run_options, arguments, - /*hlo_execution_profile=*/nullptr)); - TF_RET_CHECK(stream.BlockHostUntilDone()); - - allocations_.push_back(result); - - *result_shape = executable->result_shape(); - if (ShapeUtil::IsTuple(*result_shape)) { - // We must record element buffers of tuples as well to avoid leaks. - DCHECK(!ShapeUtil::IsNestedTuple(*result_shape)); + // Copy arguments to device. + std::vector> argument_buffers; + std::vector argument_buffer_ptrs; + for (Literal* argument : arguments) { TF_ASSIGN_OR_RETURN( - std::vector element_buffers, - backend().transfer_manager()->ShallowCopyTupleFromDevice( - backend().default_stream_executor(), result, *result_shape)); - - // A tuple may contain the same buffer in more than one element. Keep track - // of the buffers already added to avoid duplicates in allocations_. - std::set added_opaques; - for (auto element_buffer : element_buffers) { - if (added_opaques.count(element_buffer.opaque()) == 0) { - CHECK(element_buffer.opaque() != nullptr); - added_opaques.insert(element_buffer.opaque()); - allocations_.push_back(element_buffer); - } - } + std::unique_ptr argument_buffer, + backend().transfer_manager()->AllocateScopedShapedBuffer( + argument->shape(), run_options.allocator(), + run_options.device_ordinal())); + TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralToDevice( + stream.parent(), *argument, *argument_buffer)); + argument_buffers.push_back(std::move(argument_buffer)); + argument_buffer_ptrs.push_back(argument_buffers.back().get()); } - return result; -} - -StatusOr HloRunner::TransferToDevice( - const Literal& literal) { - // Allocate memory on the device using the stream executor. - int64 allocation_size = - backend().transfer_manager()->GetByteSizeRequirement(literal.shape()); - se::DeviceMemoryBase allocation = - backend().default_stream_executor()->AllocateArray( - allocation_size); - allocations_.push_back(allocation); - - TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralToDevice( - backend().default_stream_executor(), literal, &allocation)); - - return allocation; -} - -StatusOr> HloRunner::TransferFromDevice( - const Shape& shape, se::DeviceMemoryBase device_base) { - auto literal = MakeUnique(); - TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralFromDevice( - backend().default_stream_executor(), device_base, shape, shape, - literal.get())); - return std::move(literal); -} + TF_ASSIGN_OR_RETURN( + std::unique_ptr result, + executable->ExecuteOnStream(&service_run_options, argument_buffer_ptrs, + /*hlo_execution_profile=*/nullptr)); -StatusOr> HloRunner::ExecuteAndTransfer( - std::unique_ptr module, - tensorflow::gtl::ArraySlice arguments, - bool run_hlo_passes) { - Shape result_shape; + // Create a ScopedShapedBuffer of the result to manage deallocation. This will + // deallocate all the device memory when it goes out of scope. TF_ASSIGN_OR_RETURN( - se::DeviceMemoryBase device_base, - Execute(std::move(module), arguments, &result_shape, run_hlo_passes)); - return TransferFromDevice(result_shape, device_base); + std::unique_ptr scoped_result, + ScopedShapedBuffer::MakeScoped(result.get(), run_options.allocator())); + + auto result_literal = backend().transfer_manager()->TransferLiteralFromDevice( + stream.parent(), *scoped_result); + if (result_literal.ok()) { + VLOG(4) << "Executed binary and got result: " + << result_literal.ValueOrDie()->ToString(); + } else { + VLOG(4) << "Executed binary and got status: " + << result_literal.status().ToString(); + } + return result_literal; } Backend& HloRunner::backend() { diff --git a/tensorflow/compiler/xla/service/hlo_runner.h b/tensorflow/compiler/xla/service/hlo_runner.h index 95cddafc91ff40948efc4b0744343d994cf84f3a..d4b221fb52dff64dda264a931df6fd19b86e5260 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.h +++ b/tensorflow/compiler/xla/service/hlo_runner.h @@ -35,7 +35,8 @@ namespace xla { // A base class for running an HloModule. This executes the given HloModule on a // certain backend directly without using the client interface. HloModule can be -// explicitly built, or loaded from a serialization file (e.g., hlo proto file). +// explicitly built, or loaded from a serialization file (e.g., hlo proto +// file), or parsed from a hlo textual IR string. class HloRunner { public: HloRunner(); @@ -44,6 +45,12 @@ class HloRunner { ~HloRunner(); + // Converts an HloModule from the given hlo textual IR string (in + // HloModule::ToString format). + static StatusOr> CreateModuleFromString( + const tensorflow::StringPiece hlo_string, + const DebugOptions& debug_options); + // Reads the proto file in xla.HloProto format, creates and returns the // HloModule. Will try to parse the filename as binary proto, then try as // text proto if that fails. @@ -65,35 +72,13 @@ class HloRunner { // Executes the given module with given literals as input and returns the // result as a Literal. The LiteralPtr type accepts Literal* or // std::unique_ptr. - // If run_hlo_passes is true, the module will be executed without Hlo + // + // If run_hlo_passes is false, the module will be executed without Hlo // optimization. template StatusOr> Execute( std::unique_ptr module, - const tensorflow::gtl::ArraySlice literals, - bool run_hlo_passes = true); - - // Executes the given module and returns a global data handle. - StatusOr Execute( - std::unique_ptr module, - tensorflow::gtl::ArraySlice - arguments, - Shape* result_shape, bool run_hlo_passes = true); - - // Transfers the given literal to the device and returns the data handle. - StatusOr TransferToDevice( - const Literal& literal); - - // Transfers the array referred to by the given handle from the device and - // returns as a Literal. - StatusOr> TransferFromDevice( - const Shape& shape, perftools::gputools::DeviceMemoryBase device_base); - - // Executes the given module and return the result as a Literal. - StatusOr> ExecuteAndTransfer( - std::unique_ptr module, - tensorflow::gtl::ArraySlice - arguments, + const tensorflow::gtl::ArraySlice arguments, bool run_hlo_passes = true); // If backend is not created in the constructor, creates and returns the @@ -104,9 +89,12 @@ class HloRunner { Backend& backend(); private: - struct EigenThreadPoolWrapper; + StatusOr> ExecuteInternal( + std::unique_ptr module, + const tensorflow::gtl::ArraySlice arguments, + bool run_hlo_passes = true); - std::vector allocations_; + struct EigenThreadPoolWrapper; std::unique_ptr thread_pool_wrapper_; @@ -116,15 +104,14 @@ class HloRunner { template StatusOr> HloRunner::Execute( std::unique_ptr module, - const tensorflow::gtl::ArraySlice literals, + const tensorflow::gtl::ArraySlice arguments, bool run_hlo_passes) { - std::vector arguments; - for (const auto& literal : literals) { - TF_ASSIGN_OR_RETURN(perftools::gputools::DeviceMemoryBase argument, - TransferToDevice(*literal)); - arguments.push_back(argument); + // Construct a vector of plain pointers for the arguments. + std::vector argument_pointers; + for (const auto& argument : arguments) { + argument_pointers.push_back(&*argument); } - return ExecuteAndTransfer(std::move(module), arguments, run_hlo_passes); + return ExecuteInternal(std::move(module), argument_pointers, run_hlo_passes); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.cc b/tensorflow/compiler/xla/service/hlo_scheduling.cc index 8ccbcaeee4a9c9e94b344231953e20ac8f4b2053..2594c29efd717b3bead34d326c28c7efdf093c50 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling.cc +++ b/tensorflow/compiler/xla/service/hlo_scheduling.cc @@ -31,6 +31,8 @@ limitations under the License. #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" +using ::tensorflow::strings::HumanReadableNumBytes; + namespace xla { StatusOr MinimumMemoryForSequence( @@ -367,7 +369,17 @@ StatusOr MinimumMemoryForComputation( StatusOr> CreateMemoryMinimizingSequence( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_function) { + const LogicalBuffer::SizeFunction& size_function, + SchedulerAlgorithm algorithm) { + VLOG(2) << "Computation: " << computation.name(); + if (algorithm == SchedulerAlgorithm::kListSchedule) { + return ListScheduler::Run(computation, points_to_analysis, size_function); + } + if (algorithm == SchedulerAlgorithm::kDfsSchedule) { + return RunDFSMemoryScheduler(computation, points_to_analysis, + size_function); + } + // We try both a list-scheduler based ordering and a DFS based ordering, and // choose whichever returns a lower min-memory, not accounting for // fragmentation. @@ -382,7 +394,7 @@ StatusOr> CreateMemoryMinimizingSequence( const int64 list_memory, MinimumMemoryForComputation(computation, list_sequence, points_to_analysis, size_function)); - VLOG(2) << "Min-memory list sequence: " << list_memory << " bytes"; + VLOG(2) << "Min-memory list sequence: " << HumanReadableNumBytes(list_memory); TF_ASSIGN_OR_RETURN( std::vector dfs_sequence, @@ -391,13 +403,15 @@ StatusOr> CreateMemoryMinimizingSequence( const int64 dfs_memory, MinimumMemoryForComputation(computation, dfs_sequence, points_to_analysis, size_function)); - VLOG(2) << "Min-memory dfs sequence: " << dfs_memory << " bytes"; + VLOG(2) << "Min-memory dfs sequence: " << HumanReadableNumBytes(dfs_memory); if (list_memory <= dfs_memory) { - VLOG(2) << "Chose min-memory list sequence: " << list_memory << " bytes"; + VLOG(2) << "Chose min-memory list sequence: " + << HumanReadableNumBytes(list_memory); return list_sequence; } else { - VLOG(2) << "Chose min-memory dfs sequence: " << dfs_memory << " bytes"; + VLOG(2) << "Chose min-memory dfs sequence: " + << HumanReadableNumBytes(dfs_memory); return dfs_sequence; } } @@ -405,27 +419,30 @@ StatusOr> CreateMemoryMinimizingSequence( } // namespace StatusOr -CreateMemoryMinimizingSequence( - const HloModule& module, const LogicalBuffer::SizeFunction& size_function) { +CreateMemoryMinimizingSequence(const HloModule& module, + const LogicalBuffer::SizeFunction& size_function, + SchedulerAlgorithm algorithm) { SequentialHloOrdering::HloModuleSequence sequence; TF_ASSIGN_OR_RETURN(std::unique_ptr points_to_analysis, TuplePointsToAnalysis::Run(&module)); for (const auto* computation : module.MakeNonfusionComputations()) { - TF_ASSIGN_OR_RETURN(sequence[computation], - CreateMemoryMinimizingSequence( - *computation, *points_to_analysis, size_function)); + TF_ASSIGN_OR_RETURN( + sequence[computation], + CreateMemoryMinimizingSequence(*computation, *points_to_analysis, + size_function, algorithm)); } return sequence; } StatusOr> CreateMemoryMinimizingSequence( const HloComputation& computation, - const LogicalBuffer::SizeFunction& size_function) { + const LogicalBuffer::SizeFunction& size_function, + SchedulerAlgorithm algorithm) { CHECK(!computation.IsFusionComputation()); TF_ASSIGN_OR_RETURN(std::unique_ptr points_to_analysis, TuplePointsToAnalysis::Run(computation.parent())); return CreateMemoryMinimizingSequence(computation, *points_to_analysis, - size_function); + size_function, algorithm); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.h b/tensorflow/compiler/xla/service/hlo_scheduling.h index ec92a56b962152b15981f868369683144aa7c76a..1d1eb1e064f75c2220b39e84b010e720a0c37880 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling.h +++ b/tensorflow/compiler/xla/service/hlo_scheduling.h @@ -33,17 +33,28 @@ StatusOr MinimumMemoryForSequence( const SequentialHloOrdering::HloModuleSequence& module_sequence, const LogicalBuffer::SizeFunction& size_function); +enum class SchedulerAlgorithm { + kListSchedule, + kDfsSchedule, + + // Selects the available scheduler algorithm that had the minimum memory in + // the resulting sequence (a la MinimumMemoryForSequence). + kAuto, +}; + // Returns an HloModuleSequence which seeks to minimize the memory required for // the computation. size_function is the function returning the number of bytes // required for a LogicalBuffer. StatusOr CreateMemoryMinimizingSequence( - const HloModule& module, const LogicalBuffer::SizeFunction& size_function); + const HloModule& module, const LogicalBuffer::SizeFunction& size_function, + SchedulerAlgorithm algorithm = SchedulerAlgorithm::kAuto); // Overload of above that computes the sequence for a single computation. StatusOr> CreateMemoryMinimizingSequence( const HloComputation& computation, - const LogicalBuffer::SizeFunction& size_function); + const LogicalBuffer::SizeFunction& size_function, + SchedulerAlgorithm algorithm = SchedulerAlgorithm::kAuto); } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc index d1adec31c21fe55001db4d522ddda27dd538bc95..447c2446668253c932b44b51b2db22bfd47f9957 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding.cc @@ -246,7 +246,8 @@ Status HloSharding::ValidateNonTuple(const Shape& shape, // The tile rank must be the same as the input rank. if (ShapeUtil::Rank(shape) != ShapeUtil::Rank(tile_shape_)) { return tensorflow::errors::InvalidArgument( - "Tile rank is different to the input rank"); + "Tile rank is different to the input rank. sharding=", ToString(), + ", input_shape=", ShapeUtil::HumanString(shape)); } // The tile shape must not be the same as the input shape without maximal_ diff --git a/tensorflow/compiler/xla/service/hlo_sharding.h b/tensorflow/compiler/xla/service/hlo_sharding.h index 1a6988a2dc872a39ff6b0551adf7ddb871f0d72a..7263198385cf0c84b1dac1e15177dcac99adaafb 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.h +++ b/tensorflow/compiler/xla/service/hlo_sharding.h @@ -80,6 +80,17 @@ class HloSharding { return HloSharding(flattened_list); } + // Creates a new sharding for a tuple type. The requested tuple shape must not + // be nested. For nested tuples, use the ShapeTree overload. + static HloSharding Tuple(const Shape& tuple_shape, + tensorflow::gtl::ArraySlice shardings) { + CHECK(ShapeUtil::IsTuple(tuple_shape)); + CHECK(!ShapeUtil::IsNestedTuple(tuple_shape)); + std::vector flattened_list(shardings.begin(), shardings.end()); + CHECK_EQ(flattened_list.size(), ShapeUtil::TupleElementCount(tuple_shape)); + return HloSharding(flattened_list); + } + // Create a new sharding from a protobuf OpSharding. static StatusOr FromProto(const OpSharding& proto); diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc index 101a710d1cad9401134fdfe1d0ec9df241bc01e1..3dc733940fc89952bd5e75a9b28d9cbf356f8000 100644 --- a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc +++ b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc @@ -166,7 +166,7 @@ void HloTfGraphBuilder::SetNodeAttrs(const HloInstruction* instruction, layout_string = ShapeUtil::HumanStringWithLayout(instruction->shape()); } else { layout_string = StrCat( - "{", Join(instruction->shape().layout().minor_to_major(), ","), "}"); + "{", Join(LayoutUtil::MinorToMajor(instruction->shape()), ","), "}"); } attrs["layout"].set_s(layout_string); } diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index 15188c4057eca8eea1805e599cd020c045fdd10a..9d9cf0c0f67f50a13f6d966079b3f9748b0a52e9 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -14,412 +14,400 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/hlo_verifier.h" -#include "tensorflow/compiler/xla/service/shape_inference.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/flatmap.h" namespace xla { -namespace { +Status ShapeVerifier::HandleElementwiseUnary(HloInstruction* hlo) { + return CheckUnaryShape(hlo); +} -// Visitor which verifies that the output shape is correctly set. Verifies -// against the inferred shape for the instruction. -// TODO(b/26024837): Check output shape for all instruction types. -class ShapeVerifier : public DfsHloVisitor { - public: - explicit ShapeVerifier( - const std::function& shape_size_fn) - : shape_size_fn_(shape_size_fn) {} +Status ShapeVerifier::HandleElementwiseBinary(HloInstruction* hlo) { + return CheckBinaryShape(hlo); +} - Status HandleElementwiseUnary(HloInstruction* hlo) override { - return CheckUnaryShape(hlo); - } +Status ShapeVerifier::HandleClamp(HloInstruction* clamp) { + return CheckTernaryShape(clamp); +} - Status HandleElementwiseBinary(HloInstruction* hlo) override { - return CheckBinaryShape(hlo); - } +Status ShapeVerifier::HandleSelect(HloInstruction* select) { + return CheckTernaryShape(select); +} - Status HandleClamp(HloInstruction* clamp) override { - return CheckTernaryShape(clamp); +Status ShapeVerifier::HandleConcatenate(HloInstruction* concatenate) { + std::vector operand_shapes; + for (const HloInstruction* operand : concatenate->operands()) { + operand_shapes.push_back(&operand->shape()); } + return CheckShape(concatenate, + ShapeInference::InferConcatOpShape( + operand_shapes, concatenate->concatenate_dimension())); +} - Status HandleSelect(HloInstruction* select) override { - return CheckTernaryShape(select); - } +Status ShapeVerifier::HandleConvert(HloInstruction* convert) { + return CheckShape(convert, ShapeInference::InferConvertShape( + convert->operand(0)->shape(), + convert->shape().element_type())); +} - Status HandleConcatenate(HloInstruction* concatenate) override { - std::vector operand_shapes; - for (const HloInstruction* operand : concatenate->operands()) { - operand_shapes.push_back(&operand->shape()); - } - return CheckShape( - concatenate, ShapeInference::InferConcatOpShape( - operand_shapes, concatenate->concatenate_dimension())); - } +Status ShapeVerifier::HandleBitcastConvert(HloInstruction* convert) { + return CheckShape(convert, ShapeInference::InferBitcastConvertShape( + convert->operand(0)->shape(), + convert->shape().element_type())); +} - Status HandleConvert(HloInstruction* convert) override { - return CheckShape(convert, ShapeInference::InferConvertShape( - convert->operand(0)->shape(), - convert->shape().element_type())); - } +Status ShapeVerifier::HandleCopy(HloInstruction* copy) { + return CheckUnaryShape(copy); +} - Status HandleBitcastConvert(HloInstruction* convert) override { - return CheckShape(convert, ShapeInference::InferBitcastConvertShape( - convert->operand(0)->shape(), - convert->shape().element_type())); - } +Status ShapeVerifier::HandleDot(HloInstruction* dot) { + TF_ASSIGN_OR_RETURN(const Shape expected, + ShapeInference::InferDotOpShape( + dot->operand(0)->shape(), dot->operand(1)->shape(), + dot->dot_dimension_numbers())); + return CheckShape(dot, expected); +} - Status HandleCopy(HloInstruction* copy) override { - return CheckUnaryShape(copy); - } +Status ShapeVerifier::HandleConvolution(HloInstruction* convolution) { + TF_ASSIGN_OR_RETURN( + const Shape expected, + ShapeInference::InferConvolveShape( + convolution->operand(0)->shape(), convolution->operand(1)->shape(), + convolution->window(), convolution->convolution_dimension_numbers())); + return CheckShape(convolution, expected); +} - Status HandleDot(HloInstruction* dot) override { - return CheckBinaryShape(dot); - } +Status ShapeVerifier::HandleFft(HloInstruction* fft) { + TF_ASSIGN_OR_RETURN( + const Shape expected, + ShapeInference::InferFftShape(fft->operand(0)->shape(), fft->fft_type(), + fft->fft_length())); + return CheckShape(fft, expected); +} - Status HandleConvolution(HloInstruction* convolution) override { - TF_ASSIGN_OR_RETURN( - const Shape expected, - ShapeInference::InferConvolveShape( - convolution->operand(0)->shape(), convolution->operand(1)->shape(), - convolution->window(), - convolution->convolution_dimension_numbers())); - return CheckShape(convolution, expected); +Status ShapeVerifier::HandleCrossReplicaSum(HloInstruction* crs) { + std::vector operand_shapes; + for (const HloInstruction* operand : crs->operands()) { + operand_shapes.push_back(&operand->shape()); } + return CheckShape(crs, + ShapeInference::InferCrossReplicaSumShape(operand_shapes)); +} - Status HandleCrossReplicaSum(HloInstruction* crs) override { - return CheckShape(crs, ShapeInference::InferCrossReplicaSumShape( - crs->operand(0)->shape())); - } +Status ShapeVerifier::HandleReducePrecision(HloInstruction* reduce_precision) { + return CheckShape(reduce_precision, ShapeInference::InferReducePrecisionShape( + reduce_precision->operand(0)->shape(), + reduce_precision->exponent_bits(), + reduce_precision->mantissa_bits())); +} - Status HandleReducePrecision(HloInstruction* reduce_precision) override { - return CheckShape(reduce_precision, - ShapeInference::InferReducePrecisionShape( - reduce_precision->operand(0)->shape(), - reduce_precision->exponent_bits(), - reduce_precision->mantissa_bits())); - } +Status ShapeVerifier::HandleInfeed(HloInstruction*) { + return tensorflow::Status::OK(); +} - Status HandleInfeed(HloInstruction*) override { - return tensorflow::Status::OK(); - } +Status ShapeVerifier::HandleOutfeed(HloInstruction*) { + return tensorflow::Status::OK(); +} - Status HandleOutfeed(HloInstruction*) override { - return tensorflow::Status::OK(); - } +Status ShapeVerifier::HandleRng(HloInstruction*) { + return tensorflow::Status::OK(); +} - Status HandleRng(HloInstruction*) override { - return tensorflow::Status::OK(); - } +Status ShapeVerifier::HandleReverse(HloInstruction* reverse) { + return CheckShape( + reverse, ShapeInference::InferReverseShape(reverse->operand(0)->shape(), + reverse->dimensions())); +} - Status HandleReverse(HloInstruction* reverse) override { - return CheckShape( - reverse, ShapeInference::InferReverseShape(reverse->operand(0)->shape(), - reverse->dimensions())); - } +Status ShapeVerifier::HandleSort(HloInstruction* sort) { + return CheckUnaryShape(sort); +} - Status HandleSort(HloInstruction* sort) override { - return CheckUnaryShape(sort); - } +Status ShapeVerifier::HandleConstant(HloInstruction* constant) { + return CheckShape(constant, constant->literal().shape()); +} - Status HandleConstant(HloInstruction* constant) override { - return CheckShape(constant, constant->literal().shape()); - } +Status ShapeVerifier::HandleGetTupleElement(HloInstruction* get_tuple_element) { + return CheckShape(get_tuple_element, + ShapeInference::InferGetTupleElementShape( + get_tuple_element->operand(0)->shape(), + get_tuple_element->tuple_index())); +} - Status HandleGetTupleElement(HloInstruction* get_tuple_element) override { - return CheckShape(get_tuple_element, - ShapeInference::InferGetTupleElementShape( - get_tuple_element->operand(0)->shape(), - get_tuple_element->tuple_index())); - } +Status ShapeVerifier::HandleReduce(HloInstruction* reduce) { + return CheckShape( + reduce, + ShapeInference::InferReduceShape( + reduce->operand(0)->shape(), reduce->operand(1)->shape(), + reduce->dimensions(), reduce->to_apply()->ComputeProgramShape())); +} - Status HandleReduce(HloInstruction* reduce) override { - return CheckShape( - reduce, - ShapeInference::InferReduceShape( - reduce->operand(0)->shape(), reduce->operand(1)->shape(), - reduce->dimensions(), reduce->to_apply()->ComputeProgramShape())); - } +Status ShapeVerifier::HandleBitcast(HloInstruction* bitcast) { + return tensorflow::Status::OK(); +} - Status HandleBitcast(HloInstruction* bitcast) override { - // Bitcasts can be any shape, as long as the size matches the operand size. - TF_RET_CHECK(shape_size_fn_(bitcast->shape()) == - shape_size_fn_(bitcast->operand(0)->shape())); - return tensorflow::Status::OK(); +Status ShapeVerifier::HandleBroadcast(HloInstruction* broadcast) { + // HLO broadcast has no exact analog at the proto level so there is no + // ShapeInference method. Check the output shape explicitly. + const Shape& operand_shape = broadcast->operand(0)->shape(); + TF_RET_CHECK(ShapeUtil::Rank(operand_shape) == + broadcast->dimensions().size()); + for (int64 operand_dimension = 0; + operand_dimension < ShapeUtil::Rank(operand_shape); + ++operand_dimension) { + int64 output_dimension = broadcast->dimensions()[operand_dimension]; + TF_RET_CHECK(broadcast->shape().dimensions(output_dimension) == + operand_shape.dimensions(operand_dimension)); } + return tensorflow::Status::OK(); +} - Status HandleBroadcast(HloInstruction* broadcast) override { - // HLO broadcast has no exact analog at the proto level so there is no - // ShapeInference method. Check the output shape explicitly. - const Shape& operand_shape = broadcast->operand(0)->shape(); - TF_RET_CHECK(ShapeUtil::Rank(operand_shape) == - broadcast->dimensions().size()); - for (int64 operand_dimension = 0; - operand_dimension < ShapeUtil::Rank(operand_shape); - ++operand_dimension) { - int64 output_dimension = broadcast->dimensions()[operand_dimension]; - TF_RET_CHECK(broadcast->shape().dimensions(output_dimension) == - operand_shape.dimensions(operand_dimension)); - } - return tensorflow::Status::OK(); - } +Status ShapeVerifier::HandleReshape(HloInstruction* reshape) { + TF_RET_CHECK(ShapeUtil::ElementsIn(reshape->shape()) == + ShapeUtil::ElementsIn(reshape->operand(0)->shape())); + return tensorflow::Status::OK(); +} - Status HandleReshape(HloInstruction* reshape) override { - TF_RET_CHECK(ShapeUtil::ElementsIn(reshape->shape()) == - ShapeUtil::ElementsIn(reshape->operand(0)->shape())); - return tensorflow::Status::OK(); - } +Status ShapeVerifier::HandleTranspose(HloInstruction* transpose) { + return CheckShape( + transpose, ShapeInference::InferTransposeShape( + transpose->operand(0)->shape(), transpose->dimensions())); +} - Status HandleTranspose(HloInstruction* transpose) override { - return CheckShape(transpose, ShapeInference::InferTransposeShape( - transpose->operand(0)->shape(), - transpose->dimensions())); - } +Status ShapeVerifier::HandleParameter(HloInstruction*) { + return tensorflow::Status::OK(); +} - Status HandleParameter(HloInstruction*) override { - return tensorflow::Status::OK(); - } +Status ShapeVerifier::HandleFusion(HloInstruction*) { + return tensorflow::Status::OK(); +} - Status HandleFusion(HloInstruction*) override { - return tensorflow::Status::OK(); - } +Status ShapeVerifier::HandleCall(HloInstruction* call) { + // The shape of kCall should match the shape of the computation it calls. + return CheckShape(call, call->to_apply()->ComputeProgramShape().result()); +} - Status HandleCall(HloInstruction* call) override { - // The shape of kCall should match the shape of the computation it calls. - return CheckShape(call, call->to_apply()->ComputeProgramShape().result()); - } +Status ShapeVerifier::HandleCustomCall(HloInstruction*) { + return tensorflow::Status::OK(); +} - Status HandleCustomCall(HloInstruction*) override { - return tensorflow::Status::OK(); - } +Status ShapeVerifier::HandleSlice(HloInstruction* slice) { + return CheckShape(slice, + ShapeInference::InferSliceShape( + slice->operand(0)->shape(), slice->slice_starts(), + slice->slice_limits(), slice->slice_strides())); +} - Status HandleSlice(HloInstruction* slice) override { - return CheckShape(slice, - ShapeInference::InferSliceShape( - slice->operand(0)->shape(), slice->slice_starts(), - slice->slice_limits(), slice->slice_strides())); - } +Status ShapeVerifier::HandleDynamicSlice(HloInstruction* dynamic_slice) { + return CheckShape(dynamic_slice, ShapeInference::InferDynamicSliceShape( + dynamic_slice->operand(0)->shape(), + dynamic_slice->operand(1)->shape(), + dynamic_slice->dynamic_slice_sizes())); +} - Status HandleDynamicSlice(HloInstruction* dynamic_slice) override { - return CheckShape(dynamic_slice, ShapeInference::InferDynamicSliceShape( - dynamic_slice->operand(0)->shape(), - dynamic_slice->operand(1)->shape(), - dynamic_slice->dynamic_slice_sizes())); - } +Status ShapeVerifier::HandleDynamicUpdateSlice( + HloInstruction* dynamic_update_slice) { + return CheckShape(dynamic_update_slice, + ShapeInference::InferDynamicUpdateSliceShape( + dynamic_update_slice->operand(0)->shape(), + dynamic_update_slice->operand(1)->shape(), + dynamic_update_slice->operand(2)->shape())); +} - Status HandleDynamicUpdateSlice( - HloInstruction* dynamic_update_slice) override { - return CheckShape(dynamic_update_slice, - ShapeInference::InferDynamicUpdateSliceShape( - dynamic_update_slice->operand(0)->shape(), - dynamic_update_slice->operand(1)->shape(), - dynamic_update_slice->operand(2)->shape())); - } +Status ShapeVerifier::HandleTuple(HloInstruction* tuple) { + return CheckVariadicShape(tuple); +} - Status HandleTuple(HloInstruction* tuple) override { - return CheckVariadicShape(tuple); - } +Status ShapeVerifier::HandleMap(HloInstruction* map) { + std::vector operand_shapes; + int64 max_operand_rank = 0; + for (const HloInstruction* operand : map->operands()) { + operand_shapes.push_back(&operand->shape()); + max_operand_rank = + std::max(max_operand_rank, ShapeUtil::Rank(operand->shape())); + } + // TODO(b/65689298) Remove code below once Map is generalized to accept + // arbitrary map dimensions. + std::vector map_dims(max_operand_rank); + std::iota(map_dims.begin(), map_dims.end(), 0); + return CheckShape(map, ShapeInference::InferMapShape( + operand_shapes, + map->to_apply()->ComputeProgramShape(), map_dims)); +} - Status HandleMap(HloInstruction* map) override { - std::vector operand_shapes; - int64 max_operand_rank = 0; - for (const HloInstruction* operand : map->operands()) { - operand_shapes.push_back(&operand->shape()); - max_operand_rank = - std::max(max_operand_rank, ShapeUtil::Rank(operand->shape())); - } - // TODO(b/65689298) Remove code below once Map is generalized to accept - // arbitrary map dimensions. - std::vector map_dims(max_operand_rank); - std::iota(map_dims.begin(), map_dims.end(), 0); - return CheckShape( - map, - ShapeInference::InferMapShape( - operand_shapes, map->to_apply()->ComputeProgramShape(), map_dims)); - } +Status ShapeVerifier::HandleReduceWindow(HloInstruction* reduce_window) { + return CheckShape( + reduce_window, + ShapeInference::InferReduceWindowShape( + reduce_window->operand(0)->shape(), + reduce_window->operand(1)->shape(), reduce_window->window(), + reduce_window->to_apply()->ComputeProgramShape())); +} - Status HandleReduceWindow(HloInstruction* reduce_window) override { - return CheckShape( - reduce_window, - ShapeInference::InferReduceWindowShape( - reduce_window->operand(0)->shape(), - reduce_window->operand(1)->shape(), reduce_window->window(), - reduce_window->to_apply()->ComputeProgramShape())); - } +Status ShapeVerifier::HandleSelectAndScatter(HloInstruction* instruction) { + return CheckShape( + instruction, + ShapeInference::InferSelectAndScatterShape( + instruction->operand(0)->shape(), + instruction->select()->ComputeProgramShape(), instruction->window(), + instruction->operand(1)->shape(), instruction->operand(2)->shape(), + instruction->scatter()->ComputeProgramShape())); +} - Status HandleSelectAndScatter(HloInstruction* instruction) override { - return CheckShape( - instruction, - ShapeInference::InferSelectAndScatterShape( - instruction->operand(0)->shape(), - instruction->select()->ComputeProgramShape(), instruction->window(), - instruction->operand(1)->shape(), instruction->operand(2)->shape(), - instruction->scatter()->ComputeProgramShape())); - } +Status ShapeVerifier::HandleWhile(HloInstruction* xla_while) { + // The shape of kWhile should match the shape of the body computation it + // calls. + return CheckShape(xla_while, + xla_while->while_body()->ComputeProgramShape().result()); +} - Status HandleWhile(HloInstruction* xla_while) override { - // The shape of kWhile should match the shape of the body computation it - // calls. - return CheckShape(xla_while, - xla_while->while_body()->ComputeProgramShape().result()); - } +Status ShapeVerifier::HandleConditional(HloInstruction* conditional) { + TF_RETURN_IF_ERROR(CheckShape( + conditional, + conditional->true_computation()->ComputeProgramShape().result())); + return CheckShape( + conditional, + conditional->false_computation()->ComputeProgramShape().result()); +} - Status HandleConditional(HloInstruction* conditional) override { - TF_RETURN_IF_ERROR(CheckShape( - conditional, - conditional->true_computation()->ComputeProgramShape().result())); - return CheckShape( - conditional, - conditional->false_computation()->ComputeProgramShape().result()); - } +Status ShapeVerifier::HandlePad(HloInstruction* pad) { + return CheckShape(pad, ShapeInference::InferPadShape(pad->operand(0)->shape(), + pad->operand(1)->shape(), + pad->padding_config())); +} - Status HandlePad(HloInstruction* pad) override { - return CheckShape(pad, - ShapeInference::InferPadShape(pad->operand(0)->shape(), - pad->operand(1)->shape(), - pad->padding_config())); - } +Status ShapeVerifier::HandleSend(HloInstruction* send) { + TF_RET_CHECK(send->users().size() == 1); + const HloInstruction* send_done = send->users().front(); + TF_RET_CHECK(send_done->opcode() == HloOpcode::kSendDone); + TF_RETURN_IF_ERROR(CheckSameChannel(send, send_done)); + return CheckShape( + send, ShapeUtil::MakeTupleShape( + {send->operand(0)->shape(), ShapeUtil::MakeShape(U32, {})})); +} - Status HandleSend(HloInstruction* send) override { - TF_RET_CHECK(send->users().size() == 1); - const HloInstruction* send_done = send->users().front(); - TF_RET_CHECK(send_done->opcode() == HloOpcode::kSendDone); - TF_RETURN_IF_ERROR(CheckSameChannel(send, send_done)); - return CheckShape( - send, ShapeUtil::MakeTupleShape( - {send->operand(0)->shape(), ShapeUtil::MakeShape(U32, {})})); - } +Status ShapeVerifier::HandleSendDone(HloInstruction* send_done) { + TF_RET_CHECK(send_done->operands().size() == 1); + const HloInstruction* send = send_done->operand(0); + TF_RET_CHECK(send->opcode() == HloOpcode::kSend); + TF_RETURN_IF_ERROR(CheckSameChannel(send, send_done)); + return CheckShape(send_done, ShapeUtil::MakeNil()); +} - Status HandleSendDone(HloInstruction* send_done) override { - TF_RET_CHECK(send_done->operands().size() == 1); - const HloInstruction* send = send_done->operand(0); - TF_RET_CHECK(send->opcode() == HloOpcode::kSend); - TF_RETURN_IF_ERROR(CheckSameChannel(send, send_done)); - return CheckShape(send_done, ShapeUtil::MakeNil()); - } +Status ShapeVerifier::HandleRecv(HloInstruction* recv) { + TF_RET_CHECK(recv->users().size() == 1); + const HloInstruction* recv_done = recv->users().front(); + TF_RET_CHECK(recv_done->opcode() == HloOpcode::kRecvDone); + TF_RETURN_IF_ERROR(CheckSameChannel(recv, recv_done)); + return CheckShape(recv, + ShapeUtil::MakeTupleShape( + {recv_done->shape(), ShapeUtil::MakeShape(U32, {})})); +} - Status HandleRecv(HloInstruction* recv) override { - TF_RET_CHECK(recv->users().size() == 1); - const HloInstruction* recv_done = recv->users().front(); - TF_RET_CHECK(recv_done->opcode() == HloOpcode::kRecvDone); - TF_RETURN_IF_ERROR(CheckSameChannel(recv, recv_done)); - return CheckShape(recv, - ShapeUtil::MakeTupleShape( - {recv_done->shape(), ShapeUtil::MakeShape(U32, {})})); - } +Status ShapeVerifier::HandleRecvDone(HloInstruction* recv_done) { + TF_RET_CHECK(recv_done->operands().size() == 1); + const HloInstruction* recv = recv_done->operand(0); + TF_RET_CHECK(recv->opcode() == HloOpcode::kRecv); + TF_RETURN_IF_ERROR(CheckSameChannel(recv, recv_done)); + return CheckShape(recv_done, recv->shape().tuple_shapes(0)); +} - Status HandleRecvDone(HloInstruction* recv_done) override { - TF_RET_CHECK(recv_done->operands().size() == 1); - const HloInstruction* recv = recv_done->operand(0); - TF_RET_CHECK(recv->opcode() == HloOpcode::kRecv); - TF_RETURN_IF_ERROR(CheckSameChannel(recv, recv_done)); - return CheckShape(recv_done, recv->shape().tuple_shapes(0)); - } +Status ShapeVerifier::HandleBatchNormTraining( + HloInstruction* batch_norm_training) { + return CheckShape(batch_norm_training, + ShapeInference::InferBatchNormTrainingShape( + batch_norm_training->operand(0)->shape(), + batch_norm_training->operand(1)->shape(), + batch_norm_training->operand(2)->shape(), + batch_norm_training->feature_index())); +} - Status HandleBatchNormTraining(HloInstruction* batch_norm_training) override { - return CheckShape(batch_norm_training, - ShapeInference::InferBatchNormTrainingShape( - batch_norm_training->operand(0)->shape(), - batch_norm_training->operand(1)->shape(), - batch_norm_training->operand(2)->shape(), - batch_norm_training->feature_index())); - } +Status ShapeVerifier::HandleBatchNormInference( + HloInstruction* batch_norm_inference) { + return CheckShape(batch_norm_inference, + ShapeInference::InferBatchNormInferenceShape( + batch_norm_inference->operand(0)->shape(), + batch_norm_inference->operand(1)->shape(), + batch_norm_inference->operand(2)->shape(), + batch_norm_inference->operand(3)->shape(), + batch_norm_inference->operand(4)->shape(), + batch_norm_inference->feature_index())); +} - Status HandleBatchNormInference( - HloInstruction* batch_norm_inference) override { - return CheckShape(batch_norm_inference, - ShapeInference::InferBatchNormInferenceShape( - batch_norm_inference->operand(0)->shape(), - batch_norm_inference->operand(1)->shape(), - batch_norm_inference->operand(2)->shape(), - batch_norm_inference->operand(3)->shape(), - batch_norm_inference->operand(4)->shape(), - batch_norm_inference->feature_index())); - } +Status ShapeVerifier::HandleBatchNormGrad(HloInstruction* batch_norm_grad) { + return CheckShape(batch_norm_grad, ShapeInference::InferBatchNormGradShape( + batch_norm_grad->operand(0)->shape(), + batch_norm_grad->operand(1)->shape(), + batch_norm_grad->operand(2)->shape(), + batch_norm_grad->operand(3)->shape(), + batch_norm_grad->operand(4)->shape(), + batch_norm_grad->feature_index())); +} - Status HandleBatchNormGrad(HloInstruction* batch_norm_grad) override { - return CheckShape(batch_norm_grad, ShapeInference::InferBatchNormGradShape( - batch_norm_grad->operand(0)->shape(), - batch_norm_grad->operand(1)->shape(), - batch_norm_grad->operand(2)->shape(), - batch_norm_grad->operand(3)->shape(), - batch_norm_grad->operand(4)->shape(), - batch_norm_grad->feature_index())); +Status ShapeVerifier::CheckShape(const HloInstruction* instruction, + const Shape& expected_shape) { + if (!ShapeUtil::Compatible(instruction->shape(), expected_shape)) { + return InvalidArgument( + "Expected instruction to have shape compatible with %s, actual " + "shape is %s:\n%s", + ShapeUtil::HumanString(expected_shape).c_str(), + ShapeUtil::HumanString(instruction->shape()).c_str(), + instruction->ToString().c_str()); } + return tensorflow::Status::OK(); +} - Status FinishVisit(HloInstruction*) override { - return tensorflow::Status::OK(); +Status ShapeVerifier::CheckShape(const HloInstruction* instruction, + const StatusOr& expected_shape_status) { + if (!expected_shape_status.ok()) { + Status s = expected_shape_status.status(); + tensorflow::errors::AppendToMessage(&s, ", for instruction ", + instruction->ToString()); + return s; } + return CheckShape(instruction, expected_shape_status.ValueOrDie()); +} - private: - // Check the instruction's shape against the given expected shape and return - // an appropriate error if there is a mismatch. - Status CheckShape(const HloInstruction* instruction, - const Shape& expected_shape) { - if (!ShapeUtil::Compatible(instruction->shape(), expected_shape)) { - return InvalidArgument( - "Expected instruction to have shape compatible with %s, actual " - "shape is %s:\n%s", - ShapeUtil::HumanString(expected_shape).c_str(), - ShapeUtil::HumanString(instruction->shape()).c_str(), - instruction->ToString().c_str()); - } - return tensorflow::Status::OK(); - } +Status ShapeVerifier::CheckUnaryShape(const HloInstruction* instruction) { + return CheckShape(instruction, + ShapeInference::InferUnaryOpShape(instruction->opcode(), + instruction->operand(0))); +} - // Overload which takes a StatusOr to reduce boilerplate in the caller. - Status CheckShape(const HloInstruction* instruction, - const StatusOr& expected_shape_status) { - if (!expected_shape_status.ok()) { - Status s = expected_shape_status.status(); - tensorflow::errors::AppendToMessage(&s, ", for instruction ", - instruction->ToString()); - return s; - } - return CheckShape(instruction, expected_shape_status.ValueOrDie()); - } +Status ShapeVerifier::CheckBinaryShape(const HloInstruction* instruction) { + return CheckShape( + instruction, ShapeInference::InferBinaryOpShape(instruction->opcode(), + instruction->operand(0), + instruction->operand(1))); +} - // Check a unary (binary, etc) instruction's shape against the inferred shape. - Status CheckUnaryShape(const HloInstruction* instruction) { - return CheckShape(instruction, - ShapeInference::InferUnaryOpShape( - instruction->opcode(), instruction->operand(0))); - } - Status CheckBinaryShape(const HloInstruction* instruction) { - return CheckShape(instruction, - ShapeInference::InferBinaryOpShape( - instruction->opcode(), instruction->operand(0), - instruction->operand(1))); - } - Status CheckTernaryShape(const HloInstruction* instruction) { - return CheckShape(instruction, - ShapeInference::InferTernaryOpShape( - instruction->opcode(), instruction->operand(0), - instruction->operand(1), instruction->operand(2))); - } - Status CheckVariadicShape(const HloInstruction* instruction) { - return CheckShape(instruction, - ShapeInference::InferVariadicOpShape( - instruction->opcode(), instruction->operands())); - } +Status ShapeVerifier::CheckTernaryShape(const HloInstruction* instruction) { + return CheckShape(instruction, + ShapeInference::InferTernaryOpShape( + instruction->opcode(), instruction->operand(0), + instruction->operand(1), instruction->operand(2))); +} - // Checks if the given two instructions shares the same channel id. - Status CheckSameChannel(const HloInstruction* instr1, - const HloInstruction* instr2) { - if (instr1->channel_id() != instr2->channel_id()) { - return FailedPrecondition( - "Expected to have the same channel id, actual channel ids are: %s " - "(%lld), %s (%lld)", - instr1->ToString().c_str(), instr1->channel_id(), - instr2->ToString().c_str(), instr2->channel_id()); - } - return tensorflow::Status::OK(); - } +Status ShapeVerifier::CheckVariadicShape(const HloInstruction* instruction) { + return CheckShape(instruction, + ShapeInference::InferVariadicOpShape( + instruction->opcode(), instruction->operands())); +} - // Returns the size of a Shape in bytes. - const std::function shape_size_fn_; -}; +// Checks if the given two instructions shares the same channel id. +Status ShapeVerifier::CheckSameChannel(const HloInstruction* instr1, + const HloInstruction* instr2) { + if (instr1->channel_id() != instr2->channel_id()) { + return FailedPrecondition( + "Expected to have the same channel id, actual channel ids are: %s " + "(%lld), %s (%lld)", + instr1->ToString().c_str(), instr1->channel_id(), + instr2->ToString().c_str(), instr2->channel_id()); + } + return tensorflow::Status::OK(); +} string ComputationsToString( tensorflow::gtl::ArraySlice computations) { @@ -429,7 +417,62 @@ string ComputationsToString( }); } -} // namespace +// Verifies various invariants about the structure of the HLO: +// +// (1) each instruction has a non-null parent() set to the HloComputation which +// contains it. +// +// (2) each computation has a non-null parent() set to the HloModule which +// contains it. +// +// (3) the operands of each instruction are in the same computation as the +// instruction. +Status VerifyHloStructure(HloModule* module) { + for (const HloComputation* computation : module->computations()) { + if (computation->parent() == nullptr) { + return FailedPrecondition("Computation %s has a null parent pointer", + computation->name().c_str()); + } + if (computation->parent() != module) { + return FailedPrecondition( + "Computation %s parent() does not point to parent module", + computation->name().c_str()); + } + + for (const HloInstruction* instruction : computation->instructions()) { + if (instruction->parent() == nullptr) { + return FailedPrecondition("Instruction %s has a null parent pointer", + instruction->name().c_str()); + } + if (instruction->parent() != computation) { + return FailedPrecondition( + "Instruction %s parent() does not point to parent computation", + instruction->name().c_str()); + } + } + } + + // Check that operands are in the same computation separately from verifying + // parent() correctness so conditions like a null HloInstruction::parent() are + // identified and reported explicitly above rather than reporting a mismatched + // operand. + for (const HloComputation* computation : module->computations()) { + for (const HloInstruction* instruction : computation->instructions()) { + for (int i = 0; i < instruction->operand_count(); ++i) { + const HloInstruction* operand = instruction->operand(i); + if (operand->parent() != instruction->parent()) { + return FailedPrecondition( + "Operand %d (%s) of instruction %s is in a different " + "computation: %s vs %s", + i, operand->name().c_str(), instruction->name().c_str(), + operand->parent()->name().c_str(), + instruction->parent()->name().c_str()); + } + } + } + } + return tensorflow::Status::OK(); +} Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const { // The parent fusion instruction of the fusion computation must be 'fusion'. @@ -549,8 +592,9 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const { } StatusOr HloVerifier::Run(HloModule* module) { + TF_RETURN_IF_ERROR(VerifyHloStructure(module)); + tensorflow::gtl::FlatMap instructions; - ShapeVerifier shape_verifier(shape_size_fn_); for (auto* computation : module->computations()) { for (const auto& instruction : computation->instructions()) { @@ -630,7 +674,7 @@ StatusOr HloVerifier::Run(HloModule* module) { instructions[instruction->name()] = instruction; } - TF_RETURN_IF_ERROR(computation->Accept(&shape_verifier)); + TF_RETURN_IF_ERROR(computation->Accept(shape_verifier_.get())); } return false; diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h index e35a7f3642ccf91df37f69a3a11bd8c8e428b846..6368611f323ad7c1ebade4941260e12ed2c6e45f 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.h +++ b/tensorflow/compiler/xla/service/hlo_verifier.h @@ -18,14 +18,98 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/service/shape_inference.h" + namespace xla { +// Visitor which verifies that the output shape is correctly set. Verifies +// against the inferred shape for the instruction. +// TODO(b/26024837): Check output shape for all instruction types. +class ShapeVerifier : public DfsHloVisitor { + public: + Status HandleElementwiseUnary(HloInstruction* hlo) override; + Status HandleElementwiseBinary(HloInstruction* hlo) override; + Status HandleClamp(HloInstruction* clamp) override; + Status HandleSelect(HloInstruction* select) override; + Status HandleConcatenate(HloInstruction* concatenate) override; + Status HandleConvert(HloInstruction* convert) override; + Status HandleBitcastConvert(HloInstruction* convert) override; + Status HandleCopy(HloInstruction* copy) override; + Status HandleDot(HloInstruction* dot) override; + Status HandleConvolution(HloInstruction* convolution) override; + Status HandleFft(HloInstruction* fft) override; + Status HandleCrossReplicaSum(HloInstruction* crs) override; + Status HandleReducePrecision(HloInstruction* reduce_precision) override; + Status HandleInfeed(HloInstruction*) override; + Status HandleOutfeed(HloInstruction*) override; + Status HandleRng(HloInstruction*) override; + Status HandleReverse(HloInstruction* reverse) override; + Status HandleSort(HloInstruction* sort) override; + Status HandleConstant(HloInstruction* constant) override; + Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; + Status HandleReduce(HloInstruction* reduce) override; + Status HandleBitcast(HloInstruction* bitcast) override; + Status HandleBroadcast(HloInstruction* broadcast) override; + Status HandleReshape(HloInstruction* reshape) override; + Status HandleTranspose(HloInstruction* transpose) override; + Status HandleParameter(HloInstruction*) override; + Status HandleFusion(HloInstruction*) override; + Status HandleCall(HloInstruction* call) override; + Status HandleCustomCall(HloInstruction*) override; + Status HandleSlice(HloInstruction* slice) override; + Status HandleDynamicSlice(HloInstruction* dynamic_slice) override; + Status HandleDynamicUpdateSlice( + HloInstruction* dynamic_update_slice) override; + Status HandleTuple(HloInstruction* tuple) override; + Status HandleMap(HloInstruction* map) override; + Status HandleReduceWindow(HloInstruction* reduce_window) override; + Status HandleSelectAndScatter(HloInstruction* instruction) override; + Status HandleWhile(HloInstruction* xla_while) override; + Status HandleConditional(HloInstruction* conditional) override; + Status HandlePad(HloInstruction* pad) override; + Status HandleSend(HloInstruction* send) override; + Status HandleSendDone(HloInstruction* send_done) override; + Status HandleRecv(HloInstruction* recv) override; + Status HandleRecvDone(HloInstruction* recv_done) override; + Status HandleBatchNormTraining(HloInstruction* batch_norm_training) override; + Status HandleBatchNormInference( + HloInstruction* batch_norm_inference) override; + Status HandleBatchNormGrad(HloInstruction* batch_norm_grad) override; + + Status FinishVisit(HloInstruction*) override { + return tensorflow::Status::OK(); + } + + protected: + // Check the instruction's shape against the given expected shape and return + // an appropriate error if there is a mismatch. + Status CheckShape(const HloInstruction* instruction, + const Shape& expected_shape); + + // Overload which takes a StatusOr to reduce boilerplate in the caller. + Status CheckShape(const HloInstruction* instruction, + const StatusOr& expected_shape_status); + + // Check a unary (binary, etc) instruction's shape against the inferred shape. + Status CheckUnaryShape(const HloInstruction* instruction); + Status CheckBinaryShape(const HloInstruction* instruction); + Status CheckTernaryShape(const HloInstruction* instruction); + Status CheckVariadicShape(const HloInstruction* instruction); + + // Checks if the given two instructions shares the same channel id. + Status CheckSameChannel(const HloInstruction* instr1, + const HloInstruction* instr2); +}; + // HLO pass that verifies invariants of HLO instructions for each computation in // the module. class HloVerifier : public HloPassInterface { public: - explicit HloVerifier(const std::function& shape_size_fn) - : shape_size_fn_(shape_size_fn) {} + // Uses standard shape inference. + explicit HloVerifier() : shape_verifier_(MakeUnique()) {} + // Uses custom shape verification. + explicit HloVerifier(std::unique_ptr shape_verifier) + : shape_verifier_(std::move(shape_verifier)) {} ~HloVerifier() override = default; tensorflow::StringPiece name() const override { return "verifier"; } @@ -37,8 +121,8 @@ class HloVerifier : public HloPassInterface { // CHECKs various invariants of a fusion instruction. Status CheckFusionInstruction(HloInstruction* fusion) const; - // Returns the size of a Shape in bytes. - const std::function shape_size_fn_; + // Verifies shapes match inferred expectations. + std::unique_ptr shape_verifier_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_verifier_test.cc b/tensorflow/compiler/xla/service/hlo_verifier_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..2a3b55decc5289e7e576d3c5897b333c0b1bc922 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_verifier_test.cc @@ -0,0 +1,101 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_verifier.h" + +#include +#include + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { +namespace { + +using ::testing::HasSubstr; + +using HloVerifierTest = HloTestBase; + +TEST_F(HloVerifierTest, NullInstructionParent) { + HloComputation::Builder builder(TestName()); + const Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "param")); + HloInstruction* negate = builder.AddInstruction( + HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, param)); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + + TF_ASSERT_OK(verifier().Run(module.get()).status()); + + negate->set_parent(nullptr); + + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), HasSubstr("has a null parent pointer")); +} + +TEST_F(HloVerifierTest, NullComputationParent) { + HloComputation::Builder builder(TestName()); + const Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "param")); + builder.AddInstruction( + HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, param)); + auto module = CreateNewModule(); + HloComputation* computation = module->AddEntryComputation(builder.Build()); + + TF_ASSERT_OK(verifier().Run(module.get()).status()); + + computation->set_parent(nullptr); + + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), HasSubstr("has a null parent pointer")); +} + +TEST_F(HloVerifierTest, DifferentOperandParents) { + HloComputation::Builder builder(TestName()); + const Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "param")); + HloInstruction* negate = builder.AddInstruction( + HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, param)); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + + HloComputation::Builder emb_builder(TestName()); + HloInstruction* emb_param = emb_builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "param")); + module->AddEmbeddedComputation(emb_builder.Build()); + + TF_ASSERT_OK(verifier().Run(module.get()).status()); + TF_ASSERT_OK(negate->ReplaceOperandWith(0, emb_param)); + + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), + HasSubstr("is in a different computation")); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/human_readable_profile_builder.cc b/tensorflow/compiler/xla/service/human_readable_profile_builder.cc index b7c40fdeeb157fc74900bd9cf9d68a06a2cb1d56..13e4557317f74b3fb46f07fb91c339fd2f34752f 100644 --- a/tensorflow/compiler/xla/service/human_readable_profile_builder.cc +++ b/tensorflow/compiler/xla/service/human_readable_profile_builder.cc @@ -25,6 +25,7 @@ namespace xla { using tensorflow::strings::Appendf; using tensorflow::strings::HumanReadableElapsedTime; using tensorflow::strings::HumanReadableNumBytes; +using tensorflow::strings::Printf; using tensorflow::strings::StrAppend; string HumanReadableProfileBuilder::ToString() const { @@ -43,7 +44,12 @@ string HumanReadableProfileBuilder::ToString() const { } else { bytes_per_sec = HumanReadableNumBytes(op.bytes_accessed / CyclesToSeconds(op.cycles)); - bytes_per_cycle = HumanReadableNumBytes(op.bytes_accessed / op.cycles); + if (op.bytes_accessed > op.cycles) { + bytes_per_cycle = HumanReadableNumBytes(op.bytes_accessed / op.cycles); + } else { + bytes_per_cycle = + Printf("%.3fB", static_cast(op.bytes_accessed) / op.cycles); + } } double cycles_percent = 0; diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index ba901b99e4f3c72c84c1ecdf4e19e58ad9ab6506..90e1f0acdc4cdeda280dabaab2df66b181d0f407 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -100,6 +100,7 @@ namespace xla { case HloOpcode::kDivide: case HloOpcode::kDot: case HloOpcode::kExp: + case HloOpcode::kFft: case HloOpcode::kFusion: case HloOpcode::kLog: case HloOpcode::kMap: diff --git a/tensorflow/compiler/xla/service/interpreter/BUILD b/tensorflow/compiler/xla/service/interpreter/BUILD index 2704a805a91b93c69b751cdb61305ea7780f0ef2..0819ab3b90b2360c6b0b2afaa89f322afe566eb3 100644 --- a/tensorflow/compiler/xla/service/interpreter/BUILD +++ b/tensorflow/compiler/xla/service/interpreter/BUILD @@ -92,6 +92,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_execution_profile", "//tensorflow/compiler/xla/service:hlo_module_config", "//tensorflow/compiler/xla/service:shaped_buffer", + "//tensorflow/compiler/xla/service:transfer_manager", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", ], diff --git a/tensorflow/compiler/xla/service/interpreter/executable.cc b/tensorflow/compiler/xla/service/interpreter/executable.cc index 9183a1d1bfb8c2f6e1933c004f9c9f5f9ad8eced..0cb9b5d8107cd8bf468b07d5fe2a22930d9e8b8c 100644 --- a/tensorflow/compiler/xla/service/interpreter/executable.cc +++ b/tensorflow/compiler/xla/service/interpreter/executable.cc @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_evaluator.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/interpreter/executor.h" +#include "tensorflow/compiler/xla/service/transfer_manager.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/lib/core/errors.h" @@ -38,7 +39,6 @@ namespace xla { namespace interpreter { namespace se = ::perftools::gputools; -namespace sep = ::perftools::gputools::interpreter; InterpreterExecutable::InterpreterExecutable( std::unique_ptr hlo_module) @@ -47,44 +47,18 @@ InterpreterExecutable::InterpreterExecutable( InterpreterExecutable::~InterpreterExecutable() {} -static se::DeviceMemoryBase AllocateSingleOutput( - sep::InterpreterExecutor* executor, const Literal& literal) { - int64 size(xla::ShapeUtil::ByteSizeOf(literal.shape())); - void* buf = executor->Allocate(size); - const void* src = literal.InternalData(); - memcpy(buf, src, size); - return se::DeviceMemoryBase(buf, size); -} - -static se::DeviceMemoryBase AllocateOutputBuffer( - sep::InterpreterExecutor* executor, const Literal& literal) { - const Shape& shape = literal.shape(); - if (shape.element_type() != xla::TUPLE) { - return AllocateSingleOutput(executor, literal); - } else { - int64 size(xla::ShapeUtil::ByteSizeOf(shape, sizeof(void*))); - void** buf = reinterpret_cast(executor->Allocate(size)); - void** buf_rc = buf; - for (int64 n = 0; n < xla::ShapeUtil::TupleElementCount(shape); n++) { - se::DeviceMemoryBase out = - AllocateSingleOutput(executor, literal.tuple_literals(n)); - *buf++ = out.opaque(); - } - - return se::DeviceMemoryBase(buf_rc, size); - } -} - -StatusOr InterpreterExecutable::ExecuteOnStream( +StatusOr> InterpreterExecutable::ExecuteOnStream( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments, + tensorflow::gtl::ArraySlice arguments, HloExecutionProfile* hlo_execution_profile) { se::Stream* stream = run_options->stream(); + se::StreamExecutor* executor = stream->parent(); + const se::Platform* platform = executor->platform(); VLOG(1) << "Execute " << module().name(); if (VLOG_IS_ON(2)) { for (const auto& a : arguments) { - VLOG(2) << "-- argument " << a.opaque(); + VLOG(2) << "-- argument " << *a; } } @@ -96,33 +70,32 @@ StatusOr InterpreterExecutable::ExecuteOnStream( "Mismatch between argument count and graph parameter count."); } - // Create the arguments as an vector of XLA literals + TF_ASSIGN_OR_RETURN(TransferManager * transfer_manager, + TransferManager::GetForPlatform(platform)); + + // Transform the ShapedBuffer arguments into literals which the evaluator + // consumes. std::vector> arg_literals; - std::vector arg_literals_ptrs; for (int64 p = 0; p < computation->num_parameters(); ++p) { - // Create the input literal for the parameter - HloInstruction* param = computation->parameter_instruction(p); - arg_literals.emplace_back(Literal::CreateFromShape(param->shape())); - arg_literals_ptrs.push_back(arg_literals.back().get()); - - // Copy in the data from the stream_executor buffers - void* buffer = arg_literals.back()->MutableInternalData(); - memcpy(buffer, arguments[p].opaque(), - ShapeUtil::ByteSizeOf(param->shape())); + TF_ASSIGN_OR_RETURN( + std::unique_ptr arg_literal, + transfer_manager->TransferLiteralFromDevice(executor, *arguments[p])); + arg_literals.push_back(std::move(arg_literal)); } // Execute the graph using the HloEvaluator. HloEvaluator evaluator; - TF_ASSIGN_OR_RETURN(std::unique_ptr output, - evaluator.Evaluate(*computation, arg_literals_ptrs)); - - // Copy the result into the return buffer - perftools::gputools::StreamExecutor* executor(stream->parent()); - sep::InterpreterExecutor* interpreter_executor( - static_cast(executor->implementation())); - - se::DeviceMemoryBase ret = - AllocateOutputBuffer(interpreter_executor, *(output.get())); + TF_ASSIGN_OR_RETURN( + std::unique_ptr result_literal, + evaluator.Evaluate>(*computation, arg_literals)); + + // Transform the result literal back into a ShapedBuffer. + TF_ASSIGN_OR_RETURN(std::unique_ptr result, + transfer_manager->AllocateShapedBuffer( + result_literal->shape(), run_options->allocator(), + run_options->device_ordinal())); + TF_RETURN_IF_ERROR(transfer_manager->TransferLiteralToDevice( + executor, *result_literal, *result)); uint64 end_micros = tensorflow::Env::Default()->NowMicros(); @@ -132,20 +105,13 @@ StatusOr InterpreterExecutable::ExecuteOnStream( execution_profile_.set_compute_time_ns(std::max(nanoseconds, 1.0)); } - return ret; -} - -StatusOr> InterpreterExecutable::ExecuteOnStream( - const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments, - HloExecutionProfile* hlo_execution_profile) { - return tensorflow::errors::Unimplemented( - "ExecuteOnStream is not yet supported on Interpreter."); + return std::move(result); } -StatusOr InterpreterExecutable::ExecuteAsyncOnStream( +StatusOr> +InterpreterExecutable::ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments) { + tensorflow::gtl::ArraySlice arguments) { return tensorflow::errors::Unimplemented( "ExecuteAsyncOnStream is not yet supported on Interpreter."); } diff --git a/tensorflow/compiler/xla/service/interpreter/executable.h b/tensorflow/compiler/xla/service/interpreter/executable.h index 0e87eb90bff4b896fc4bc0efc4fa7b851631be6f..410110a1adf04c83001c38ed03f5d60dd203dc7e 100644 --- a/tensorflow/compiler/xla/service/interpreter/executable.h +++ b/tensorflow/compiler/xla/service/interpreter/executable.h @@ -43,21 +43,14 @@ class InterpreterExecutable : public Executable { InterpreterExecutable(std::unique_ptr hlo_module); ~InterpreterExecutable() override; - StatusOr ExecuteOnStream( - const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice - arguments, - HloExecutionProfile* hlo_execution_profile) override; - StatusOr> ExecuteOnStream( const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments, HloExecutionProfile* hlo_execution_profile) override; - StatusOr ExecuteAsyncOnStream( + StatusOr> ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice - arguments) override; + tensorflow::gtl::ArraySlice arguments) override; static int64 ShapeSizeBytes(const Shape& shape); diff --git a/tensorflow/compiler/xla/service/interpreter/executor.cc b/tensorflow/compiler/xla/service/interpreter/executor.cc index 0bb3259ef43915067e614e72038387e8300ecc41..68371910d76f42c0b6d4b1adad9d6a83bdb858e6 100644 --- a/tensorflow/compiler/xla/service/interpreter/executor.cc +++ b/tensorflow/compiler/xla/service/interpreter/executor.cc @@ -85,7 +85,7 @@ bool InterpreterExecutor::HostCallback(Stream *stream, bool InterpreterExecutor::CreateStreamDependency(Stream *dependent, Stream *other) { AsExecutorStream(dependent)->EnqueueTask( - [other]() { other->BlockHostUntilDone(); }); + [other]() { SE_CHECK_OK(other->BlockHostUntilDone()); }); AsExecutorStream(dependent)->BlockUntilDone(); return true; } @@ -100,9 +100,9 @@ bool InterpreterExecutor::StopTimer(Stream *stream, Timer *timer) { return true; } -bool InterpreterExecutor::BlockHostUntilDone(Stream *stream) { +port::Status InterpreterExecutor::BlockHostUntilDone(Stream *stream) { AsExecutorStream(stream)->BlockUntilDone(); - return true; + return port::Status::OK(); } DeviceDescription *InterpreterExecutor::PopulateDeviceDescription() const { diff --git a/tensorflow/compiler/xla/service/interpreter/executor.h b/tensorflow/compiler/xla/service/interpreter/executor.h index c59b2ccb1505b78be0c459ac9311428d65cc7e44..c5d07e906dafb033905c50c604069e80e1ce80cd 100644 --- a/tensorflow/compiler/xla/service/interpreter/executor.h +++ b/tensorflow/compiler/xla/service/interpreter/executor.h @@ -157,7 +157,7 @@ class InterpreterExecutor : public internal::StreamExecutorInterface { bool StartTimer(Stream *stream, Timer *timer) override; bool StopTimer(Stream *stream, Timer *timer) override; - bool BlockHostUntilDone(Stream *stream) override; + port::Status BlockHostUntilDone(Stream *stream) override; int PlatformDeviceCount() override { return 1; } diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index 7eda7c2284c2457703fcfcd4226172e41dd4ae01..f80dace8775c5ed31addb4a3d134f53005c6df71 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -369,8 +369,9 @@ string LayoutConstraints::ToString() const { } Status LayoutAssignment::AddMandatoryConstraints( - const ComputationLayout& computation_layout, HloComputation* computation, - LayoutConstraints* constraints) { + const ComputationLayout& computation_layout, + const ChannelLayoutConstraints* channel_constraints, + HloComputation* computation, LayoutConstraints* constraints) { VLOG(3) << "Adding mandatory layout constraints to computation " << computation->name(); @@ -403,6 +404,37 @@ Status LayoutAssignment::AddMandatoryConstraints( TF_RETURN_IF_ERROR( constraints->SetInstructionLayout(*shape_with_layout, instruction)); } + + if (instruction->opcode() == HloOpcode::kSend || + instruction->opcode() == HloOpcode::kRecv) { + CHECK(channel_constraints) + << "Multi-module layout assignment requires ChannelLayoutConstraints"; + int64 channel_id = instruction->channel_id(); + if (!channel_constraints->IsChannelConstrained(channel_id)) { + continue; + } + if (instruction->opcode() == HloOpcode::kSend) { + // TODO(b/68493863): Change to use SetOperandLayout(). + const Shape send_buffer_shape = instruction->operand(0)->shape(); + TF_RET_CHECK(ShapeUtil::IsArray(send_buffer_shape)); + Shape new_buffer_shape = channel_constraints->LayoutShapeForChannel( + send_buffer_shape, instruction->channel_id()); + TF_RETURN_IF_ERROR(constraints->SetInstructionLayout( + new_buffer_shape, instruction->operand(0))); + } else { + const Shape recv_buffer_shape = + ShapeUtil::GetTupleElementShape(instruction->shape(), 0); + TF_RET_CHECK(ShapeUtil::IsArray(recv_buffer_shape)); + TF_ASSIGN_OR_RETURN( + const LogicalBuffer* buffer, + constraints->points_to_analysis().GetBufferDefinedAt(instruction, + {0})); + Shape new_shape = channel_constraints->LayoutShapeForChannel( + recv_buffer_shape, instruction->channel_id()); + TF_RETURN_IF_ERROR(constraints->SetBufferLayout( + new_shape.layout(), *buffer, /*mandatory=*/true)); + } + } } // Constrain layouts of instructions which call computations which have @@ -476,17 +508,14 @@ Status LayoutAssignment::AddMandatoryConstraints( body_layout.result_shape(), instruction, 0, /*mandatory=*/true)); } else if (instruction->opcode() == HloOpcode::kCustomCall) { + if (!CustomCallRequiresMajorFirstLayout(instruction)) { + continue; + } // Add constraints for kCustomCall instruction operands and instructions. - // For now we only support row major layouts for all inputs and outputs. - auto row_major_shape = [](const Shape& old_shape) { - Shape new_shape(old_shape); - std::vector dimension_order(new_shape.dimensions_size()); - std::iota(dimension_order.rbegin(), dimension_order.rend(), 0); - *new_shape.mutable_layout() = LayoutUtil::MakeLayout(dimension_order); - return new_shape; - }; - - Shape result_shape(row_major_shape(instruction->shape())); + // For now we only support major-first layouts for all inputs and outputs. + Shape result_shape = ShapeUtil::MakeShapeWithDescendingLayout( + instruction->shape().element_type(), + AsInt64Slice(instruction->shape().dimensions())); TF_RETURN_IF_ERROR( constraints->SetInstructionLayout(result_shape, instruction)); for (int64 i = 0; i < instruction->operand_count(); ++i) { @@ -496,7 +525,10 @@ Status LayoutAssignment::AddMandatoryConstraints( continue; } - Shape row_major_operand_shape(row_major_shape(operand_shape)); + Shape row_major_operand_shape = + ShapeUtil::MakeShapeWithDescendingLayout( + operand_shape.element_type(), + AsInt64Slice(operand_shape.dimensions())); TF_RETURN_IF_ERROR(constraints->SetOperandLayout( row_major_operand_shape, instruction, i, /*mandatory=*/true)); } @@ -530,9 +562,11 @@ Status CheckCallLayout(HloInstruction* call, Status CheckCustomCallLayout(HloInstruction* custom_call) { for (const HloInstruction* operand : custom_call->operands()) { TF_RET_CHECK( + ShapeUtil::IsOpaque(operand->shape()) || LayoutUtil::IsMonotonicWithDim0Major(operand->shape().layout())); } TF_RET_CHECK( + ShapeUtil::IsOpaque(custom_call->shape()) || LayoutUtil::IsMonotonicWithDim0Major(custom_call->shape().layout())); return Status::OK(); } @@ -601,11 +635,9 @@ Status CheckConstantLayout(HloInstruction* constant) { return Status::OK(); } -// Check that all layouts in the module have been set and satisfy all necessary -// conditions. -Status CheckLayouts( - HloModule* module, - const std::map& computation_layouts) { +} // namespace + +Status LayoutAssignment::CheckLayouts(HloModule* module) { TF_ASSIGN_OR_RETURN(auto points_to_analysis, TuplePointsToAnalysis::Run(module)); for (auto* computation : module->MakeNonfusionComputations()) { @@ -649,10 +681,12 @@ Status CheckLayouts( case HloOpcode::kCall: TF_RETURN_IF_ERROR(CheckCallLayout( instruction, - FindOrDie(computation_layouts, instruction->to_apply()))); + FindOrDie(computation_layouts_, instruction->to_apply()))); break; case HloOpcode::kCustomCall: - TF_RETURN_IF_ERROR(CheckCustomCallLayout(instruction)); + if (CustomCallRequiresMajorFirstLayout(instruction)) { + TF_RETURN_IF_ERROR(CheckCustomCallLayout(instruction)); + } break; case HloOpcode::kFusion: TF_RETURN_IF_ERROR(CheckFusionLayout(instruction)); @@ -660,7 +694,7 @@ Status CheckLayouts( case HloOpcode::kParameter: TF_RETURN_IF_ERROR(CheckParameterLayout( instruction, - FindOrDie(computation_layouts, instruction->parent()))); + FindOrDie(computation_layouts_, instruction->parent()))); break; case HloOpcode::kConstant: TF_RETURN_IF_ERROR(CheckConstantLayout(instruction)); @@ -668,8 +702,8 @@ Status CheckLayouts( case HloOpcode::kWhile: TF_RETURN_IF_ERROR(CheckWhileLayout( instruction, - FindOrDie(computation_layouts, instruction->while_condition()), - FindOrDie(computation_layouts, instruction->while_body()))); + FindOrDie(computation_layouts_, instruction->while_condition()), + FindOrDie(computation_layouts_, instruction->while_body()))); break; default: break; @@ -681,17 +715,18 @@ Status CheckLayouts( // computation root. TF_RET_CHECK(ShapeUtil::Equal( module->entry_computation()->root_instruction()->shape(), - FindOrDie(computation_layouts, module->entry_computation()) + FindOrDie(computation_layouts_, module->entry_computation()) .result_layout() .shape())); return Status::OK(); } -} // namespace - -LayoutAssignment::LayoutAssignment(ComputationLayout* entry_computation_layout) - : entry_computation_layout_(entry_computation_layout) { +LayoutAssignment::LayoutAssignment( + ComputationLayout* entry_computation_layout, + ChannelLayoutConstraints* channel_constraints) + : entry_computation_layout_(entry_computation_layout), + channel_layout_constraints_(channel_constraints) { VLOG(1) << "entry computation layout given to layout assignment: " << entry_computation_layout_->ToString(); // Layouts of all parameter instructions must be set. @@ -711,8 +746,8 @@ std::unique_ptr LayoutAssignment::ChooseOperandLayoutFromOutputLayout( int64 operand_no) { const HloInstruction* operand = instruction->operand(operand_no); - CHECK(ShapeUtil::IsArray(instruction->shape()) && - ShapeUtil::IsArray(operand->shape())); + CHECK(ShapeUtil::IsArray(instruction->shape())); + CHECK(ShapeUtil::IsArray(operand->shape())); if (instruction->IsElementwiseOnOperand(operand_no) && !ShapeUtil::IsScalar(operand->shape()) && @@ -742,7 +777,7 @@ std::unique_ptr LayoutAssignment::ChooseOperandLayoutFromOutputLayout( const Shape& output_shape = instruction->shape(); Shape output_shape_with_layout = ShapeUtil::MakeShapeWithLayout( output_shape.element_type(), AsInt64Slice(output_shape.dimensions()), - AsInt64Slice(output_layout.minor_to_major())); + LayoutUtil::MinorToMajor(output_layout)); Shape operand_shape = operand->shape(); *operand_shape.mutable_layout() = LayoutUtil::GetDefaultLayoutForShape(operand_shape); @@ -771,7 +806,7 @@ std::unique_ptr LayoutAssignment::ChooseOperandLayoutFromOutputLayout( int64 rank = ShapeUtil::Rank(instruction->shape()); std::vector new_minor_to_major(rank); for (int64 i = 0; i < rank; ++i) { - int64 output_dim = output_layout.minor_to_major(i); + int64 output_dim = LayoutUtil::Minor(output_layout, i); int64 operand_dim = instruction->dimensions(output_dim); new_minor_to_major[i] = operand_dim; } @@ -814,7 +849,7 @@ std::unique_ptr LayoutAssignment::ChooseOutputLayoutFromOperandLayout( Shape operand_shape_with_layout = ShapeUtil::MakeShapeWithLayout( operand->shape().element_type(), AsInt64Slice(operand->shape().dimensions()), - AsInt64Slice(operand_layout.minor_to_major())); + LayoutUtil::MinorToMajor(operand_layout)); Shape output_shape = user->shape(); *output_shape.mutable_layout() = LayoutUtil::GetDefaultLayoutForShape(output_shape); @@ -844,7 +879,7 @@ std::unique_ptr LayoutAssignment::ChooseOutputLayoutFromOperandLayout( std::vector new_minor_to_major(rank); auto inverse_dimensions = InversePermutation(user->dimensions()); for (int64 i = 0; i < rank; ++i) { - int64 operand_dim = operand_layout.minor_to_major(i); + int64 operand_dim = LayoutUtil::Minor(operand_layout, i); int64 user_dim = inverse_dimensions[operand_dim]; new_minor_to_major[i] = user_dim; } @@ -1303,8 +1338,8 @@ Status LayoutAssignment::AssignLayouts(const LayoutConstraints& constraints, TF_RET_CHECK(LayoutUtil::HasLayout(instruction->shape())); } - // Copy the root instrucion's result if the it does not match the result - // layout constraint + // Copy the root instruction's result if its layout does not match the result + // layout constraint. if (constraints.ResultLayout() != nullptr && !constraints.ResultLayout()->MatchesLayoutInShape( computation->root_instruction()->shape())) { @@ -1321,20 +1356,35 @@ Status LayoutAssignment::AssignLayouts(const LayoutConstraints& constraints, Status LayoutAssignment::RunOnComputation( const ComputationLayout& computation_layout, const TuplePointsToAnalysis& points_to_analysis, - HloComputation* computation) { + HloComputation* computation, + ChannelLayoutConstraints* channel_constraints) { DCHECK(computation_layout.LayoutIsSet()); InsertOrDie(&computation_layouts_, computation, computation_layout); VLOG(2) << "LayoutAssignment::RunOnComputation(" << computation->name() << ")"; VLOG(2) << " ComputationLayout = " << computation_layout.ToString(); + // Clear existing layouts of the instructions. All layouts must be assigned by + // the LayoutAssignment pass, except for Infeed, Outfeed, Parameters and the + // computation result. The latter two are specified in computation_layout, so + // we only need to keep the existing layouts for Infeed and Outfeed. Clearing + // the layouts here avoids hiding potential bugs in the layout assignment pass + // that may accidently use the existing layout. + for (HloInstruction* instruction : computation->instructions()) { + if (instruction->opcode() == HloOpcode::kInfeed || + instruction->opcode() == HloOpcode::kOutfeed) { + continue; + } + LayoutUtil::ClearLayout(instruction->mutable_shape()); + } + // Construct LayoutConstraints with all layout constraints of the computation. LayoutConstraints constraints(points_to_analysis, computation); // Add constraints required for correctness on all backends (eg, entry // parameter layout constraints). - TF_RETURN_IF_ERROR( - AddMandatoryConstraints(computation_layout, computation, &constraints)); + TF_RETURN_IF_ERROR(AddMandatoryConstraints( + computation_layout, channel_constraints, computation, &constraints)); // Add any backend-specific constraints. TF_RETURN_IF_ERROR(AddBackendConstraints(&constraints)); @@ -1373,7 +1423,20 @@ Status LayoutAssignment::RunOnComputation( // All logical buffers should have constraints at this point. All that // remains is assign the constraints to the buffers and infer layouts for // aliased buffers. - return AssignLayouts(constraints, computation); + TF_RETURN_IF_ERROR(AssignLayouts(constraints, computation)); + + // Record the layouts assigned for any communication ops in + // channel_constraints so that they are constrained for future modules. + for (HloInstruction* instruction : computation->instructions()) { + if (instruction->opcode() == HloOpcode::kSend) { + channel_constraints->ConstrainChannel( + instruction->channel_id(), instruction->operand(0)->shape().layout()); + } else if (instruction->opcode() == HloOpcode::kRecvDone) { + channel_constraints->ConstrainChannel(instruction->channel_id(), + instruction->shape().layout()); + } + } + return Status::OK(); } StatusOr LayoutAssignment::Run(HloModule* module) { @@ -1393,9 +1456,9 @@ StatusOr LayoutAssignment::Run(HloModule* module) { // all callers of a computation will agree. for (auto* computation : module->MakeComputationPostOrder()) { if (computation == module->entry_computation()) { - TF_RETURN_IF_ERROR(RunOnComputation(*entry_computation_layout_, - *points_to_analysis, - module->entry_computation())); + TF_RETURN_IF_ERROR(RunOnComputation( + *entry_computation_layout_, *points_to_analysis, + module->entry_computation(), channel_layout_constraints_)); } else if (computation->IsFusionComputation()) { continue; } else { @@ -1404,11 +1467,12 @@ StatusOr LayoutAssignment::Run(HloModule* module) { // suboptimal. computation_layout.SetToDefaultLayout(); TF_RETURN_IF_ERROR(RunOnComputation(computation_layout, - *points_to_analysis, computation)); + *points_to_analysis, computation, + channel_layout_constraints_)); } } - TF_RETURN_IF_ERROR(CheckLayouts(module, computation_layouts_)); + TF_RETURN_IF_ERROR(CheckLayouts(module)); VLOG(3) << "After layout assignment:"; XLA_VLOG_LINES(3, module->ToString()); diff --git a/tensorflow/compiler/xla/service/layout_assignment.h b/tensorflow/compiler/xla/service/layout_assignment.h index 0b97fba744923b8afc3fb539566b68f1bca47d38..6bfae2998609c0482b91368f1891ce1e8e43fa23 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.h +++ b/tensorflow/compiler/xla/service/layout_assignment.h @@ -215,13 +215,62 @@ class LayoutConstraints { HloComputation* computation_; }; +// Contains constraints on the layout of channels; sends and recvs. +class ChannelLayoutConstraints { + public: + // Construct an empty constraint set. + ChannelLayoutConstraints() {} + + // Returns true if channel_id has a layout constraint. + bool IsChannelConstrained(int64 channel_id) const { + return constraints_.count(channel_id) > 0; + } + + // Given `shape`, apply the layout for `channel_id`. `channel_id` must already + // be constrained. + Shape LayoutShapeForChannel(Shape shape, int64 channel_id) const { + CHECK(IsChannelConstrained(channel_id)); + *shape.mutable_layout() = constraints_.at(channel_id); + return shape; + } + + // Returns the layout constraint for `channel_id`, which must already be + // constrained. + Layout LayoutForChannel(int64 channel_id) const { + CHECK(IsChannelConstrained(channel_id)); + return constraints_.at(channel_id); + } + + // Adds a new layout constraint for `channel_id`. If a constraint for + // `channel_id` already exists, this operation requires that the new layout is + // the same as the previously constrained layout. + void ConstrainChannel(int64 channel_id, const Layout& layout) { + CHECK(!IsChannelConstrained(channel_id) || + LayoutUtil::Equal(layout, constraints_[channel_id])); + constraints_[channel_id] = layout; + } + + private: + std::unordered_map constraints_; +}; + // HLO pass which assigns layouts to all instructions in the HLO module while // satisfying all necessary invariants and minimizing cost. class LayoutAssignment : public HloPassInterface { public: // entry_computation_layout is modified to populate a layout for the result in // the case that no particular layout is requested. - explicit LayoutAssignment(ComputationLayout* entry_computation_layout); + // + // channel_constraints is both an input and output. Any sends or recvs that + // are present in channel_constraints will be layed out as constrained. Any + // unconstrained sends or recvs will be layed out as locally optimal and their + // layout will be added as a constraint to channel_constraints. + // + // If channel_constraints is nullptr, no kSend or kRecvs must be contained + // within any module passed to `Run`. + explicit LayoutAssignment( + ComputationLayout* entry_computation_layout, + ChannelLayoutConstraints* channel_constraints = nullptr); ~LayoutAssignment() override {} tensorflow::StringPiece name() const override { return "layout-assignment"; } @@ -247,6 +296,19 @@ class LayoutAssignment : public HloPassInterface { const ResultLayoutConstraint& layout_constraint, LayoutConstraints* constraints); + // By default LayoutAssignment ensures that inputs and outputs of CustomCalls + // have the "major-first" layout (i.e. {n, n-1, ..., 0}). + // + // If this function returns true, LayoutAssignment does not set a layout for + // the given CustomCall. It's up to the backend to set one in + // AddBackendConstraints, if necessary. + // + // Precondition: instruction->opcode() == HloOpcode::kCustomCall. + virtual bool CustomCallRequiresMajorFirstLayout( + const HloInstruction* /*instruction*/) { + return true; + } + // Called after layouts of an instruction have been finalized to allow // subclasses to check for platform specific assumptions. virtual Status Verify(const HloInstruction* instruction) { @@ -283,9 +345,10 @@ class LayoutAssignment : public HloPassInterface { private: // Adds constraints which must be satisfied for correctness on all // backends. Called once prior to propagating constraints. - Status AddMandatoryConstraints(const ComputationLayout& computation_layout, - HloComputation* computation, - LayoutConstraints* constraints); + Status AddMandatoryConstraints( + const ComputationLayout& computation_layout, + const ChannelLayoutConstraints* channel_constraints, + HloComputation* computation, LayoutConstraints* constraints); // This method can be overridden to add backend-specific constraints to the // layout of the instructions of a computation. This method is called after @@ -301,7 +364,8 @@ class LayoutAssignment : public HloPassInterface { // constrained. Status RunOnComputation(const ComputationLayout& computation_layout, const TuplePointsToAnalysis& points_to_analysis, - HloComputation* computation); + HloComputation* computation, + ChannelLayoutConstraints* channel_constraints); // Assign layouts to the instructions of a computation which satisfy the given // layout constraints. Copies may be added to satisfy the constraints. The @@ -315,7 +379,12 @@ class LayoutAssignment : public HloPassInterface { // required for correctness. Status PropagateConstraints(LayoutConstraints* constraints); + // Check that all layouts in the module have been set and satisfy all + // necessary conditions. + Status CheckLayouts(HloModule* module); + ComputationLayout* entry_computation_layout_; + ChannelLayoutConstraints* channel_layout_constraints_; protected: // Map containing the layouts of all computations assigned so diff --git a/tensorflow/compiler/xla/service/liveness_util_test.cc b/tensorflow/compiler/xla/service/liveness_util_test.cc index 476e86fa72ad691cda52097c953ba15132f206a7..2c2a02f6375343d67dfb155bbb03729ff6e490d2 100644 --- a/tensorflow/compiler/xla/service/liveness_util_test.cc +++ b/tensorflow/compiler/xla/service/liveness_util_test.cc @@ -277,8 +277,11 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) { auto b = builder.AddInstruction(HloInstruction::CreateConstant( Literal::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); auto dot = builder.AddInstruction( - HloInstruction::CreateBinary(data_shape, HloOpcode::kDot, a, b)); + HloInstruction::CreateDot(data_shape, a, b, dot_dnums)); auto one = builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0(1.0))); @@ -312,8 +315,11 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedTransposeDotAdd) { auto b_t = builder.AddInstruction( HloInstruction::CreateTranspose(data_shape, b, {1, 0})); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); auto dot = builder.AddInstruction( - HloInstruction::CreateBinary(data_shape, HloOpcode::kDot, a, b_t)); + HloInstruction::CreateDot(data_shape, a, b_t, dot_dnums)); auto one = builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0(1.0))); diff --git a/tensorflow/compiler/xla/service/llvm_ir/BUILD b/tensorflow/compiler/xla/service/llvm_ir/BUILD index d878061f724de1c82f8285b0f082d0be4d5778df..ffc78bd5cfac3df1001d8125327607c85169ae92 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/BUILD +++ b/tensorflow/compiler/xla/service/llvm_ir/BUILD @@ -48,6 +48,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_module_config", "//tensorflow/compiler/xla/service:name_uniquer", "//tensorflow/core:lib", "@llvm//:core", @@ -156,18 +157,6 @@ cc_library( ], ) -cc_library( - name = "vector_support_library", - srcs = ["vector_support_library.cc"], - hdrs = ["vector_support_library.h"], - deps = [ - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", - "@llvm//:core", - ], -) - cc_library( name = "kernel_support_library", srcs = ["kernel_support_library.cc"], diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc index 7224bd689842d89563b374f3db3d4e314be18764..6384c7f46f5ebbedaeda232b40095611a5d738a4 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc @@ -39,13 +39,27 @@ IrArray::Index::Index(llvm::Value* linear, const Shape& shape, << "Shape " << ShapeUtil::HumanStringWithLayout(shape) << " should have a layout."; int64 divisor = 1; - for (int64 dimension : layout_.minor_to_major()) { + for (int64 i = 0; i < layout_.minor_to_major_size(); ++i) { + int64 dimension = layout_.minor_to_major(i); int64 size_of_current_dimension = shape.dimensions(dimension); - // Emit IR instructions that compute - // (linear_index / divisor) % current_dimension - multidim_[dimension] = ir_builder->CreateURem( - ir_builder->CreateUDiv(linear, ir_builder->getInt64(divisor)), - ir_builder->getInt64(size_of_current_dimension)); + + // If i is not the last dimension, compute + // (linear_index / divisor) % current_dimension. + // If i is the last dimension, we can skip the mod, because we assume that + // linear is in bounds. + // + // TODO(jlebar): We could add bounds checks here and elsewhere in this file, + // guarded under some sort of xla-memcheck flag. This might be particularly + // useful because cuda-memcheck can't help us much in XLA: Most of our + // memory lives in one big allocation, so cuda-memcheck can't detect + // out-of-bounds accesses. + auto* quot = ir_builder->CreateUDiv(linear, ir_builder->getInt64(divisor)); + if (i < layout_.minor_to_major_size() - 1) { + multidim_[dimension] = ir_builder->CreateURem( + quot, ir_builder->getInt64(size_of_current_dimension)); + } else { + multidim_[dimension] = quot; + } divisor *= size_of_current_dimension; } } @@ -244,8 +258,8 @@ llvm::Value* IrArray::EmitArrayElementAddress( // // getelementptr base_ptr_, 0, most major index, ..., most minor index std::vector gep_indices(1, ir_builder->getInt64(0)); - for (int64 i = shape_->layout().minor_to_major_size() - 1; i >= 0; --i) { - int64 dimension = shape_->layout().minor_to_major(i); + for (int64 i = 0; i < LayoutUtil::MinorToMajor(*shape_).size(); ++i) { + int64 dimension = LayoutUtil::Major(shape_->layout(), i); gep_indices.push_back(actual_index[dimension]); } return ir_builder->CreateInBoundsGEP(base_ptr_, gep_indices, diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc index 29cc0f81bd2c06538e28d1b593ee6a897fea0f27..23d2d4e87d26f4988ebddcf20f5a27af6a7fe0d6 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" namespace xla { void KernelSupportLibrary::For( @@ -62,4 +63,72 @@ void KernelSupportLibrary::If( false_block_generator(); llvm_ir::SetToLastInsertPoint(if_data.after_block, ir_builder_); } + +void KernelSupportLibrary::EmitAndCallOutlinedKernel( + bool enable_fast_math, bool optimize_for_size, + llvm::IRBuilder<>* ir_builder, tensorflow::StringPiece kernel_name, + KernelSupportLibrary::ArgumentVector arguments, + const std::function& + kernel_body_generator) { + llvm::Module* module = ir_builder->GetInsertBlock()->getModule(); + llvm::Function* function = + module->getFunction(llvm_ir::AsStringRef(kernel_name)); + + int64 null_arg_idx = -1; + std::vector sanitized_args; + sanitized_args.reserve(arguments.size()); + for (int64 i = 0, e = arguments.size(); i < e; i++) { + if (arguments[i]) { + sanitized_args.push_back(arguments[i]); + } else { + CHECK_EQ(null_arg_idx, -1); + null_arg_idx = i; + } + } + + if (!function) { + VLOG(2) << "Generating kernel for " << kernel_name; + std::vector arg_types; + std::transform(sanitized_args.begin(), sanitized_args.end(), + std::back_inserter(arg_types), + [](llvm::Value* arg) { return arg->getType(); }); + + auto* function_type = llvm::FunctionType::get( + ir_builder->getVoidTy(), arg_types, /*isVarArg=*/false); + + function = llvm_ir::CreateFunction( + function_type, llvm::GlobalValue::InternalLinkage, + /*enable_fast_math=*/enable_fast_math, + /*optimize_for_size=*/optimize_for_size, kernel_name, module); + + llvm::IRBuilder<>::InsertPointGuard guard(*ir_builder); + + auto* entry_bb = + llvm::BasicBlock::Create(ir_builder->getContext(), "entry", function); + auto* return_inst = llvm::ReturnInst::Create(ir_builder->getContext(), + /*retVal=*/nullptr, entry_bb); + // Set the insert point to before return_inst. + ir_builder->SetInsertPoint(return_inst); + + std::vector arg_values; + /* + * clang on OSX doesn't like std::transform or range for loop here. + * See https://github.com/tensorflow/tensorflow/issues/15196 + */ + for (llvm::Function::arg_iterator arg = function->arg_begin(), + arg_e = function->arg_end(); + arg != arg_e; ++arg) { + arg_values.push_back(arg); + } + if (null_arg_idx != -1) { + arg_values.insert(arg_values.begin() + null_arg_idx, nullptr); + } + kernel_body_generator(arg_values); + } else { + VLOG(3) << "Re-using kernel for " << kernel_name; + } + + ir_builder->CreateCall(function, llvm_ir::AsArrayRef(sanitized_args)); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h index 9bafb7b57740b7acd0286c113c8a0585c0f93689..827e092a3fa9116c461716b27c309033f7988745 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h @@ -118,6 +118,60 @@ class KernelSupportLibrary { const std::function& true_block_generator, const std::function& false_block_generator = []() {}); + using ArgumentVector = tensorflow::gtl::ArraySlice; + + // Generates the following control flow structure: + // + // define @`kernel_name`(arg0, arg1, ... arg`arguments.size()`) { + // kernel_body_generator({arg0, arg1, ... arg`arguments.size()`}); + // } + // + // ... + // call @`kernel_name`(arguments[0], arguments[1] ...) + // ... + // + // If a function called `kernel_name` is already present in the module then + // that function is re-used. In that sense we're using the llvm::Module as a + // cache of outlined kernels, keyed by function name. + // + // If any of the values in `arguments` is nullptr (i.e. a nullptr + // llvm::Value*) then we ignore it when generating LLVM IR, and instead pass + // in a nullptr llvm::Value* in its position to `kernel_body_generator`. + // Currently we only support at most one nullptr value in `arguments`. + static void EmitAndCallOutlinedKernel( + bool enable_fast_math, bool optimize_for_size, + llvm::IRBuilder<>* ir_builder, tensorflow::StringPiece kernel_name, + ArgumentVector arguments, + const std::function& kernel_body_generator); + + // Thin wrappers around the more general EmitAndCallOutlinedKernel above. + static void EmitAndCallOutlinedKernel( + bool enable_fast_math, bool optimize_for_size, + llvm::IRBuilder<>* ir_builder, tensorflow::StringPiece kernel_name, + llvm::Value* arg0, llvm::Value* arg1, llvm::Value* arg2, + const std::function& + kernel_body_generator) { + EmitAndCallOutlinedKernel( + enable_fast_math, optimize_for_size, ir_builder, kernel_name, + {arg0, arg1, arg2}, [&](ArgumentVector args) { + kernel_body_generator(args[0], args[1], args[2]); + }); + } + + static void EmitAndCallOutlinedKernel( + bool enable_fast_math, bool optimize_for_size, + llvm::IRBuilder<>* ir_builder, tensorflow::StringPiece kernel_name, + llvm::Value* arg0, llvm::Value* arg1, llvm::Value* arg2, + llvm::Value* arg3, + const std::function& kernel_body_generator) { + EmitAndCallOutlinedKernel( + enable_fast_math, optimize_for_size, ir_builder, kernel_name, + {arg0, arg1, arg2, arg3}, [&](ArgumentVector args) { + kernel_body_generator(args[0], args[1], args[2], args[3]); + }); + } + private: llvm::IRBuilder<>* ir_builder_; bool prevent_unrolling_; diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc index cd0c4a371e2b1cd0e1c52b77e47e8b081ab8e836..d2bcb38d09218c72183c7cece95bef6371006555 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc @@ -142,6 +142,13 @@ llvm::Type* PrimitiveTypeToIrType(PrimitiveType element_type, return llvm::Type::getInt8Ty(module->getContext()); case S16: case U16: + case BF16: + // For BF16 we just need some type that is 16 bits wide so that it will + // take up the right amount of space in memory. LLVM does not have a BF16 + // type (the LLVM half type is IEEE 16 bit floating point, not bfloat), so + // we can't map it directly to an LLVM type. We will not map a BF16 + // addition to an addition on this type (int16) - this is just the type + // used for storage. return llvm::Type::getInt16Ty(module->getContext()); case S32: case U32: @@ -200,8 +207,8 @@ llvm::Type* ShapeToIrType(const Shape& shape, llvm::Module* module) { if (ShapeUtil::IsTuple(shape)) { // A tuple buffer is an array of pointers. result_type = llvm::ArrayType::get(result_type, shape.tuple_shapes_size()); - } else { - for (int64 dimension : shape.layout().minor_to_major()) { + } else if (ShapeUtil::IsArray(shape)) { + for (int64 dimension : LayoutUtil::MinorToMajor(shape)) { result_type = llvm::ArrayType::get(result_type, shape.dimensions(dimension)); } @@ -280,6 +287,11 @@ llvm::Constant* LiteralToConstant(const Literal& literal, int64 dimension_index, value = llvm::ConstantFP::get(ir_element_type, literal.Get(*multi_index)); break; + case BF16: + value = llvm::ConstantInt::get( + ir_element_type, + tensorflow::bit_cast(literal.Get(*multi_index))); + break; case F64: value = llvm::ConstantFP::get(ir_element_type, literal.Get(*multi_index)); @@ -304,7 +316,7 @@ llvm::Constant* LiteralToConstant(const Literal& literal, int64 dimension_index, // decrements with each recursive call. We want to iterate through the // dimensions in major-to-minor order as we recurse so just index into // minor_to_major to get the dimension number for this level of the recursion. - int64 dimension = shape.layout().minor_to_major(dimension_index); + int64 dimension = LayoutUtil::Minor(shape.layout(), dimension_index); // Recursively call LiteralToConstant to construct subarrays for the // more-minor dimensions. Gather the subarrays into a vector for bundling into @@ -320,7 +332,7 @@ llvm::Constant* LiteralToConstant(const Literal& literal, int64 dimension_index, if (elements.empty()) { element_type = ir_element_type; for (int i = 0; i < dimension_index; ++i) { - int64 index = shape.layout().minor_to_major(i); + int64 index = LayoutUtil::Minor(shape.layout(), i); element_type = llvm::ArrayType::get(element_type, shape.dimensions(index)); } @@ -676,5 +688,58 @@ Status DumpIRToDirectory(const string& directory_name, return f->Close(); } +llvm::Function* CreateFunction(llvm::FunctionType* function_type, + llvm::GlobalValue::LinkageTypes linkage, + bool enable_fast_math, bool optimize_for_size, + tensorflow::StringPiece name, + llvm::Module* module) { + llvm::Function* function = + llvm::Function::Create(function_type, linkage, AsStringRef(name), module); + function->setCallingConv(llvm::CallingConv::C); + function->addFnAttr("no-frame-pointer-elim", "false"); + + if (enable_fast_math) { + function->addFnAttr("unsafe-fp-math", "true"); + function->addFnAttr("no-infs-fp-math", "true"); + function->addFnAttr("no-nans-fp-math", "true"); + function->addFnAttr("no-signed-zeros-fp-math", "true"); + } + + // Add the optize attribute to the function if optimizing for size. This + // controls internal behavior of some optimization passes (e.g. loop + // unrolling). + if (optimize_for_size) { + function->addFnAttr(llvm::Attribute::OptimizeForSize); + } + + return function; +} + +void InitializeLLVMCommandLineOptions(const HloModuleConfig& config) { + auto options = config.debug_options().xla_backend_extra_options(); + if (!options.empty()) { + std::vector fake_argv_storage; + fake_argv_storage.push_back(""); + for (const auto& it : options) { + // Skip options the XLA backend itself consumes. + if (!tensorflow::StringPiece(it.first).starts_with("xla_")) { + if (it.second.empty()) { + fake_argv_storage.push_back(it.first); + } else { + fake_argv_storage.push_back(it.first + "=" + it.second); + } + } + } + + VLOG(2) << "Passing argv to LLVM:"; + std::vector fake_argv; + for (const auto& s : fake_argv_storage) { + fake_argv.push_back(s.c_str()); + VLOG(2) << s; + } + llvm::cl::ParseCommandLineOptions(fake_argv.size(), &fake_argv[0]); + } +} + } // namespace llvm_ir } // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h index 063ead2b647d8fc5cc4f67004aaded80a2191fe9..4a10ec466dae6fdb56546fb8d8b353dcff6a5b8d 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h @@ -29,6 +29,7 @@ limitations under the License. #include "llvm/Support/raw_ostream.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/stringpiece.h" @@ -281,6 +282,16 @@ Status DumpIRToDirectory(const string& directory_name, const string& hlo_module_name, const llvm::Module& llvm_module, bool optimized); +llvm::Function* CreateFunction(llvm::FunctionType* function_type, + llvm::GlobalValue::LinkageTypes linkage, + bool enable_fast_math, bool optimize_for_size, + tensorflow::StringPiece name, + llvm::Module* module); + +// Extracts the xla_backend_extra_options from `config` and passes those that +// don't start with xla_ to LLVM. +void InitializeLLVMCommandLineOptions(const HloModuleConfig& config); + } // namespace llvm_ir } // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc index 6fa4cd08c9e0ac30b83c0e2b49d98d930c2e15df..a5f7c850c33757fe8d48567ade35544d81224e46 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc @@ -99,8 +99,8 @@ IrArray::Index LoopEmitter::EmitIndexAndSetExitBasicBlock( // dimension (of the target shape). ForLoopNest loop_nest(loop_name, ir_builder_); IrArray::Index array_index(shape_.dimensions_size()); - for (int i = shape_.layout().minor_to_major_size() - 1; i >= 0; --i) { - int64 dimension = shape_.layout().minor_to_major(i); + for (int i = 0; i < LayoutUtil::MinorToMajor(shape_).size(); ++i) { + int64 dimension = LayoutUtil::Major(shape_.layout(), i); std::unique_ptr loop = loop_nest.AddLoop( /*start_index=*/0, /*end_index=*/shape_.dimensions(dimension), diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc index 06f43bd3cb2376d34a3104133c868c4f4e5cc730..2194d24257d0ccd04f3c9625412116eba01acd8c 100644 --- a/tensorflow/compiler/xla/service/local_service.cc +++ b/tensorflow/compiler/xla/service/local_service.cc @@ -84,15 +84,30 @@ StatusOr> LocalService::CompileExecutable( // Validate incoming layouts. if (argument_layouts.size() != program_shape->parameters_size()) { return InvalidArgument( - "invalid number of arguments for computation: expected %d, got %zu", + "Invalid number of arguments for computation: expected %d, got %zu.", program_shape->parameters_size(), argument_layouts.size()); } for (int i = 0; i < argument_layouts.size(); ++i) { const Shape& argument_shape = *argument_layouts[i]; TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(argument_shape)); if (!ShapeUtil::Compatible(argument_shape, program_shape->parameters(i))) { + tensorflow::gtl::optional metadata = + user_computation->ParameterMetadata(i); + auto metadata_string = [&metadata]() -> string { + if (!metadata.has_value()) { + return ""; + } + CHECK(metadata.value() != nullptr); + const OpMetadata& m = *metadata.value(); + if (!m.source_file().empty()) { + return tensorflow::strings::Printf( + " (%s:%d)", m.source_file().c_str(), m.source_line()); + } + return ""; + }; return InvalidArgument( - "invalid argument shape for argument %d, expected %s, got %s", i, + "Invalid argument shape for argument %d%s, expected %s, got %s.", i, + metadata_string().c_str(), ShapeUtil::HumanString(program_shape->parameters(i)).c_str(), ShapeUtil::HumanString(argument_shape).c_str()); } @@ -118,10 +133,14 @@ StatusOr> LocalService::CompileExecutable( TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor, execute_backend_->stream_executor(device_ordinal)); - std::vector argument_buffers( - argument_layouts.size()); return BuildExecutable(versioned_handle, std::move(module_config), - argument_buffers, execute_backend_.get(), executor); + execute_backend_.get(), executor); +} + +StatusOr LocalService::ReplicaNumberToDeviceOrdinal(int replica_number) { + return backend().computation_placer()->DeviceId( + replica_number, /*computation=*/0, options_.number_of_replicas(), + /*computation_count=*/1); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/local_service.h b/tensorflow/compiler/xla/service/local_service.h index 52c4346385eb663baa6e7579d7b3883ba084205b..acbc7268252881958190f416ab936d64430166e1 100644 --- a/tensorflow/compiler/xla/service/local_service.h +++ b/tensorflow/compiler/xla/service/local_service.h @@ -47,6 +47,13 @@ class LocalService : public Service { const tensorflow::gtl::ArraySlice argument_layouts, const Shape* result_layout, int device_ordinal); + // Returns the device ordinal that corresponds to the given replica number. + // + // This returns an error if there is not a one-to-one correspondence of + // replicas to device ordinals, but is useful as a short term mechanism for + // the "easy" case where a single replica is a single device. + StatusOr ReplicaNumberToDeviceOrdinal(int replica_number); + private: explicit LocalService(const ServiceOptions& options, std::unique_ptr backend); diff --git a/tensorflow/compiler/xla/service/name_uniquer.cc b/tensorflow/compiler/xla/service/name_uniquer.cc index a0d08c288dbcc45e83a36ce7b094b04a9dbae532..7d8c05fffa4ab11d7dbf9956d2cb7ebd5bcdd3c4 100644 --- a/tensorflow/compiler/xla/service/name_uniquer.cc +++ b/tensorflow/compiler/xla/service/name_uniquer.cc @@ -17,12 +17,44 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" namespace xla { +namespace { + +bool IsAllowed(char character) { + auto c = static_cast(character); + return (isalnum(c) != 0) || c == '_' || c == '.' || c == '-'; +} + +} // namespace + +NameUniquer::NameUniquer(const string& separator) { + CHECK(std::all_of(separator.begin(), separator.end(), IsAllowed)) + << "separator should comprises allowed characters only"; + separator_ = separator; +} + +/*static*/ string NameUniquer::GetSanitizedName(const string& name) { + string result = name; + CHECK(!result.empty()) << "name should not be empty"; + char c = static_cast(result[0]); + if (!isalpha(c) && c != '_') { + result[0] = '_'; + } + for (int i = 1; i < result.length(); i++) { + if (!IsAllowed(result[i])) { + result[i] = '_'; + } + } + return result; +} + string NameUniquer::GetUniqueName(tensorflow::StringPiece prefix) { string root = prefix.empty() ? "name" : prefix.ToString(); + root = GetSanitizedName(root); // Strip away numeric suffix (if any). Only recognize separator if it is in // the middle of the name. diff --git a/tensorflow/compiler/xla/service/name_uniquer.h b/tensorflow/compiler/xla/service/name_uniquer.h index ed379b52258463b960dea788721c2c4325ef0260..4139c2700b25e8600182a034a8ac6f4f041c12e6 100644 --- a/tensorflow/compiler/xla/service/name_uniquer.h +++ b/tensorflow/compiler/xla/service/name_uniquer.h @@ -28,14 +28,21 @@ namespace xla { // Simple stateful class that helps generate "unique" names. To use it, simply // call GetUniqueName as many times as needed. The names returned by // GetUniqueName are guaranteed to be distinct for this instance of the class. +// Note that the names will be sanitized to match regexp +// "[a-zA-Z_][a-zA-Z0-9_.-]*". class NameUniquer { public: - explicit NameUniquer(const string& separator = "__") - : separator_(separator) {} + // The separator must contain allowed characters only: "[a-zA-Z0-9_.-]". + explicit NameUniquer(const string& separator = "__"); - // Get a unique name in a string, with an optional prefix for convenience. + // Get a sanitized unique name in a string, with an optional prefix for + // convenience. string GetUniqueName(tensorflow::StringPiece prefix = ""); + // Sanitizes and returns the name. Unallowed characters will be replaced with + // '_'. The result will match the regexp "[a-zA-Z_][a-zA-Z0-9_.-]*". + static string GetSanitizedName(const string& name); + private: // The string to use to separate the prefix of the name from the uniquing // integer value. diff --git a/tensorflow/compiler/xla/service/name_uniquer_test.cc b/tensorflow/compiler/xla/service/name_uniquer_test.cc index 9f0747a6e2175a968d8f3661ac51512009e86f29..4258cf16876ab46dce6df062ab701b1b1a4a7580 100644 --- a/tensorflow/compiler/xla/service/name_uniquer_test.cc +++ b/tensorflow/compiler/xla/service/name_uniquer_test.cc @@ -60,12 +60,30 @@ TEST_F(NameUniquerTest, NumericSuffixes) { EXPECT_EQ("bar", uniquer.GetUniqueName("bar.-1000")); EXPECT_EQ("bar.1", uniquer.GetUniqueName("bar.-2000")); EXPECT_EQ("bar.2", uniquer.GetUniqueName("bar.1")); +} + +TEST_F(NameUniquerTest, Sanitize) { + NameUniquer uniquer("_"); + + EXPECT_EQ("foo", uniquer.GetUniqueName("foo")); + EXPECT_EQ("foo_1", uniquer.GetUniqueName("foo")); + EXPECT_EQ("foo.54", uniquer.GetUniqueName("foo.54")); + EXPECT_EQ("foo_54", uniquer.GetUniqueName("foo_54")); + EXPECT_EQ("foo_54.1", uniquer.GetUniqueName("foo_54.1")); + EXPECT_EQ("foo_55", uniquer.GetUniqueName("foo")); + + // Invalid characters will be replaced with '_'. + EXPECT_EQ("bar", uniquer.GetUniqueName("bar<-1000")); + EXPECT_EQ("bar_1", uniquer.GetUniqueName("bar<-2000")); + EXPECT_EQ("bar_2", uniquer.GetUniqueName("bar_1")); // Separator is only recognized in the middle of the prefix. - EXPECT_EQ(".10", uniquer.GetUniqueName(".10")); - EXPECT_EQ(".10.1", uniquer.GetUniqueName(".10")); - EXPECT_EQ("foobar.", uniquer.GetUniqueName("foobar.")); - EXPECT_EQ("foobar..1", uniquer.GetUniqueName("foobar.")); + EXPECT_EQ("_10", uniquer.GetUniqueName( + ".10")); // the leading '.' is replaced with '_'. + EXPECT_EQ("_10_1", uniquer.GetUniqueName(".10")); + EXPECT_EQ("_10_2", uniquer.GetUniqueName("_10")); + EXPECT_EQ("foobar_", uniquer.GetUniqueName("foobar_")); + EXPECT_EQ("foobar__1", uniquer.GetUniqueName("foobar_")); } } // namespace diff --git a/tensorflow/compiler/xla/service/platform_util.cc b/tensorflow/compiler/xla/service/platform_util.cc index 63f3bfb36cedeb44b190e1e8a5584d334f94b585..aa974ee61a27de9c19e97d8a6eb48f9261ce4bd9 100644 --- a/tensorflow/compiler/xla/service/platform_util.cc +++ b/tensorflow/compiler/xla/service/platform_util.cc @@ -33,10 +33,32 @@ namespace se = ::perftools::gputools; namespace xla { +using tensorflow::str_util::Lowercase; + // Minimum supported CUDA compute capability is 3.5. constexpr int kMinCudaComputeCapabilityMajor = 3; constexpr int kMinCudaComputeCapabilityMinor = 5; +// The name of the interpreter platform. +constexpr char kInterpreter[] = "interpreter"; + +namespace { + +string CanonicalPlatformName(const string& name) { + string platform_str = Lowercase(name); + // "cpu" and "host" mean the same thing. + if (platform_str == "cpu") { + platform_str = "host"; + } + // "gpu" and "cuda" mean the same thing. + if (platform_str == "gpu") { + platform_str = "cuda"; + } + return platform_str; +} + +} // namespace + /* static */ StatusOr> PlatformUtil::GetSupportedPlatforms() { se::MultiPlatformManager::PlatformMap platform_map; @@ -78,7 +100,7 @@ PlatformUtil::GetSupportedPlatforms() { return platforms; } -/* static */ StatusOr PlatformUtil::GetDefaultPlatform() { +/* static */ StatusOr PlatformUtil::GetSolePlatform() { TF_ASSIGN_OR_RETURN(auto platforms, GetSupportedPlatforms()); if (platforms.empty()) { return NotFound("no platforms found"); @@ -87,26 +109,42 @@ PlatformUtil::GetSupportedPlatforms() { } // Multiple platforms present and we can't pick a reasonable default. - auto l = [](string* out, const se::Platform* p) { out->append(p->Name()); }; - string platforms_string = tensorflow::str_util::Join(platforms, ", ", l); + string platforms_string = tensorflow::str_util::Join( + platforms, ", ", + [](string* out, const se::Platform* p) { out->append(p->Name()); }); return InvalidArgument( "must specify platform because more than one platform found: %s", platforms_string.c_str()); } -/*static*/ StatusOr PlatformUtil::GetPlatform( - const string& platform_name) { - using tensorflow::str_util::Lowercase; - string platform_str = Lowercase(platform_name); - // "cpu" and "host" mean the same thing. - if (platform_str == "cpu") { - platform_str = "host"; - } - // "gpu" and "cuda" mean the same thing. - if (platform_str == "gpu") { - platform_str = "cuda"; +/* static */ StatusOr PlatformUtil::GetDefaultPlatform() { + TF_ASSIGN_OR_RETURN(auto platforms, GetSupportedPlatforms()); + if (platforms.empty()) { + return NotFound("no platforms found"); + } else if (platforms.size() == 1) { + return platforms[0]; + } else if (platforms.size() == 2) { + for (int i = 0; i < 2; i++) { + if (Lowercase(platforms[i]->Name()) == kInterpreter && + Lowercase(platforms[1 - i]->Name()) != kInterpreter) { + return platforms[1 - i]; + } + } } + // Multiple platforms present and we can't pick a reasonable default. + string platforms_string = tensorflow::str_util::Join( + platforms, ", ", + [](string* out, const se::Platform* p) { out->append(p->Name()); }); + return InvalidArgument( + "must specify platform because more than one platform (except for the " + "interpreter platform) found: %s", + platforms_string.c_str()); +} + +/*static*/ StatusOr PlatformUtil::GetPlatform( + const string& platform_name) { + string platform_str = CanonicalPlatformName(platform_name); TF_ASSIGN_OR_RETURN(auto platforms, PlatformUtil::GetSupportedPlatforms()); for (se::Platform* platform : platforms) { if (Lowercase(platform->Name()) == platform_str) { @@ -116,6 +154,32 @@ PlatformUtil::GetSupportedPlatforms() { return InvalidArgument("platform %s not found", platform_name.c_str()); } +/*static*/ StatusOr PlatformUtil::GetPlatformExceptFor( + const string& platform_name) { + string platform_str = CanonicalPlatformName(platform_name); + + TF_ASSIGN_OR_RETURN(auto platforms, PlatformUtil::GetSupportedPlatforms()); + std::vector matched; + for (se::Platform* platform : platforms) { + if (Lowercase(platform->Name()) != platform_name) { + matched.push_back(platform); + } + } + if (matched.empty()) { + return InvalidArgument("unable to find platform that is not %s", + platform_name.c_str()); + } + if (matched.size() == 1) { + return matched[0]; + } + string matched_string = tensorflow::str_util::Join( + matched, ", ", + [](string* out, const se::Platform* p) { out->append(p->Name()); }); + return InvalidArgument( + "found multiple platforms %s, but expected one platform except for %s", + matched_string.c_str(), platform_name.c_str()); +} + // Returns whether the device underlying the given StreamExecutor is supported // by XLA. static bool IsDeviceSupported(se::StreamExecutor* executor) { diff --git a/tensorflow/compiler/xla/service/platform_util.h b/tensorflow/compiler/xla/service/platform_util.h index a59d4ffe87f568ac786e4b2d3bf6983bc0d4695a..69188820a70707d9c9be10b20fb7de92ad4d9873 100644 --- a/tensorflow/compiler/xla/service/platform_util.h +++ b/tensorflow/compiler/xla/service/platform_util.h @@ -37,16 +37,28 @@ class PlatformUtil { static StatusOr> GetSupportedPlatforms(); - // Convenience function which returns the default supported platform. If + // Convenience function which returns the default supported platform for + // tests. If exactly one supported platform is present, then this platform is + // the default platform. If exactly two platforms are present and one of them + // is the interpreter platform, then the other platform is the default + // platform. Otherwise returns an error. + static StatusOr GetDefaultPlatform(); + + // Convenience function which returns the sole supported platform. If // exactly one supported platform is present, then this platform is the // default platform. Otherwise returns an error. - static StatusOr GetDefaultPlatform(); + static StatusOr GetSolePlatform(); // Returns the platform according to the given name. Returns error if there is // no such platform. static StatusOr GetPlatform( const string& platform_name); + // Returns exactly one platform that does not have given name. Returns error + // if there is no such platform, or there are multiple such platforms. + static StatusOr GetPlatformExceptFor( + const string& platform_name); + // Returns a vector of StreamExecutors for the given platform. The vector is // indexed by device ordinal (device numbering used by StreamExecutor). If an // element is nullptr, then the device is present by not supported by XLA. diff --git a/tensorflow/compiler/xla/service/reshape_mover.cc b/tensorflow/compiler/xla/service/reshape_mover.cc index 0fb90230f2f39a841973361f63d17af579a1342b..e62bafc50b0e1270702621c9ea7b2ee43e001fe0 100644 --- a/tensorflow/compiler/xla/service/reshape_mover.cc +++ b/tensorflow/compiler/xla/service/reshape_mover.cc @@ -101,8 +101,9 @@ HloInstruction* FirstNonScalarAndNonTrivialReshapeOperand( IsReshapeOrTranspose(operand) && !CanTriviallyChangeShape(operand->operand(0))) { VLOG(5) << "Found first non-scalar and non-trivial reshape operand of " - << hlo->ToStringNoMetadata() << ":\n\t" - << operand->ToStringNoMetadata(); + << hlo->ToString(HloPrintOptions().set_print_metadata(false)) + << ":\n\t" + << operand->ToString(HloPrintOptions().set_print_metadata(false)); return operand; } } @@ -133,8 +134,9 @@ bool AreEquivalentReshapes(const HloInstruction* a, const HloInstruction* b) { bool AllOperandsHaveEasyShapeChanges( const HloInstruction* instruction, const HloInstruction* first_reshape_operand) { + auto print_no_metadata = HloPrintOptions().set_print_metadata(false); VLOG(3) << "** Checking whether all operands have easy shape changes: " - << instruction->ToStringNoMetadata(); + << instruction->ToString(print_no_metadata); // Check whether all operands: // 0. Have the same dimensions as the output -- if not, it may be // implicitly broadcast, which can confound the movement's @@ -151,21 +153,21 @@ bool AllOperandsHaveEasyShapeChanges( VLOG(5) << "Operand shape differs from output shape; may be " "implicitly broadcast, so preventing " "movement\n\toperand: " - << operand->ToStringNoMetadata() - << "\n\tinstruction: " << instruction->ToStringNoMetadata(); + << operand->ToString(print_no_metadata) << "\n\tinstruction: " + << instruction->ToString(print_no_metadata); return false; } if (AreEquivalentReshapes(first_reshape_operand, operand)) { VLOG(5) << "Are equivalent reshapes:\n\tfirst_reshape_operand: " - << first_reshape_operand->ToStringNoMetadata() - << "\n\toperand: " << operand->ToStringNoMetadata(); + << first_reshape_operand->ToString(print_no_metadata) + << "\n\toperand: " << operand->ToString(print_no_metadata); continue; } if (CanTriviallyChangeShape(operand)) { VLOG(5) << "Operand can trivially change shape: " - << operand->ToStringNoMetadata(); + << operand->ToString(print_no_metadata); continue; } @@ -173,12 +175,12 @@ bool AllOperandsHaveEasyShapeChanges( // well. VLOG(5) << "Operand is neither equalivant to the first Reshape operand" "nor can trivially change shape: " - << operand->ToStringNoMetadata(); + << operand->ToString(print_no_metadata); return false; } VLOG(3) << "All operands have easy shape changes: " - << instruction->ToStringNoMetadata(); + << instruction->ToString(print_no_metadata); return true; } @@ -250,11 +252,13 @@ StatusOr TrySinkReshapeOrTranspose(HloComputation* computation, return false; } + auto print_no_metadata = HloPrintOptions().set_print_metadata(false); // At this point we've decided to sink reshape/transpose operands. const Shape& new_operand_shape = first_reshape_operand->operand(0)->shape(); VLOG(3) << "** Sinking reshape or transpose: " - << instruction->ToStringNoMetadata() << "\n\tfirst reshape operand: " - << first_reshape_operand->ToStringNoMetadata() + << instruction->ToString(print_no_metadata) + << "\n\tfirst reshape operand: " + << first_reshape_operand->ToString(print_no_metadata) << "\n\tnew operand shape: " << ShapeUtil::HumanString(new_operand_shape); @@ -267,7 +271,7 @@ StatusOr TrySinkReshapeOrTranspose(HloComputation* computation, continue; } VLOG(3) << "Updating operand #" << i << ": " - << operands[i]->ToStringNoMetadata(); + << operands[i]->ToString(print_no_metadata); operands[i] = UpdateOperand(computation, first_reshape_operand, new_operand_shape, operands[i]); } @@ -298,7 +302,7 @@ StatusOr TrySinkReshapeOrTranspose(HloComputation* computation, switch (first_reshape_operand->opcode()) { case HloOpcode::kReshape: VLOG(3) << "Creating new reshape for new elementwise op: " - << new_elementwise->ToStringNoMetadata(); + << new_elementwise->ToString(print_no_metadata); new_reshape = HloInstruction::CreateReshape(instruction->shape(), new_elementwise); break; diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index d997cab83f8c2bc74632e49f23e690ffb17b901a..fc848bdb036125e5dadb471be431d3d2523c6770 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -60,41 +60,32 @@ namespace xla { namespace { -// Copies the contents of an Allocation into a Literal proto. -tensorflow::Status LiteralFromAllocation(const Allocation* allocation, - const Shape& literal_shape, - Literal* literal) { - TF_ASSIGN_OR_RETURN( - se::StreamExecutor * executor, - allocation->backend()->stream_executor(allocation->device_ordinal())); - return allocation->backend()->transfer_manager()->TransferLiteralFromDevice( - executor, allocation->device_memory(), allocation->shape(), literal_shape, - literal); -} - // Records the arguments used to invoke a computation in a SessionModule // proto. tensorflow::Status RecordArguments( - const tensorflow::gtl::ArraySlice arg_allocations, + const tensorflow::gtl::ArraySlice arguments, + se::StreamExecutor* executor, TransferManager* transfer_manager, SessionModule* module) { module->clear_arguments(); - for (const Allocation* allocation : arg_allocations) { - Literal argument; - TF_RETURN_IF_ERROR( - LiteralFromAllocation(allocation, allocation->shape(), &argument)); - *module->add_arguments() = argument.ToProto(); + for (const ShapedBuffer* argument : arguments) { + TF_ASSIGN_OR_RETURN( + std::unique_ptr literal, + transfer_manager->TransferLiteralFromDevice(executor, *argument)); + *module->add_arguments() = literal->ToProto(); } return tensorflow::Status::OK(); } // Records the result of a computation in a SessionModule proto. -tensorflow::Status RecordResult(const Allocation* result_allocation, +tensorflow::Status RecordResult(const ShapedBuffer& result, + se::StreamExecutor* executor, + TransferManager* transfer_manager, SessionModule* module) { module->clear_result(); - Literal result; - TF_RETURN_IF_ERROR(LiteralFromAllocation( - result_allocation, result_allocation->shape(), &result)); - *module->mutable_result() = result.ToProto(); + TF_ASSIGN_OR_RETURN( + std::unique_ptr literal, + transfer_manager->TransferLiteralFromDevice(executor, result)); + *module->mutable_result() = literal->ToProto(); return tensorflow::Status::OK(); } @@ -152,7 +143,9 @@ int ServiceOptions::intra_op_parallelism_threads() const { Service::Service(const ServiceOptions& options, std::unique_ptr execute_backend) - : options_(options), execute_backend_(std::move(execute_backend)) { + : options_(options), + allocation_tracker_(execute_backend.get()), + execute_backend_(std::move(execute_backend)) { CHECK_GT(options_.number_of_replicas(), 0); if (execute_backend_) { if (execute_backend_->device_count() > 0) { @@ -235,35 +228,33 @@ tensorflow::Status Service::ValidateResultShapeWithLayout( return ShapeUtil::ValidateShape(shape_with_layout); } -StatusOr> Service::ResolveAndValidateArguments( +StatusOr> Service::ResolveAndValidateArguments( tensorflow::gtl::ArraySlice arguments, - const Backend* backend, int device_ordinal) { - std::vector allocations; + int device_ordinal) { + std::vector shaped_buffers; for (size_t i = 0; i < arguments.size(); ++i) { - auto allocation_status = allocation_tracker_.Resolve(*arguments[i]); - if (!allocation_status.ok()) { - return Status(allocation_status.status().code(), - StrCat(allocation_status.status().error_message(), ", ", + auto buffer_status = allocation_tracker_.Resolve(*arguments[i]); + if (!buffer_status.ok()) { + return Status(buffer_status.status().code(), + StrCat(buffer_status.status().error_message(), ", ", "failed to resolve allocation for parameter ", i)); } - const Allocation* allocation = allocation_status.ValueOrDie(); + const ShapedBuffer* shaped_buffer = buffer_status.ValueOrDie(); // Verify allocation is same platform and device as the execution. - if (allocation->backend() != backend || - allocation->device_ordinal() != device_ordinal) { + if (shaped_buffer->platform() != execute_backend_->platform() || + shaped_buffer->device_ordinal() != device_ordinal) { return InvalidArgument( - "argument %lu is on device %s but computation will be executed " + "argument %lu is on device %s:%d but computation will be executed " "on device %s", - i, - allocation->backend() - ->device_name(allocation->device_ordinal()) - .c_str(), - backend->device_name(device_ordinal).c_str()); + i, shaped_buffer->platform()->Name().c_str(), + shaped_buffer->device_ordinal(), + execute_backend_->device_name(device_ordinal).c_str()); } - allocations.push_back(allocation); + shaped_buffers.push_back(shaped_buffer); } - return allocations; + return shaped_buffers; } StatusOr> Service::CreateModuleConfig( @@ -325,11 +316,11 @@ StatusOr> Service::CreateModuleConfig( StatusOr> Service::CreateModuleConfig( const ProgramShape& program_shape, - tensorflow::gtl::ArraySlice arguments, + tensorflow::gtl::ArraySlice arguments, const ExecutionOptions& execution_options) { std::vector argument_shapes; for (const auto* arg : arguments) { - argument_shapes.push_back(&arg->shape()); + argument_shapes.push_back(&arg->on_host_shape()); } return CreateModuleConfig(program_shape, argument_shapes, &execution_options); } @@ -398,8 +389,6 @@ StatusOr>> Service::BuildExecutables( StatusOr> Service::BuildExecutable( const VersionedComputationHandle& versioned_handle, std::unique_ptr module_config, - const tensorflow::gtl::ArraySlice - arguments, Backend* backend, se::StreamExecutor* executor) { VLOG(1) << Printf("BuildExecutable on service %p with handle %s", this, versioned_handle.ToString().c_str()); @@ -447,8 +436,6 @@ StatusOr> Service::BuildExecutable( StatusOr> Service::BuildAndCacheExecutable( const VersionedComputationHandle& versioned_handle, std::unique_ptr module_config, - const tensorflow::gtl::ArraySlice - arguments, Backend* backend, perftools::gputools::StreamExecutor* executor, ExecutionProfile* profile) { std::shared_ptr executable = @@ -471,8 +458,8 @@ StatusOr> Service::BuildAndCacheExecutable( HloModuleConfig original_module_config = *module_config; TF_ASSIGN_OR_RETURN( std::unique_ptr executable_unique_ptr, - BuildExecutable(versioned_handle, std::move(module_config), arguments, - backend, executor)); + BuildExecutable(versioned_handle, std::move(module_config), backend, + executor)); if (profile != nullptr) { uint64 end_micros = tensorflow::Env::Default()->NowMicros(); @@ -489,9 +476,7 @@ StatusOr> Service::BuildAndCacheExecutable( StatusOr> Service::ExecuteParallelAndRegisterResult( tensorflow::gtl::ArraySlice executables, - tensorflow::gtl::ArraySlice< - std::vector> - arguments, + tensorflow::gtl::ArraySlice> arguments, Backend* backend, tensorflow::gtl::ArraySlice device_handles, tensorflow::gtl::ArraySlice result_tags, ExecutionProfile* profile) { @@ -547,7 +532,7 @@ Service::ExecuteParallelAndRegisterResult( // Asynchronously launch the computation. TF_ASSIGN_OR_RETURN( - perftools::gputools::DeviceMemoryBase result, + std::unique_ptr result, executables[i]->ExecuteAsyncOnStream(&run_options, arguments[i])); if (replica == 0 && profile != nullptr) { @@ -557,17 +542,20 @@ Service::ExecuteParallelAndRegisterResult( // All replicas share the same device address for the result allocation, // so only one of the replicas need to register the result handle. if (replica == 0) { - result_handles.push_back(allocation_tracker_.Register( - backend, replicas[0]->device_ordinal(), result, - executables[i]->result_shape(), result_tags[i])); + TF_ASSIGN_OR_RETURN( + GlobalDataHandle handle, + allocation_tracker_.Register(std::move(result), result_tags[i])); + result_handles.push_back(handle); } } } // Wait for all executions to complete. for (int64 i = 0; i < streams.size(); ++i) { - if (!streams[i]->BlockHostUntilDone()) { - return InternalError("failed to complete execution for stream %lld", i); + Status block_status = streams[i]->BlockHostUntilDone(); + if (!block_status.ok()) { + return InternalError("failed to complete execution for stream %lld: %s", + i, block_status.error_message().c_str()); } } @@ -625,8 +613,7 @@ Service::ExecuteParallelAndRegisterResult( StatusOr Service::ExecuteAndRegisterResult( Executable* executable, - const tensorflow::gtl::ArraySlice - arguments, + const tensorflow::gtl::ArraySlice arguments, Backend* backend, perftools::gputools::StreamExecutor* executor, const string& result_tag, ExecutionProfile* profile) { // Set up streams. @@ -651,6 +638,7 @@ StatusOr Service::ExecuteAndRegisterResult( for (const Pool::SmartPtr& stream : streams) { ExecutableRunOptions options; options.set_stream(stream.get()); + options.set_device_ordinal(stream->parent()->device_ordinal()); options.set_allocator(backend->memory_allocator()); options.set_inter_op_thread_pool(backend->inter_op_thread_pool()); options.set_intra_op_thread_pool( @@ -660,24 +648,21 @@ StatusOr Service::ExecuteAndRegisterResult( backend->inter_op_thread_pool()); } - perftools::gputools::DeviceMemoryBase result; + std::unique_ptr result; if (options_.number_of_replicas() == 1) { - TF_ASSIGN_OR_RETURN( - result, executable->ExecuteOnStreamWrapper( - &run_options[0], profile, arguments)); + TF_ASSIGN_OR_RETURN(result, executable->ExecuteOnStreamWrapper( + &run_options[0], profile, arguments)); } else { - std::vector< - tensorflow::gtl::ArraySlice> + // TODO(b/69985541): Support profiling also on this path. + std::vector> repeated_arguments(options_.number_of_replicas(), arguments); TF_ASSIGN_OR_RETURN(auto results, executable->ExecuteOnStreams( run_options, repeated_arguments)); TF_RET_CHECK(!results.empty()); - result = results[0]; + result = std::move(results[0]); } - return allocation_tracker_.Register(backend, executor->device_ordinal(), - result, executable->result_shape(), - result_tag); + return allocation_tracker_.Register(std::move(result), result_tag); } tensorflow::Status Service::SetReturnValue(const SetReturnValueRequest* arg, @@ -691,7 +676,7 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg, ExecuteParallelResponse* result) { VLOG(1) << "running execute-parallel request: " << arg->ShortDebugString(); - std::vector> all_arguments; + std::vector> all_arguments; std::vector> all_executors; std::vector versioned_handles; std::vector> module_configs; @@ -748,19 +733,14 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg, // In the case of partitioned computations, assume all arguments go on the // zeroth core. TF_ASSIGN_OR_RETURN( - std::vector arg_allocations, - ResolveAndValidateArguments(request.arguments(), execute_backend_.get(), + std::vector arguments, + ResolveAndValidateArguments(request.arguments(), executors[0]->device_ordinal())); - std::vector arguments; - arguments.reserve(arg_allocations.size()); - for (const Allocation* allocation : arg_allocations) { - arguments.push_back(allocation->device_memory()); - } // Create an HloModuleConfig object for the computation, given the shape of // the program and the argument allocations. TF_ASSIGN_OR_RETURN(std::unique_ptr module_config, - CreateModuleConfig(*program_shape, arg_allocations, + CreateModuleConfig(*program_shape, arguments, request.execution_options())); VLOG(3) << "ExecuteParallel created HloModuleConfig computation layout: " << module_config->entry_computation_layout().ToString(); @@ -863,35 +843,30 @@ tensorflow::Status Service::Execute(const ExecuteRequest* arg, user_computation->ComputeProgramShape(versioned_handle.version)); TF_ASSIGN_OR_RETURN( - std::vector arg_allocations, - ResolveAndValidateArguments(arg->arguments(), execute_backend_.get(), + std::vector arguments, + ResolveAndValidateArguments(arg->arguments(), execute_backend_->default_device_ordinal())); - TF_ASSIGN_OR_RETURN(std::unique_ptr module_config, - CreateModuleConfig(*program_shape, arg_allocations, - arg->execution_options())); + TF_ASSIGN_OR_RETURN( + std::unique_ptr module_config, + CreateModuleConfig(*program_shape, arguments, arg->execution_options())); VLOG(3) << "Execute created HloModuleConfig computation layout: " << module_config->entry_computation_layout().ToString(); - std::vector arguments; - arguments.reserve(arg_allocations.size()); - for (const Allocation* allocation : arg_allocations) { - arguments.push_back(allocation->device_memory()); - } - TF_ASSIGN_OR_RETURN( std::shared_ptr executable, BuildAndCacheExecutable(versioned_handle, std::move(module_config), - arguments, execute_backend_.get(), + execute_backend_.get(), execute_backend_->default_stream_executor(), result->mutable_profile())); if (executable->dumping()) { executable->session_module()->set_execution_platform( execute_backend_->platform()->Name()); - TF_RETURN_IF_ERROR( - RecordArguments(arg_allocations, executable->session_module())); + TF_RETURN_IF_ERROR(RecordArguments( + arguments, execute_backend_->default_stream_executor(), + execute_backend_->transfer_manager(), executable->session_module())); } TF_ASSIGN_OR_RETURN( @@ -902,10 +877,11 @@ tensorflow::Status Service::Execute(const ExecuteRequest* arg, "result of " + user_computation->name(), result->mutable_profile())); if (executable->dumping()) { - TF_ASSIGN_OR_RETURN(const Allocation* result_allocation, + TF_ASSIGN_OR_RETURN(const ShapedBuffer* result_buffer, allocation_tracker_.Resolve(result->output())); - TF_RETURN_IF_ERROR( - RecordResult(result_allocation, executable->session_module())); + TF_RETURN_IF_ERROR(RecordResult( + *result_buffer, execute_backend_->default_stream_executor(), + execute_backend_->transfer_manager(), executable->session_module())); TF_RETURN_IF_ERROR(executable->DumpSessionModule()); } @@ -931,31 +907,24 @@ tensorflow::Status Service::ExecuteAsync(const ExecuteAsyncRequest* arg, user_computation->ComputeProgramShape(versioned_handle.version)); TF_ASSIGN_OR_RETURN( - std::vector arg_allocations, - ResolveAndValidateArguments(arg->arguments(), execute_backend_.get(), + std::vector arguments, + ResolveAndValidateArguments(arg->arguments(), execute_backend_->default_device_ordinal())); - TF_ASSIGN_OR_RETURN(std::unique_ptr module_config, - CreateModuleConfig(*program_shape, arg_allocations, - arg->execution_options())); + TF_ASSIGN_OR_RETURN( + std::unique_ptr module_config, + CreateModuleConfig(*program_shape, arguments, arg->execution_options())); VLOG(3) << "ExecuteAsync created HloModuleConfig computation layout: " << module_config->entry_computation_layout().ToString(); - std::vector arguments; - arguments.reserve(arg_allocations.size()); - for (const Allocation* allocation : arg_allocations) { - arguments.push_back(allocation->device_memory()); - } - ExecutionProfile profile; TF_ASSIGN_OR_RETURN( std::shared_ptr executable, - BuildAndCacheExecutable(versioned_handle, std::move(module_config), - arguments, execute_backend_.get(), - execute_backend_->default_stream_executor(), - &profile)); + BuildAndCacheExecutable( + versioned_handle, std::move(module_config), execute_backend_.get(), + execute_backend_->default_stream_executor(), &profile)); TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*execute_backend_, SingleComputationDeviceHandle())); @@ -970,7 +939,7 @@ tensorflow::Status Service::ExecuteAsync(const ExecuteAsyncRequest* arg, streams.push_back(std::move(stream)); } - perftools::gputools::DeviceMemoryBase result_data; + std::unique_ptr result_buffer; for (const Pool::SmartPtr& stream : streams) { ExecutableRunOptions options; options.set_stream(stream.get()); @@ -983,19 +952,19 @@ tensorflow::Status Service::ExecuteAsync(const ExecuteAsyncRequest* arg, options, execute_backend_->StreamBorrower()); TF_ASSIGN_OR_RETURN( - perftools::gputools::DeviceMemoryBase this_result_data, + std::unique_ptr this_result_buffer, executable->ExecuteAsyncOnStream(&service_options, arguments)); // Take the first result. - if (result_data == nullptr) { - result_data = this_result_data; + if (result_buffer == nullptr) { + result_buffer = std::move(this_result_buffer); } } - auto output = allocation_tracker_.Register( - execute_backend_.get(), execute_backend_->default_device_ordinal(), - result_data, executable->result_shape(), - "result of " + user_computation->name()); + TF_ASSIGN_OR_RETURN( + GlobalDataHandle output, + allocation_tracker_.Register(std::move(result_buffer), + "result of " + user_computation->name())); *result->mutable_execution() = execution_tracker_.Register( execute_backend_.get(), std::move(streams), profile, output); @@ -1022,37 +991,58 @@ tensorflow::Status Service::WaitForExecution(const WaitForExecutionRequest* arg, tensorflow::Status Service::TransferToClient(const TransferToClientRequest* arg, TransferToClientResponse* result) { - TF_ASSIGN_OR_RETURN(const Allocation* allocation, + TF_ASSIGN_OR_RETURN(const ShapedBuffer* shaped_buffer, allocation_tracker_.Resolve(arg->data())); - const Shape* literal_shape; + const Shape* return_shape; if (arg->has_shape_with_layout()) { if (!LayoutUtil::HasLayout(arg->shape_with_layout())) { return InvalidArgument("shape_with_layout must have layout if present."); } - literal_shape = &arg->shape_with_layout(); + return_shape = &arg->shape_with_layout(); } else { - literal_shape = &allocation->shape(); + return_shape = &shaped_buffer->on_host_shape(); } - Literal literal; - TF_RETURN_IF_ERROR( - LiteralFromAllocation(allocation, *literal_shape, &literal)); - *result->mutable_literal() = literal.ToProto(); + TF_ASSIGN_OR_RETURN( + se::StreamExecutor * executor, + execute_backend_->stream_executor(shaped_buffer->device_ordinal())); + + TF_ASSIGN_OR_RETURN( + std::unique_ptr result_literal, + execute_backend_->transfer_manager()->TransferLiteralFromDevice( + executor, *shaped_buffer)); + + if (LayoutUtil::LayoutsInShapesEqual(*return_shape, + result_literal->shape())) { + *result->mutable_literal() = result_literal->ToProto(); + } else { + *result->mutable_literal() = + result_literal->Relayout(*return_shape)->ToProto(); + } return tensorflow::Status::OK(); } +namespace { + +// Creates a clone of the given shaped buffer with the given device ordinal. The +// shape and DeviceMemoryBase values of the clone are identical to the original. +std::unique_ptr CloneShapedBufferOnDevice( + const ShapedBuffer& shaped_buffer, int device_ordinal) { + auto clone = MakeUnique( + shaped_buffer.on_host_shape(), shaped_buffer.on_device_shape(), + shaped_buffer.platform(), device_ordinal); + clone->buffers() = shaped_buffer.buffers(); + return clone; +} + +} // namespace + tensorflow::Status Service::TransferToServer(const TransferToServerRequest* arg, TransferToServerResponse* result) { - Literal literal = Literal(arg->literal()); - const Shape& shape = literal.shape(); - - if (ShapeUtil::IsTuple(shape) && options_.number_of_replicas() > 1) { - // TODO(b/32990684): Tuple transfers to host end up allocating further - // buffers - implement that correctly. - return Unimplemented( - "Tuple transfers to the device not supported with replication."); - } + TF_ASSIGN_OR_RETURN(std::unique_ptr literal, + Literal::CreateFromProto(arg->literal())); + const Shape& shape = literal->shape(); std::vector replicas; if (arg->has_device_handle()) { @@ -1063,25 +1053,38 @@ tensorflow::Status Service::TransferToServer(const TransferToServerRequest* arg, replicas, Replicas(*execute_backend_, SingleComputationDeviceHandle())); } - // Allocate memory on the device, using the stream executor. The size of the - // allocation is obtained by examining the shape of the literal passed from - // the client. An allocation handle is returned in the response. - int64 allocation_size = - execute_backend_->transfer_manager()->GetByteSizeRequirement(shape); - - TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase allocation, - execute_backend_->memory_allocator()->Allocate( - replicas[0]->device_ordinal(), allocation_size)); - - *result->mutable_data() = allocation_tracker_.Register( - execute_backend_.get(), replicas[0]->device_ordinal(), allocation, shape, - StrCat("TransferToServer literal of size ", allocation_size)); + // All memory allocation is done on the first replica. The allocations in all + // other replicas mirror the firsts'. + int master_device_ordinal = replicas[0]->device_ordinal(); + TF_ASSIGN_OR_RETURN( + std::unique_ptr shaped_buffer, + execute_backend_->transfer_manager()->AllocateShapedBuffer( + shape, execute_backend_->memory_allocator(), master_device_ordinal)); + // Transfer the data to the replicas. for (se::StreamExecutor* executor : replicas) { - TF_RETURN_IF_ERROR( - execute_backend_->transfer_manager()->TransferLiteralToDevice( - executor, literal, &allocation)); + if (executor->device_ordinal() == master_device_ordinal) { + TF_RETURN_IF_ERROR( + execute_backend_->transfer_manager()->TransferLiteralToDevice( + executor, *literal, *shaped_buffer)); + } else { + // The replica is not the master. Create an cloned shaped buffer with + // the replica's device ordinal. This is required because + // TransferLiteralToDevice verifies that the device ordinal of the shaped + // buffer matches that of the executor. + std::unique_ptr clone = + CloneShapedBufferOnDevice(*shaped_buffer, executor->device_ordinal()); + TF_RETURN_IF_ERROR( + execute_backend_->transfer_manager()->TransferLiteralToDevice( + executor, *literal, *clone)); + } } + TF_ASSIGN_OR_RETURN( + *result->mutable_data(), + allocation_tracker_.Register(std::move(shaped_buffer), + StrCat("TransferToServer literal of shape ", + ShapeUtil::HumanString(shape)))); + return tensorflow::Status::OK(); } @@ -1109,8 +1112,10 @@ tensorflow::Status Service::TransferToInfeed(const TransferToInfeedRequest* arg, executor = replicas[arg->replica_id()]; } + TF_ASSIGN_OR_RETURN(std::unique_ptr literal, + Literal::CreateFromProto(arg->literal())); return execute_backend_->transfer_manager()->TransferLiteralToInfeed( - executor, Literal(arg->literal())); + executor, *literal); } tensorflow::Status Service::TransferFromOutfeed( @@ -1185,7 +1190,22 @@ tensorflow::Status Service::ComputeConstant(const ComputeConstantRequest* arg, bool is_constant, user_computation->IsConstant(arg->operand(), arg->parameters_size())); if (!is_constant) { - return InvalidArgument("Operand to ComputeConstant depends on parameter."); + StatusOr op_request_status = + user_computation->LookUpRequestForErrorReporting(arg->operand()); + string op_request_string = ""; + if (op_request_status.ok()) { + op_request_string = op_request_status.ValueOrDie()->ShortDebugString(); + } + return InvalidArgument( + "Operand to ComputeConstant depends on a parameter.\n\n" + " op requested for constant evaluation: %s\n\n" + "This is an internal error that typically happens when the XLA user " + "(e.g. TensorFlow) is attempting to determine a value that must be a " + "compile-time constant (e.g. an array dimension) but it is not capable " + "of being evaluated at XLA compile time.\n\n" + "Please file a usability bug with the framework being used (e.g. " + "TensorFlow).", + op_request_string.c_str()); } // We can't use ComputeProgramShape because it checks that all parameter @@ -1222,18 +1242,16 @@ tensorflow::Status Service::ComputeConstant(const ComputeConstantRequest* arg, /*include_unreachable_instructions=*/ false)); - std::vector parameters(arg->parameters_size()); + std::vector> parameters(arg->parameters_size()); for (int64 i = 0; i < arg->parameters_size(); ++i) { - parameters[i] = Literal(arg->parameters(i)); + TF_ASSIGN_OR_RETURN(parameters[i], + Literal::CreateFromProto(arg->parameters(i))); } - std::vector parameter_ptrs; - std::transform(parameters.begin(), parameters.end(), - std::back_inserter(parameter_ptrs), - [](const Literal& literal) { return &literal; }); - HloEvaluator evaluator; - TF_ASSIGN_OR_RETURN(auto result_literal, - evaluator.Evaluate(*module, parameter_ptrs)); + TF_ASSIGN_OR_RETURN( + auto result_literal, + evaluator.Evaluate>(*module, parameters)); + // Since the shape_with_output_layout option in ExecutionOption is // non-effective to the Evaluator results, explicit relayout here. if (arg->has_output_layout()) { @@ -1246,9 +1264,9 @@ tensorflow::Status Service::ComputeConstant(const ComputeConstantRequest* arg, tensorflow::Status Service::GetShape(const GetShapeRequest* arg, GetShapeResponse* result) { - TF_ASSIGN_OR_RETURN(const Allocation* allocation, + TF_ASSIGN_OR_RETURN(const ShapedBuffer* buffer, allocation_tracker_.Resolve(arg->data())); - *result->mutable_shape() = allocation->shape(); + *result->mutable_shape() = buffer->on_host_shape(); return tensorflow::Status::OK(); } @@ -1357,6 +1375,17 @@ tensorflow::Status Service::Op(const OpRequest* arg, OpResponse* result) { handle_status = computation->AddConcatenateInstruction(arg->concatenate_request()); break; + case OpRequest::kConditionalRequest: { + TF_ASSIGN_OR_RETURN(UserComputation * true_computation, + computation_tracker_.Resolve( + arg->conditional_request().true_computation())); + TF_ASSIGN_OR_RETURN(UserComputation * false_computation, + computation_tracker_.Resolve( + arg->conditional_request().false_computation())); + handle_status = computation->AddConditionalInstruction( + arg->conditional_request(), *true_computation, *false_computation); + break; + } case OpRequest::kConstantRequest: handle_status = computation->AddConstantInstruction(arg->constant_request()); @@ -1381,6 +1410,9 @@ tensorflow::Status Service::Op(const OpRequest* arg, OpResponse* result) { handle_status = computation->AddCustomCallInstruction(arg->custom_call_request()); break; + case OpRequest::kDotRequest: + handle_status = computation->AddDotInstruction(arg->dot_request()); + break; case OpRequest::kDynamicSliceRequest: handle_status = computation->AddDynamicSliceInstruction(arg->dynamic_slice_request()); @@ -1389,6 +1421,9 @@ tensorflow::Status Service::Op(const OpRequest* arg, OpResponse* result) { handle_status = computation->AddDynamicUpdateSliceInstruction( arg->dynamic_update_slice_request()); break; + case OpRequest::kFftRequest: + handle_status = computation->AddFftInstruction(arg->fft_request()); + break; case OpRequest::kGetTupleElementRequest: handle_status = computation->AddGetTupleElementInstruction( arg->get_tuple_element_request()); @@ -1501,8 +1536,10 @@ tensorflow::Status Service::Op(const OpRequest* arg, OpResponse* result) { handle_status = computation->AddRecvInstruction(arg->recv_request()); break; } + case OpRequest::OP_NOT_SET: + return InvalidArgument("XLA service received OpRequest with OP_NOT_SET"); default: - return InvalidArgument("Unsupported operation"); + return InvalidArgument("Unsupported operation in XLA service"); } TF_ASSIGN_OR_RETURN(*result->mutable_output(), handle_status); diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h index 47f4f0ade594089aa71717ef1e122886b0a6c7ac..f962d0cdc7d41e1aeab55da5abcb1b40215b4144 100644 --- a/tensorflow/compiler/xla/service/service.h +++ b/tensorflow/compiler/xla/service/service.h @@ -250,7 +250,7 @@ class Service : public ServiceInterface { // class. StatusOr> CreateModuleConfig( const ProgramShape& program_shape, - tensorflow::gtl::ArraySlice arguments, + tensorflow::gtl::ArraySlice arguments, const ExecutionOptions& execution_options); protected: @@ -265,10 +265,10 @@ class Service : public ServiceInterface { // Resolves the given argument handles in the allocation tracker and returns // the corresponding allocations. The function also verifies that each - // allocation matches the given backend and device ordinal. - StatusOr> ResolveAndValidateArguments( + // allocation matches the execution platform and device ordinal. + StatusOr> ResolveAndValidateArguments( tensorflow::gtl::ArraySlice arguments, - const Backend* backend, int device_ordinal); + int device_ordinal); // Create a Hlo module config for the given program shape and arguments. // execution_options is optional; if not given a default is used. @@ -281,8 +281,6 @@ class Service : public ServiceInterface { StatusOr> BuildExecutable( const VersionedComputationHandle& versioned_handle, std::unique_ptr module_config, - const tensorflow::gtl::ArraySlice - arguments, Backend* backend, perftools::gputools::StreamExecutor* executor); // Same as BuildExecutable() above, but builds a list of Executables for the @@ -299,8 +297,6 @@ class Service : public ServiceInterface { StatusOr> BuildAndCacheExecutable( const VersionedComputationHandle& versioned_handle, std::unique_ptr module_config, - const tensorflow::gtl::ArraySlice - arguments, Backend* backend, perftools::gputools::StreamExecutor* executor, ExecutionProfile* profile); @@ -310,8 +306,7 @@ class Service : public ServiceInterface { // ExecutionProfile object which will be filled in with profile data. StatusOr ExecuteAndRegisterResult( Executable* executable, - const tensorflow::gtl::ArraySlice - arguments, + const tensorflow::gtl::ArraySlice arguments, Backend* backend, perftools::gputools::StreamExecutor* executor, const string& result_tag, ExecutionProfile* profile); @@ -320,9 +315,7 @@ class Service : public ServiceInterface { // from the tracker are returned. StatusOr> ExecuteParallelAndRegisterResult( tensorflow::gtl::ArraySlice executables, - tensorflow::gtl::ArraySlice< - std::vector> - arguments, + tensorflow::gtl::ArraySlice> arguments, Backend* backend, tensorflow::gtl::ArraySlice device_handles, tensorflow::gtl::ArraySlice result_tags, diff --git a/tensorflow/compiler/xla/service/service_executable_run_options.h b/tensorflow/compiler/xla/service/service_executable_run_options.h index 017e5ef09ed2f52b862821e9408540d188a1edf5..6c1f8feac7ed4423051cf2737be57dcfab508671 100644 --- a/tensorflow/compiler/xla/service/service_executable_run_options.h +++ b/tensorflow/compiler/xla/service/service_executable_run_options.h @@ -30,6 +30,9 @@ class ServiceExecutableRunOptions { using StreamBorrower = std::function::SmartPtr>(int)>; + ServiceExecutableRunOptions() + : ServiceExecutableRunOptions(ExecutableRunOptions()) {} + explicit ServiceExecutableRunOptions( ExecutableRunOptions run_options, StreamBorrower borrow_stream = nullptr, tensorflow::thread::ThreadPool* xla_intra_op_thread_pool = nullptr) diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index 3df1911d07cf0cd123604b1fac63923a725a37c6..a6d6c8b27f81045a4bee09e056c5c8f8e8a330c7 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/math/math_util.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" @@ -90,8 +91,6 @@ BinaryOperation OpcodeToBinaryOperation(HloOpcode opcode) { return BINOP_ATAN2; case HloOpcode::kComplex: return BINOP_COMPLEX; - case HloOpcode::kDot: - return BINOP_DOT; case HloOpcode::kMultiply: return BINOP_MUL; case HloOpcode::kAdd: @@ -549,8 +548,113 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, return ShapeUtil::MakeShape(operand_shape.element_type(), dimensions); } -/* static */ StatusOr ShapeInference::InferDotOpShape(const Shape& lhs, - const Shape& rhs) { +// Current DotDimensionNumbers Requirements: +// +// Contracting Dimensions: +// *) Exactly one contracting dimension on both lhs and rhs. +// *) Contracting dimension size must be the same on both lhs and rhs. +// *) Contracting dimension numbers do not need to be the same (i.e. transposes +// are passed on to emitter implementations). +// +// Batch Dimensions: +// *) Same number of batch dimensions on both lhs and rhs. +// *) Same batch dimension numbers (and sizes) on both lhs and rhs. +// *) Batch dimension numbers must be ordered before contracting and +// non-contracting/non-batch dimension numbers. +// +// Non-Contracting-Non-Batch Dimensions: +// *) Can be 0 (matrix-vector) or 1 (matrix-matrix). +// + +namespace { + +Status ValidateDotDimensionNumbers( + const Shape& lhs, const Shape& rhs, + const DotDimensionNumbers& dimension_numbers) { + // Check that dimension numbers are in range. + auto dims_in_range = + [](const int64 rank, tensorflow::gtl::ArraySlice contracting_dims, + tensorflow::gtl::ArraySlice batch_dims) -> bool { + auto in_range = [&rank](int64 i) -> bool { return 0 <= i && i < rank; }; + return std::all_of(contracting_dims.begin(), contracting_dims.end(), + in_range) && + std::all_of(batch_dims.begin(), batch_dims.end(), in_range); + }; + + tensorflow::gtl::ArraySlice lhs_contracting_dimensions = + AsInt64Slice(dimension_numbers.lhs_contracting_dimensions()); + tensorflow::gtl::ArraySlice rhs_contracting_dimensions = + AsInt64Slice(dimension_numbers.rhs_contracting_dimensions()); + tensorflow::gtl::ArraySlice lhs_batch_dimensions = + AsInt64Slice(dimension_numbers.lhs_batch_dimensions()); + tensorflow::gtl::ArraySlice rhs_batch_dimensions = + AsInt64Slice(dimension_numbers.rhs_batch_dimensions()); + + if (!dims_in_range(ShapeUtil::Rank(lhs), lhs_contracting_dimensions, + lhs_batch_dimensions) || + !dims_in_range(ShapeUtil::Rank(rhs), rhs_contracting_dimensions, + rhs_batch_dimensions)) { + return InvalidArgument("A dimension number is out of range in dot: %s", + dimension_numbers.DebugString().c_str()); + } + + // Check that dimension numbers are unique. + auto dims_unique = [](tensorflow::gtl::ArraySlice contracting_dims, + tensorflow::gtl::ArraySlice batch_dims) -> bool { + tensorflow::gtl::FlatSet dim_set; + auto is_unique = [&dim_set](int64 i) -> bool { + return dim_set.insert(i).second; + }; + return std::all_of(contracting_dims.begin(), contracting_dims.end(), + is_unique) && + std::all_of(batch_dims.begin(), batch_dims.end(), is_unique); + }; + + if (!dims_unique(lhs_contracting_dimensions, lhs_batch_dimensions) || + !dims_unique(rhs_contracting_dimensions, rhs_batch_dimensions)) { + return InvalidArgument("A dimension number is not unique in dot: %s", + dimension_numbers.DebugString().c_str()); + } + + // Check that the count of non-contracting-non-batch dimensions is in {0, 1}. + const int64 lhs_non_contracting_non_batch_dims = + ShapeUtil::Rank(lhs) - + dimension_numbers.lhs_contracting_dimensions_size() - + dimension_numbers.lhs_batch_dimensions_size(); + const int64 rhs_non_contracting_non_batch_dims = + ShapeUtil::Rank(rhs) - + dimension_numbers.rhs_contracting_dimensions_size() - + dimension_numbers.rhs_batch_dimensions_size(); + if (lhs_non_contracting_non_batch_dims < 0 || + lhs_non_contracting_non_batch_dims > 1 || + rhs_non_contracting_non_batch_dims < 0 || + rhs_non_contracting_non_batch_dims > 1) { + return InvalidArgument( + "batch and contracting dimension number mismatch " + "with rank "); + } + + // Check that batch dimension numbers are ordered before all others, and + // that they are monotonically increasing. + std::vector batch_dim_numbers(lhs_batch_dimensions.size()); + std::iota(batch_dim_numbers.begin(), batch_dim_numbers.end(), 0); + if (!std::equal(batch_dim_numbers.begin(), batch_dim_numbers.end(), + lhs_batch_dimensions.begin()) || + !std::equal(batch_dim_numbers.begin(), batch_dim_numbers.end(), + rhs_batch_dimensions.begin())) { + return InvalidArgument( + "batch dimension numbers must precede non-batch dimensions and be" + "monotonically increasing."); + } + + return Status::OK(); +} + +} // namespace + +/* static */ StatusOr ShapeInference::InferDotOpShape( + const Shape& lhs, const Shape& rhs, + const DotDimensionNumbers& dimension_numbers) { TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(lhs, "lhs of dot")); TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(rhs, "rhs of dot")); @@ -570,37 +674,62 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, return fail("element types do not match"); } - if (ShapeUtil::Rank(lhs) < 1 || ShapeUtil::Rank(lhs) > 2 || - ShapeUtil::Rank(rhs) < 1 || ShapeUtil::Rank(rhs) > 2) { - return fail("dot only supports rank 1 or 2"); + if ((ShapeUtil::Rank(lhs) < 1) || (ShapeUtil::Rank(rhs) < 1)) { + return fail("dot only supports rank 1 or above."); } - // Determine the index of the contracted dimensions for input tensors. - // dimensions -1 of lhs and dimension 0 of rhs are contracted. - int64 lhs_contracted_dimension = ShapeUtil::GetDimensionNumber(lhs, -1); - int64 rhs_contracted_dimension = 0; + // Validate basic properties of dot dimension numbers. + TF_RETURN_IF_ERROR(ValidateDotDimensionNumbers(lhs, rhs, dimension_numbers)); + + // Check that there is only one contracting dimension for both lhs and rhs. + if (dimension_numbers.lhs_contracting_dimensions_size() != + dimension_numbers.rhs_contracting_dimensions_size() || + dimension_numbers.lhs_contracting_dimensions_size() != 1) { + return fail("must specify one contracting dimension for both lhs and rhs."); + } - // Check if the contracted dimension sizes are the same. - if ((lhs_contracted_dimension < ShapeUtil::Rank(lhs) && - rhs_contracted_dimension < ShapeUtil::Rank(rhs)) && - lhs.dimensions(lhs_contracted_dimension) != - rhs.dimensions(rhs_contracted_dimension)) { - return fail("contracted dimensions mismatch"); + // Check that contracting dimension sizes match. + const int64 lhs_contracting_dimension = + dimension_numbers.lhs_contracting_dimensions(0); + const int64 rhs_contracting_dimension = + dimension_numbers.rhs_contracting_dimensions(0); + if (lhs.dimensions(lhs_contracting_dimension) != + rhs.dimensions(rhs_contracting_dimension)) { + return fail("contracting dimension sizes do not match."); + } + + // Check that number of batch dimensions match. + if (dimension_numbers.lhs_batch_dimensions_size() != + dimension_numbers.rhs_batch_dimensions_size()) { + return fail("must the same number of batch dimensions for lhs and rhs."); + } + + // Check that batch dimension numbers and sizes match. + for (int64 i = 0; i < dimension_numbers.lhs_batch_dimensions_size(); ++i) { + if (dimension_numbers.lhs_batch_dimensions(i) != + dimension_numbers.rhs_batch_dimensions(i) || + lhs.dimensions(dimension_numbers.lhs_batch_dimensions(i)) != + rhs.dimensions(dimension_numbers.rhs_batch_dimensions(i))) { + return fail("batch dimension numbers and sizes must match for lhs/rhs."); + } } // The ranks of lhs and rhs are decremented by 1 respectively due to the // contraction, and added for the rank of the result. When an input tensor is // a scalar, its contribution to the rank of the result is 0. // Generate the result dimensions in order, rhs dimensions followed by lhs - // dimensions except the contracted dimensions. + // dimensions except the contracted and batch dimensions. std::vector dimensions; + std::unordered_set rhs_batch_dims( + dimension_numbers.rhs_batch_dimensions().begin(), + dimension_numbers.rhs_batch_dimensions().end()); for (int64 i = 0; i < ShapeUtil::Rank(lhs); i++) { - if (i != lhs_contracted_dimension) { + if (i != lhs_contracting_dimension) { dimensions.push_back(lhs.dimensions(i)); } } for (int64 i = 0; i < ShapeUtil::Rank(rhs); i++) { - if (i != rhs_contracted_dimension) { + if (i != rhs_contracting_dimension && rhs_batch_dims.count(i) == 0) { dimensions.push_back(rhs.dimensions(i)); } } @@ -816,8 +945,6 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( rhs, tensorflow::strings::StrCat("rhs of binary operation ", BinaryOperation_Name(operation)))); switch (operation) { - case BINOP_DOT: - return InferDotOpShape(lhs, rhs); case BINOP_MAX: case BINOP_MIN: case BINOP_SUB: @@ -1588,11 +1715,103 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( return ShapeUtil::MakeShape(lhs.element_type(), dimensions); } +/* static */ StatusOr ShapeInference::InferFftShape( + const Shape& in, const FftType fft_type, + const tensorflow::gtl::ArraySlice fft_length) { + const int64 fft_rank = fft_length.size(); + if (fft_rank < 1 || fft_rank > 3) { + return InvalidArgument("FFT only supports ranks 1-3, but got %lld", + fft_rank); + } +#define RET_CHECK_RANK(x) \ + if (x.dimensions_size() < fft_rank) { \ + return InvalidArgument( \ + "FFT of rank %lld requires input of at least " \ + "same rank; got input of rank %d", \ + fft_rank, x.dimensions_size()); \ + } + switch (fft_type) { + case FFT: + case IFFT: + if (in.element_type() != C64) { + return InvalidArgument("%s requires C64 input type, found %s", + FftType_Name(fft_type).c_str(), + PrimitiveType_Name(in.element_type()).c_str()); + } + RET_CHECK_RANK(in); + return in; + case RFFT: { + if (in.element_type() != F32) { + return InvalidArgument("RFFT requires F32 input type, found %s", + PrimitiveType_Name(in.element_type()).c_str()); + } + RET_CHECK_RANK(in); + for (int i = 0; i < fft_rank; i++) { + if (in.dimensions(in.dimensions_size() - fft_rank + i) != + fft_length[i]) { + return InvalidArgument( + "RFFT requires innermost dimensions match fft_length but " + "dimension %lld is %lld and should be %lld", + in.dimensions_size() - fft_rank + i, + in.dimensions(in.dimensions_size() - fft_rank + i), + fft_length[i]); + } + } + Shape result = ShapeUtil::ChangeElementType(in, C64); + result.set_dimensions(result.dimensions_size() - 1, + fft_length[fft_rank - 1] / 2 + 1); + return result; + } + case IRFFT: { + if (in.element_type() != C64) { + return InvalidArgument("IRFFT requires C64 input type, found %s", + PrimitiveType_Name(in.element_type()).c_str()); + } + RET_CHECK_RANK(in); + Shape result = ShapeUtil::ComplexComponentShape(in); + for (int i = 0; i < fft_rank - 1; i++) { + if (in.dimensions(in.dimensions_size() - fft_rank + i) != + fft_length[i]) { + return InvalidArgument( + "IRFFT requires all but one innermost dimensions match " + "fft_length, but dimension %lld is %lld and should be %lld", + in.dimensions_size() - fft_rank + i, + in.dimensions(in.dimensions_size() - fft_rank + i), + fft_length[i]); + } + } + if (in.dimensions(in.dimensions_size() - 1) != + fft_length[fft_rank - 1] / 2 + 1) { + return InvalidArgument( + "IRFFT requires innermost dimension matches fft_length/2+1, but " + "dimension %d is %lld and should be %lld", + in.dimensions_size() - 1, in.dimensions(in.dimensions_size() - 1), + fft_length[fft_rank - 1] / 2 + 1); + } + result.set_dimensions(result.dimensions_size() - 1, + fft_length[fft_rank - 1]); + return result; + } + default: + LOG(FATAL) << "Unexpected fft_type: " << fft_type; + } +#undef RET_CHECK_RANK +} + /* static */ StatusOr ShapeInference::InferCrossReplicaSumShape( - const Shape& operand) { - TF_RETURN_IF_ERROR( - ExpectNotTupleOrOpaque(operand, "operand of cross replica sum")); - return operand; + tensorflow::gtl::ArraySlice operand_shapes) { + for (const Shape* operand_shape : operand_shapes) { + TF_RETURN_IF_ERROR( + ExpectNotTupleOrOpaque(*operand_shape, "operand of cross replica sum")); + } + if (operand_shapes.size() == 1) { + return *operand_shapes[0]; + } + std::vector operand_shape_values; + for (const Shape* operand_shape : operand_shapes) { + operand_shape_values.push_back(*operand_shape); + } + return ShapeUtil::MakeTupleShape(operand_shape_values); } /* static */ StatusOr ShapeInference::InferReduceShape( @@ -1958,6 +2177,64 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( return init; } +/* static */ StatusOr ShapeInference::InferConditionalShape( + const Shape& predicate, const Shape& true_operand, + const Shape& false_operand, const ProgramShape& true_computation, + const ProgramShape& false_computation) { + if (!ShapeUtil::ShapeIs(predicate, PRED, {})) { + return InvalidArgument("predicate must be a boolean; got %s.", + ShapeUtil::HumanString(predicate).c_str()); + } + + if (true_computation.parameters_size() != 1) { + return InvalidArgument("true_computation must take 1 argument; got %d.", + true_computation.parameters_size()); + } + if (!ShapeUtil::Compatible(true_computation.parameters(0), true_operand)) { + auto true_shape_string = [&]() { + return tensorflow::strings::Printf( + "true_operand: %s; true_computation: %s", + ShapeUtil::HumanString(true_operand).c_str(), + ShapeUtil::HumanString(true_computation).c_str()); + }; + return InvalidArgument( + "true_operand must match the shape of the only parameter of " + "true_computation: got %s.", + true_shape_string().c_str()); + } + + if (false_computation.parameters_size() != 1) { + return InvalidArgument("false_computation must take 1 argument; got %d.", + false_computation.parameters_size()); + } + if (!ShapeUtil::Compatible(false_computation.parameters(0), false_operand)) { + auto false_shape_string = [&]() { + return tensorflow::strings::Printf( + "false_operand: %s; false_computation: %s", + ShapeUtil::HumanString(false_operand).c_str(), + ShapeUtil::HumanString(false_computation).c_str()); + }; + return InvalidArgument( + "false_operand must match the shape of the only parameter of " + "false_computation: got %s.", + false_shape_string().c_str()); + } + if (!ShapeUtil::Compatible(true_computation.result(), + false_computation.result())) { + auto shape_string = [&]() { + return tensorflow::strings::Printf( + "true_computation result: %s; false_computation result: %s.", + ShapeUtil::HumanString(true_computation.result()).c_str(), + ShapeUtil::HumanString(false_computation.result()).c_str()); + }; + return InvalidArgument( + "the result of true_computation and false_computation must have the " + "same shape: got %s.", + shape_string().c_str()); + } + return true_computation.result(); +} + /* static */ StatusOr ShapeInference::InferBroadcastShape( const Shape& operand, tensorflow::gtl::ArraySlice broadcast_sizes) { TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(operand, "operand of broadcast")); diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h index 0aadb98a407c2160b60e686f6f3ea250bb9e838f..b39151ebbc19f5d0b702a80da5069f58c8dfb07d 100644 --- a/tensorflow/compiler/xla/service/shape_inference.h +++ b/tensorflow/compiler/xla/service/shape_inference.h @@ -109,8 +109,15 @@ class ShapeInference { const Shape& lhs, const Shape& rhs, const Window& window, const ConvolutionDimensionNumbers& dimension_numbers); - // Infers the shape produced a cross replica sum with the given operand shape. - static StatusOr InferCrossReplicaSumShape(const Shape& operand); + // Infers the shape produced by the given FFT type on the given operand. + static StatusOr InferFftShape( + const Shape& in, FftType fft_type, + tensorflow::gtl::ArraySlice fft_length); + + // Infers the shape produced a cross replica sum with the given operand + // shapes. + static StatusOr InferCrossReplicaSumShape( + tensorflow::gtl::ArraySlice operand_shapes); // Infers the shape produced by applying the given reduction computation // shape to the given input operand shape. @@ -178,6 +185,12 @@ class ShapeInference { const ProgramShape& body, const Shape& init); + // Infers the shape produced by a conditional operation. + static StatusOr InferConditionalShape( + const Shape& predicate, const Shape& true_operand, + const Shape& false_operand, const ProgramShape& true_computation, + const ProgramShape& false_computation); + // Infers the shape produced by a broadcast operation. static StatusOr InferBroadcastShape( const Shape& operand, tensorflow::gtl::ArraySlice broadcast_sizes); @@ -229,11 +242,13 @@ class ShapeInference { tensorflow::gtl::ArraySlice arg_shapes, const ProgramShape& to_apply); - private: // Helper that infers the shape produced by performing a dot operation with // the given LHS and RHS shapes. - static StatusOr InferDotOpShape(const Shape& lhs, const Shape& rhs); + static StatusOr InferDotOpShape( + const Shape& lhs, const Shape& rhs, + const DotDimensionNumbers& dimension_numbers); + private: // Helper that infers the shape produced by performing an element-wise binary // operation with the given LHS and RHS shapes. // Note: By "element-wise" we mean operations that look at a single element in diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc index be93c879c0b7fd74c3b93e28c6dc0f5c656a522a..99d87f3b550ae72befe254f23fad080dd210aaf4 100644 --- a/tensorflow/compiler/xla/service/shape_inference_test.cc +++ b/tensorflow/compiler/xla/service/shape_inference_test.cc @@ -898,8 +898,11 @@ TEST_F(ShapeInferenceTest, BroadcastScalar) { // scalar vector: error TEST_F(ShapeInferenceTest, ScalarDotVector) { + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); auto inferred_status = - ShapeInference::InferBinaryOpShape(BINOP_DOT, f32_, vector_32_, {}); + ShapeInference::InferDotOpShape(f32_, vector_32_, dot_dnums); ASSERT_FALSE(inferred_status.ok()); ASSERT_THAT(inferred_status.status().error_message(), HasSubstr("dot only supports rank")); @@ -907,61 +910,199 @@ TEST_F(ShapeInferenceTest, ScalarDotVector) { // 3D 2D: error TEST_F(ShapeInferenceTest, DotWithRankHigherThanTwo) { - auto inferred_status = ShapeInference::InferBinaryOpShape( - BINOP_DOT, ShapeUtil::MakeShape(F32, {32, 32, 32}), matrix_32_64_, {}); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + auto inferred_status = ShapeInference::InferDotOpShape( + ShapeUtil::MakeShape(F32, {32, 32, 32}), matrix_32_64_, dot_dnums); ASSERT_FALSE(inferred_status.ok()); ASSERT_THAT(inferred_status.status().error_message(), - HasSubstr("dot only supports rank")); + HasSubstr("batch and contracting dimension number mismatch")); } // vector vector -> scalar TEST_F(ShapeInferenceTest, VectorDotVector) { + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(0); + dot_dnums.add_rhs_contracting_dimensions(0); auto inferred_status = - ShapeInference::InferBinaryOpShape(BINOP_DOT, vector_64_, vector_64_, {}); + ShapeInference::InferDotOpShape(vector_64_, vector_64_, dot_dnums); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(f32_, inferred_status.ValueOrDie())); auto inferred_status_mismatch = - ShapeInference::InferBinaryOpShape(BINOP_DOT, vector_64_, vector_32_, {}); + ShapeInference::InferDotOpShape(vector_64_, vector_32_, dot_dnums); ASSERT_FALSE(inferred_status_mismatch.ok()); } // matrix vector -> vector TEST_F(ShapeInferenceTest, MatrixDotVector) { - auto inferred_status = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_DOT, matrix_32_64_, vector_64_, {}); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + auto inferred_status = + ShapeInference::InferDotOpShape(matrix_32_64_, vector_64_, dot_dnums); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(), vector_32_)); - auto inferred_status_mismatch = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_DOT, matrix_32_64_, vector_32_, {}); + auto inferred_status_mismatch = + ShapeInference::InferDotOpShape(matrix_32_64_, vector_32_, dot_dnums); ASSERT_FALSE(inferred_status_mismatch.ok()); } // vector matrix -> vector TEST_F(ShapeInferenceTest, VectorDotMatrix) { - auto inferred_status = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_DOT, vector_32_, matrix_32_64_, {}); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(0); + dot_dnums.add_rhs_contracting_dimensions(0); + auto inferred_status = + ShapeInference::InferDotOpShape(vector_32_, matrix_32_64_, dot_dnums); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(), vector_64_)); - auto inferred_status_mismatch = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_DOT, vector_64_, matrix_32_64_, {}); + auto inferred_status_mismatch = + ShapeInference::InferDotOpShape(vector_64_, matrix_32_64_, dot_dnums); ASSERT_FALSE(inferred_status_mismatch.ok()); } // matrix matrix -> matrix TEST_F(ShapeInferenceTest, MatrixDotMatrix) { - auto inferred_status_match = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_DOT, matrix_32_64_, matrix_64_48_, {}); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + auto inferred_status_match = + ShapeInference::InferDotOpShape(matrix_32_64_, matrix_64_48_, dot_dnums); ASSERT_IS_OK(inferred_status_match.status()); ASSERT_TRUE( ShapeUtil::Equal(inferred_status_match.ValueOrDie(), matrix_32_48_)) << "inferred: " << ShapeUtil::HumanString(inferred_status_match.ValueOrDie()) << " expected: " << ShapeUtil::HumanString(matrix_64_48_); - auto inferred_status_mismatch = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_DOT, matrix_32_64_, matrix_32_64_, {}); + auto inferred_status_mismatch = + ShapeInference::InferDotOpShape(matrix_32_64_, matrix_32_64_, dot_dnums); ASSERT_FALSE(inferred_status_mismatch.ok()); } +// BatchMatMul with two batch dimensions and one contracting dimension. +TEST_F(ShapeInferenceTest, DotGeneral) { + Shape lhs_shape = ShapeUtil::MakeShape(F32, {5, 2, 11, 3}); + Shape rhs_shape = ShapeUtil::MakeShape(F32, {5, 2, 3, 14}); + Shape output_shape = ShapeUtil::MakeShape(F32, {5, 2, 11, 14}); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(3); + dot_dnums.add_lhs_batch_dimensions(0); + dot_dnums.add_lhs_batch_dimensions(1); + + dot_dnums.add_rhs_contracting_dimensions(2); + dot_dnums.add_rhs_batch_dimensions(0); + dot_dnums.add_rhs_batch_dimensions(1); + + auto inferred_status_match = + ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums); + ASSERT_IS_OK(inferred_status_match.status()); + ASSERT_TRUE( + ShapeUtil::Equal(inferred_status_match.ValueOrDie(), output_shape)) + << "inferred: " + << ShapeUtil::HumanString(inferred_status_match.ValueOrDie()) + << " expected: " << ShapeUtil::HumanString(output_shape); +} + +// BatchMatMul with two contracting dimensions fails. +TEST_F(ShapeInferenceTest, DotWithTwoContractingDimsFails) { + Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3, 2}); + Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 3, 14}); + Shape output_shape = ShapeUtil::MakeShape(F32, {2, 11, 14}); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(2); + dot_dnums.add_lhs_contracting_dimensions(3); + dot_dnums.add_lhs_batch_dimensions(0); + + dot_dnums.add_rhs_contracting_dimensions(1); + dot_dnums.add_rhs_batch_dimensions(0); + + auto inferred_status = + ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums); + ASSERT_FALSE(inferred_status.ok()); + ASSERT_THAT(inferred_status.status().error_message(), + HasSubstr("must specify one contracting dimension for both " + "lhs and rhs")); +} + +// BatchMatMul with different batch dimension sizes fails. +TEST_F(ShapeInferenceTest, DotWithMisatchedBatchDimSizesFails) { + Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3}); + Shape rhs_shape = ShapeUtil::MakeShape(F32, {3, 3, 14}); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(2); + dot_dnums.add_lhs_batch_dimensions(0); + + dot_dnums.add_rhs_contracting_dimensions(1); + dot_dnums.add_rhs_batch_dimensions(0); + + auto inferred_status = + ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums); + ASSERT_FALSE(inferred_status.ok()); + ASSERT_THAT(inferred_status.status().error_message(), + HasSubstr("batch dimension numbers and sizes must match")); +} + +// BatchMatMul with different batch dimension numbers fails. +TEST_F(ShapeInferenceTest, DotWithMisatchedBatchDimNumbersFails) { + Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3}); + Shape rhs_shape = ShapeUtil::MakeShape(F32, {3, 2, 14}); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(2); + dot_dnums.add_lhs_batch_dimensions(0); + + dot_dnums.add_rhs_contracting_dimensions(0); + dot_dnums.add_rhs_batch_dimensions(1); + + auto inferred_status = + ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums); + ASSERT_FALSE(inferred_status.ok()); + ASSERT_THAT(inferred_status.status().error_message(), + HasSubstr("batch dimension numbers must precede non-batch")); +} + +// BatchMatMul with out-of-range dimension numbers fails. +TEST_F(ShapeInferenceTest, DotWithContractingDimNumberOutOfRange) { + Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3}); + Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 3, 14}); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(3); + dot_dnums.add_lhs_batch_dimensions(0); + + dot_dnums.add_rhs_contracting_dimensions(0); + dot_dnums.add_rhs_batch_dimensions(1); + + auto inferred_status = + ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums); + ASSERT_FALSE(inferred_status.ok()); + ASSERT_THAT(inferred_status.status().error_message(), + HasSubstr("A dimension number is out of range")); +} + +// BatchMatMul with non-unique dimension numbers fails. +TEST_F(ShapeInferenceTest, DotWithContractingNonUniqueDimNumber) { + Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3}); + Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 3, 14}); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(0); + dot_dnums.add_lhs_batch_dimensions(0); + + dot_dnums.add_rhs_contracting_dimensions(0); + dot_dnums.add_rhs_batch_dimensions(1); + + auto inferred_status = + ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums); + ASSERT_FALSE(inferred_status.ok()); + ASSERT_THAT(inferred_status.status().error_message(), + HasSubstr("A dimension number is not unique")); +} + TEST_F(ShapeInferenceTest, BinOpBroadcastMatrixVector) { // Test variations of broadcasting a vector for a binary add with a // matrix. @@ -1296,5 +1437,80 @@ TEST_F(ShapeInferenceTest, Transpose) { ShapeUtil::MakeShape(F32, {3, 4, 5, 2}))); } +TEST_F(ShapeInferenceTest, Conditional) { + auto inferred_status0 = ShapeInference::InferConditionalShape( + pred_, vector_32_, vector_64_, + ShapeUtil::MakeProgramShape({vector_32_}, f32_), + ShapeUtil::MakeProgramShape({vector_64_}, f32_)); + EXPECT_IS_OK(inferred_status0.status()); + EXPECT_TRUE(ShapeUtil::Equal(f32_, inferred_status0.ValueOrDie())); + + auto inferred_status1 = ShapeInference::InferConditionalShape( + pred_, matrix_32_48_, vector_32_, + ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_64_), + ShapeUtil::MakeProgramShape({vector_32_}, vector_64_)); + EXPECT_IS_OK(inferred_status1.status()); + EXPECT_TRUE(ShapeUtil::Equal(vector_64_, inferred_status1.ValueOrDie())); + + auto tuple_f32_v32 = ShapeUtil::MakeTupleShape({f32_, vector_32_}); + auto inferred_status2 = ShapeInference::InferConditionalShape( + pred_, matrix_32_48_, tuple_f32_v32, + ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_), + ShapeUtil::MakeProgramShape({tuple_f32_v32}, vector_32_)); + EXPECT_IS_OK(inferred_status2.status()); + EXPECT_TRUE(ShapeUtil::Equal(vector_32_, inferred_status2.ValueOrDie())); + + auto inferred_status_error0 = ShapeInference::InferConditionalShape( + s32_, vector_32_, vector_64_, + ShapeUtil::MakeProgramShape({vector_32_}, f32_), + ShapeUtil::MakeProgramShape({vector_64_}, f32_)); + EXPECT_FALSE(inferred_status_error0.ok()); + EXPECT_THAT(inferred_status_error0.status().error_message(), + HasSubstr("predicate must be a boolean")); + + auto inferred_status_error1 = ShapeInference::InferConditionalShape( + pred_, ShapeUtil::MakeTupleShape({f32_, vector_32_}), matrix_32_48_, + ShapeUtil::MakeProgramShape({f32_, vector_32_}, vector_32_), + ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_)); + EXPECT_FALSE(inferred_status_error1.ok()); + EXPECT_THAT(inferred_status_error1.status().error_message(), + HasSubstr("true_computation must take 1 argument")); + + auto inferred_status_error2 = ShapeInference::InferConditionalShape( + pred_, vector_32_, vector_64_, + ShapeUtil::MakeProgramShape({vector_64_}, f32_), + ShapeUtil::MakeProgramShape({vector_64_}, f32_)); + EXPECT_FALSE(inferred_status_error2.ok()); + EXPECT_THAT(inferred_status_error2.status().error_message(), + HasSubstr("true_operand must match the shape of the only " + "parameter of true_computation")); + + auto inferred_status_error3 = ShapeInference::InferConditionalShape( + pred_, matrix_32_48_, ShapeUtil::MakeTupleShape({f32_, vector_32_}), + ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_), + ShapeUtil::MakeProgramShape({f32_, vector_32_}, vector_32_)); + EXPECT_FALSE(inferred_status_error3.ok()); + EXPECT_THAT(inferred_status_error3.status().error_message(), + HasSubstr("false_computation must take 1 argument")); + + auto inferred_status_error4 = ShapeInference::InferConditionalShape( + pred_, vector_32_, vector_64_, + ShapeUtil::MakeProgramShape({vector_32_}, f32_), + ShapeUtil::MakeProgramShape({vector_32_}, f32_)); + EXPECT_FALSE(inferred_status_error4.ok()); + EXPECT_THAT(inferred_status_error4.status().error_message(), + HasSubstr("false_operand must match the shape of the only " + "parameter of false_computation")); + + auto inferred_status_error5 = ShapeInference::InferConditionalShape( + pred_, vector_32_, vector_64_, + ShapeUtil::MakeProgramShape({vector_32_}, f32_), + ShapeUtil::MakeProgramShape({vector_64_}, vector_32_)); + EXPECT_FALSE(inferred_status_error5.ok()); + EXPECT_THAT(inferred_status_error5.status().error_message(), + HasSubstr("the result of true_computation and false_computation " + "must have the same shape")); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/shaped_buffer.cc b/tensorflow/compiler/xla/service/shaped_buffer.cc index a7539a1a11d2bbd62c780890c6730dbb212307c4..c679d401c3691b14a43ce77cbe953cd4c64a9e92 100644 --- a/tensorflow/compiler/xla/service/shaped_buffer.cc +++ b/tensorflow/compiler/xla/service/shaped_buffer.cc @@ -34,58 +34,32 @@ namespace xla { using ::tensorflow::strings::Appendf; -/* static */ StatusOr> -ShapedBuffer::MakeArrayShapedBuffer(const Shape& shape, - const se::Platform* platform, - int device_ordinal, - const se::DeviceMemoryBase& buffer) { - if (ShapeUtil::IsTuple(shape)) { - return InvalidArgument("Shape must be an array: %s", - ShapeUtil::HumanStringWithLayout(shape).c_str()); - } - auto shaped_buffer = - MakeUnique(shape, platform, device_ordinal); - *shaped_buffer->mutable_shape_index_to_buffer_entry()->mutable_element({}) = - 0; - *shaped_buffer->mutable_buffers() = {buffer}; - return std::move(shaped_buffer); -} - -ShapedBuffer::ShapedBuffer(const Shape& shape, const se::Platform* platform, - int device_ordinal) - : shape_(shape), +ShapedBuffer::ShapedBuffer(const Shape& on_host_shape, + const Shape& on_device_shape, + const se::Platform* platform, int device_ordinal) + : on_host_shape_(on_host_shape), + on_device_shape_(on_device_shape), platform_(platform), device_ordinal_(device_ordinal), - shape_index_to_buffer_entry_(shape) {} + buffers_(on_device_shape) {} void ShapedBuffer::clear() { - for (se::DeviceMemoryBase& memory_base : buffers_) { + for (auto& pair : buffers_) { // A default constructed DeviceMemoryBase is a null pointer. - memory_base = se::DeviceMemoryBase(); + pair.second = se::DeviceMemoryBase(); } } -void ShapedBuffer::AddBufferAtIndex( - const perftools::gputools::DeviceMemoryBase& buffer, - const ShapeIndex& shape_index) { - *mutable_shape_index_to_buffer_entry()->mutable_element(shape_index) = - buffers().size(); - mutable_buffers()->push_back(buffer); -} - -const se::DeviceMemoryBase& ShapedBuffer::buffer( - const ShapeIndex& index) const { - return buffers_[shape_index_to_buffer_entry_.element(index)]; -} - -se::DeviceMemoryBase* ShapedBuffer::mutable_buffer(const ShapeIndex& index) { - return &buffers_[shape_index_to_buffer_entry_.element(index)]; -} - string ShapedBuffer::ToString() const { - string s = "ShapedBuffer(" + platform_->Name() + "):\n"; + string s = tensorflow::strings::StrCat( + "ShapedBuffer(", platform_->Name(), ":", device_ordinal(), + "), on-host shape=" + ShapeUtil::HumanStringWithLayout(on_host_shape()), + ", on-device shape=" + + ShapeUtil::HumanStringWithLayout(on_device_shape()), + ":\n"); ShapeUtil::ForEachSubshape( - shape(), [this, &s](const Shape& subshape, const ShapeIndex& index) { + on_device_shape(), + [this, &s](const Shape& subshape, const ShapeIndex& index) { string shape_str; if (ShapeUtil::IsTuple(subshape)) { shape_str = "tuple"; @@ -105,53 +79,24 @@ std::ostream& operator<<(std::ostream& out, const ShapedBuffer& buffer) { return out; } -/* static */ StatusOr> -ScopedShapedBuffer::Allocate( - const Shape& shape, DeviceMemoryAllocator* allocator, int device_ordinal, - const std::function& shape_size_fn) { - if (!LayoutUtil::HasLayout(shape)) { - return InvalidArgument("Shape must have a layout: %s", - ShapeUtil::HumanStringWithLayout(shape).c_str()); - } - TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(shape)); - auto shaped_buffer = - WrapUnique(new ScopedShapedBuffer(shape, allocator, device_ordinal)); - - // Allocate an appropriate sized buffer for each element in the shape - // including the tuple pointer arrays. - for (auto& pair : shaped_buffer->shape_index_to_buffer_entry_) { - const ShapeIndex& index = pair.first; - size_t& buffer_entry = pair.second; - TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase memory_base, - shaped_buffer->allocator_->Allocate( - shaped_buffer->device_ordinal(), - shape_size_fn(ShapeUtil::GetSubshape( - shaped_buffer->shape(), index)))); - shaped_buffer->buffers_.push_back(memory_base); - buffer_entry = shaped_buffer->buffers_.size() - 1; - } - - return std::move(shaped_buffer); -} - /* static */ StatusOr> ScopedShapedBuffer::MakeScoped( ShapedBuffer* shaped_buffer, DeviceMemoryAllocator* allocator) { auto scoped_buffer = WrapUnique(new ScopedShapedBuffer( - shaped_buffer->shape(), allocator, shaped_buffer->device_ordinal())); + shaped_buffer->on_host_shape(), shaped_buffer->on_device_shape(), + allocator, shaped_buffer->device_ordinal())); scoped_buffer->buffers_ = shaped_buffer->buffers(); - scoped_buffer->shape_index_to_buffer_entry_ = - shaped_buffer->shape_index_to_buffer_entry(); - shaped_buffer->clear(); return std::move(scoped_buffer); } -ScopedShapedBuffer::ScopedShapedBuffer(const Shape& shape, +ScopedShapedBuffer::ScopedShapedBuffer(const Shape& on_host_shape, + const Shape& on_device_shape, DeviceMemoryAllocator* allocator, int device_ordinal) - : ShapedBuffer(shape, allocator->platform(), device_ordinal), + : ShapedBuffer(on_host_shape, on_device_shape, allocator->platform(), + device_ordinal), allocator_(allocator) {} ScopedShapedBuffer::~ScopedShapedBuffer() { @@ -159,7 +104,8 @@ ScopedShapedBuffer::~ScopedShapedBuffer() { // in the shape (eg, a tuple with a repeated element) so keep track of what // has been deallocated. std::set deallocated_opaques; - for (se::DeviceMemoryBase& memory_base : buffers_) { + for (auto& pair : buffers_) { + se::DeviceMemoryBase& memory_base = pair.second; if (!memory_base.is_null() && deallocated_opaques.count(memory_base.opaque()) == 0) { deallocated_opaques.insert(memory_base.opaque()); @@ -170,13 +116,10 @@ ScopedShapedBuffer::~ScopedShapedBuffer() { } std::unique_ptr ScopedShapedBuffer::release() { - auto shaped_buffer = - MakeUnique(shape(), platform(), device_ordinal()); - - *shaped_buffer->mutable_buffers() = buffers(); - *shaped_buffer->mutable_shape_index_to_buffer_entry() = - shape_index_to_buffer_entry(); + auto shaped_buffer = MakeUnique( + on_host_shape(), on_device_shape(), platform(), device_ordinal()); + shaped_buffer->buffers() = buffers(); clear(); return shaped_buffer; diff --git a/tensorflow/compiler/xla/service/shaped_buffer.h b/tensorflow/compiler/xla/service/shaped_buffer.h index fa88caa13ff734995e8ab0925f17d0d3c26b8fda..d397e47d2ca734458c7dc99baa5c81b16d0fd72b 100644 --- a/tensorflow/compiler/xla/service/shaped_buffer.h +++ b/tensorflow/compiler/xla/service/shaped_buffer.h @@ -31,61 +31,68 @@ limitations under the License. namespace xla { // Class which encapsulates a buffer or set of buffers containing data of a -// particular XLA shape. Used for zero-copy execution interface for a -// XLA client running in the same process as the service (LocalClient), +// particular XLA shape. class ShapedBuffer { public: - // Convenience method which creates a ShapedBuffer of array shape (not a - // tuple). Its single buffer pointer is set to the given value "buffer". The - // given buffer must be large enough to store the given shape as given by - // ShapeUtil::ByteSizeOf. - static StatusOr> MakeArrayShapedBuffer( - const Shape& shape, const perftools::gputools::Platform* platform, - int device_ordinal, const perftools::gputools::DeviceMemoryBase& buffer); - - ShapedBuffer(const Shape& shape, + // Construct a ShapedBuffer with null DeviceMemoryBases at each index. The + // shape of the data on the host and the device may differ because the device + // may have a different representation for different data types. Therefore, + // both the on-host and on-device shape are required. The on-device shape + // determines the number of device allocations (DeviceMemoryBase) held by the + // ShapedBuffer. + ShapedBuffer(const Shape& on_host_shape, const Shape& on_device_shape, const perftools::gputools::Platform* platform, int device_ordinal); - const Shape& shape() const { return shape_; } + // Returns the shape of the on-host representation of the data held by this + // ShapedBuffer. + const Shape& on_host_shape() const { return on_host_shape_; } + + // Returns the shape of the on-device representation of the data held by this + // ShapedBuffer. + const Shape& on_device_shape() const { return on_device_shape_; } + const perftools::gputools::Platform* platform() const { return platform_; } int device_ordinal() const { return device_ordinal_; } + // Return the root buffer of the shape (shape index {}). + const perftools::gputools::DeviceMemoryBase& root_buffer() const { + return buffer(/*index=*/{}); + } + // Returns the buffer at the given shape index where index is defined as in // ShapeUtil::GetSubshape. const perftools::gputools::DeviceMemoryBase& buffer( - const ShapeIndex& index) const; - perftools::gputools::DeviceMemoryBase* mutable_buffer( - const ShapeIndex& index); - - // Returns the underlying structure which stores the buffer pointers. - const std::vector& buffers() const { - return buffers_; + const ShapeIndex& index) const { + return buffers_.element(index); } - std::vector* mutable_buffers() { - return &buffers_; + + // Sets the device memory buffer at the given index. + void set_buffer(const perftools::gputools::DeviceMemoryBase& buffer, + const ShapeIndex& index) { + *buffers_.mutable_element(index) = buffer; } - // Returns the tree of indices which map to buffer pointers. - const ShapeTree& shape_index_to_buffer_entry() const { - return shape_index_to_buffer_entry_; + // Returns the underlying ShapeTree containing all the device addresses in the + // ShapedBuffer. + const ShapeTree& buffers() const { + return buffers_; } - ShapeTree* mutable_shape_index_to_buffer_entry() { - return &shape_index_to_buffer_entry_; + ShapeTree& buffers() { + return buffers_; } // Set all device memory pointers in the object to null. void clear(); - // Adds a new buffer at the given shape index. - void AddBufferAtIndex(const perftools::gputools::DeviceMemoryBase& buffer, - const ShapeIndex& shape_index); - string ToString() const; protected: - // The shape of the device buffer with layout. - const Shape shape_; + // The shape of the data when represented on the host. + const Shape on_host_shape_; + + // The shape of the data on the device. + const Shape on_device_shape_; // The platform the memory is allocated on. const perftools::gputools::Platform* platform_; @@ -93,14 +100,8 @@ class ShapedBuffer { // The device the memory is allocated on. const int device_ordinal_; - // The list of DeviceMemoryBase pointers representing this shape. - // Note that there can be a many to one relationship between tuple elements - // and buffers. To account for this, shape_index_to_buffer_entry_ allows us - // to make from a position in a shape to an index into this list. - std::vector buffers_; - - // The tree of indices into buffers_. - ShapeTree shape_index_to_buffer_entry_; + // The tree of device buffers. Its shape is on_device_shape(). + ShapeTree buffers_; }; std::ostream& operator<<(std::ostream& out, const ShapedBuffer& buffer); @@ -110,20 +111,16 @@ std::ostream& operator<<(std::ostream& out, const ShapedBuffer& buffer); // destructed. class ScopedShapedBuffer : public ShapedBuffer { public: - // Return a newly allocated ScopedShapedBuffer of an arbitrary shape. Array - // buffers (leaves in the shape) are allocated and uninitialized. Tuple - // buffers (if any) are allocated and initialized to the backend-specific - // representation of an array of pointers to the tuple elements. - static StatusOr> Allocate( - const Shape& shape, DeviceMemoryAllocator* allocator, int device_ordinal, - const std::function& shape_size_fn); - // Takes a ShapedBuffer and returns a ScopedShapedBuffer which manages the // deallocation of the device memory held in the shaped buffer. All device // memory pointers in the given ShapedBuffer are set to null. static StatusOr> MakeScoped( ShapedBuffer* shaped_buffer, DeviceMemoryAllocator* allocator); + // Create a ScopedShapedBuffer with null DeviceMemoryBases at each index. + ScopedShapedBuffer(const Shape& on_host_shape, const Shape& on_device_shape, + DeviceMemoryAllocator* allocator, int device_ordinal); + // Return the allocator used to allocate the device memory held in this // ScopedShapedBuffer. DeviceMemoryAllocator* memory_allocator() const { return allocator_; } @@ -138,8 +135,6 @@ class ScopedShapedBuffer : public ShapedBuffer { virtual ~ScopedShapedBuffer(); protected: - ScopedShapedBuffer(const Shape& shape, DeviceMemoryAllocator* allocator, - int device_ordinal); ScopedShapedBuffer(const ScopedShapedBuffer&) = delete; void operator=(const ScopedShapedBuffer&) = delete; diff --git a/tensorflow/compiler/xla/service/transfer_manager.cc b/tensorflow/compiler/xla/service/transfer_manager.cc index d5f53ad56fb019d0ae7c27fc28706f05614ece68..2f36e2b16e0f2eed10aef811dd3cceeba6a5b8a9 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.cc +++ b/tensorflow/compiler/xla/service/transfer_manager.cc @@ -40,6 +40,45 @@ TransferManager::GetPlatformTransferManagers() { return r; } +Status TransferManager::TransferArrayToDevice( + perftools::gputools::StreamExecutor* executor, const Literal& literal, + const perftools::gputools::DeviceMemoryBase& dest) { + const Shape on_device_shape = HostShapeToDeviceShape(literal.shape()); + TF_RET_CHECK(ShapeUtil::IsArray(on_device_shape)) + << "On-device representation of " + << ShapeUtil::HumanString(literal.shape()) + << " is not an array: " << ShapeUtil::HumanString(on_device_shape); + if (dest.size() < GetByteSizeRequirement(on_device_shape)) { + return FailedPrecondition( + "Allocation on device not large enough for array: " + "%lld < %lld", + dest.size(), GetByteSizeRequirement(on_device_shape)); + } + ShapedBuffer shaped_buffer(/*on_host_shape=*/literal.shape(), on_device_shape, + executor->platform(), executor->device_ordinal()); + shaped_buffer.set_buffer(dest, /*index=*/{}); + return TransferLiteralToDevice(executor, literal, shaped_buffer); +} + +StatusOr> TransferManager::TransferArrayFromDevice( + perftools::gputools::StreamExecutor* executor, const Shape& shape, + const perftools::gputools::DeviceMemoryBase& source) { + TF_RET_CHECK(ShapeUtil::Equal(HostShapeToDeviceShape(shape), shape)) + << "Shape " << ShapeUtil::HumanString(shape) + << " has a differently shaped representation on-device: " + << ShapeUtil::HumanString(HostShapeToDeviceShape(shape)); + if (source.size() < GetByteSizeRequirement(shape)) { + return FailedPrecondition( + "Allocation on device not large enough for array: " + "%lld < %lld", + source.size(), GetByteSizeRequirement(shape)); + } + ShapedBuffer shaped_buffer(/*on_host_shape=*/shape, shape, + executor->platform(), executor->device_ordinal()); + shaped_buffer.set_buffer(source, /*index=*/{}); + return TransferLiteralFromDevice(executor, shaped_buffer); +} + /* static */ void TransferManager::RegisterTransferManager( se::Platform::Id platform_id, TransferManagerCreationFunction creation_function) { @@ -75,14 +114,12 @@ TransferManager::GetPlatformTransferManagers() { Status TransferManager::WriteTupleIndexTables( perftools::gputools::StreamExecutor* executor, const ShapedBuffer& device_buffer) { - VLOG(2) << "Writing tuple index tables to ShapedBuffer rooted at " - << device_buffer.buffer(/*index=*/{}).opaque() - << "; shape: " << ShapeUtil::HumanString(device_buffer.shape()); + VLOG(2) << "Writing tuple index tables for " << device_buffer; TF_RET_CHECK(executor->device_ordinal() == device_buffer.device_ordinal()); return ShapeUtil::ForEachSubshapeWithStatus( - device_buffer.shape(), + device_buffer.on_device_shape(), [&](const Shape& device_subshape, const ShapeIndex& index) -> Status { if (ShapeUtil::IsTuple(device_subshape)) { se::DeviceMemoryBase device_memory = device_buffer.buffer(index); @@ -97,7 +134,7 @@ Status TransferManager::WriteTupleIndexTables( elements.push_back(device_buffer.buffer(element_index)); element_index.pop_back(); } - return WriteTuplePointersToDevice(executor, elements, device_subshape, + return WriteSingleTupleIndexTable(executor, elements, device_subshape, &device_memory); } @@ -143,31 +180,43 @@ Status TransferManager::TransferBufferToDevice( return Status::OK(); } -StatusOr> -TransferManager::GatherBufferPointersFromTuple( - se::StreamExecutor* executor, const se::DeviceMemoryBase& source, - const Shape& shape) { - TF_RET_CHECK(ShapeUtil::IsTuple(shape)); - - std::set buffer_pointers; - buffer_pointers.insert(source); - - TF_ASSIGN_OR_RETURN(std::vector tuple_elements, - ShallowCopyTupleFromDevice(executor, source, shape)); - for (auto i = 0; i < tuple_elements.size(); ++i) { - const Shape& element_shape = shape.tuple_shapes(i); - if (ShapeUtil::IsTuple(element_shape)) { - TF_ASSIGN_OR_RETURN( - std::set buffer_pointers_in_element, - GatherBufferPointersFromTuple(executor, tuple_elements[i], - element_shape)); - buffer_pointers.insert(buffer_pointers_in_element.begin(), - buffer_pointers_in_element.end()); - } else { - buffer_pointers.insert(tuple_elements[i]); - } +StatusOr> TransferManager::AllocateShapedBuffer( + const Shape& on_host_shape, DeviceMemoryAllocator* allocator, + int device_ordinal) { + if (!LayoutUtil::HasLayout(on_host_shape)) { + return InvalidArgument( + "Shape must have a layout: %s", + ShapeUtil::HumanStringWithLayout(on_host_shape).c_str()); + } + TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(on_host_shape)); + const Shape on_device_shape = HostShapeToDeviceShape(on_host_shape); + TF_RET_CHECK(LayoutUtil::HasLayout(on_device_shape)); + + auto shaped_buffer = WrapUnique(new ShapedBuffer( + on_host_shape, on_device_shape, allocator->platform(), device_ordinal)); + + // Allocate an appropriate sized buffer for each element in the shape + // including the tuple pointer arrays. + for (auto& pair : shaped_buffer->buffers()) { + const ShapeIndex& index = pair.first; + se::DeviceMemoryBase& memory_base = pair.second; + const Shape& subshape = ShapeUtil::GetSubshape(on_device_shape, index); + TF_ASSIGN_OR_RETURN(memory_base, + allocator->Allocate(shaped_buffer->device_ordinal(), + GetByteSizeRequirement(subshape))); } - return std::move(buffer_pointers); + + return std::move(shaped_buffer); +} + +StatusOr> +TransferManager::AllocateScopedShapedBuffer(const Shape& on_host_shape, + DeviceMemoryAllocator* allocator, + int device_ordinal) { + TF_ASSIGN_OR_RETURN( + std::unique_ptr unscoped_buffer, + AllocateShapedBuffer(on_host_shape, allocator, device_ordinal)); + return ScopedShapedBuffer::MakeScoped(unscoped_buffer.get(), allocator); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/transfer_manager.h b/tensorflow/compiler/xla/service/transfer_manager.h index fdc123e54eb7f754c12510bef551b98da01b585d..9f2b5c4aecf0b52f610171e0c2755de577b2bd9e 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.h +++ b/tensorflow/compiler/xla/service/transfer_manager.h @@ -44,55 +44,47 @@ class TransferManager { // Returns the ID of the platform that this transfer manager acts on. virtual perftools::gputools::Platform::Id PlatformId() const = 0; - // Transfers the region into the provided literal using the provided - // executor. device_shape is the shape, including layout, of the data on the - // device, while literal_shape will be the shape for the literal. device_shape - // and literal_shape must be compatible, but need not have the same layout. - // TODO(b/66694934): Remove TransferLiteral* methods which accept bare - // DeviceMemoryBase. - virtual Status TransferLiteralFromDevice( - perftools::gputools::StreamExecutor* executor, - const perftools::gputools::DeviceMemoryBase& region, - const Shape& device_shape, const Shape& literal_shape, - Literal* literal) = 0; - - // Transfers the given literal into the provided region output parameter, - // using the given executor. - virtual Status TransferLiteralToDevice( - perftools::gputools::StreamExecutor* executor, const Literal& literal, - perftools::gputools::DeviceMemoryBase* region) = 0; - - // Transfers the data held in the given ShapedBuffer into the provided literal - // using the provided executor. literal_shape will be the shape for the - // literal. The shape of the ShapedBuffer and literal_shape must be - // compatible, but need not have the same layout. + // Returns the shape of the on-device representation for the given shape on + // the host. This is intended for use with ShapedBuffer where buffers are + // pre-allocated by the host, e.g. TransferLiteralToDevice, without the user + // needing to consider device-specific behaviors. + virtual Shape HostShapeToDeviceShape(const Shape& host_shape) const { + return host_shape; + } + + // Returns a literal containing the data held in the given ShapedBuffer. + // using the provided executor. The optional literal_shape will be the shape + // for the literal. The shape of the ShapedBuffer and + // DeviceShape(literal_shape) must be compatible, but need not have the same + // layout. virtual StatusOr> TransferLiteralFromDevice( perftools::gputools::StreamExecutor* executor, const ShapedBuffer& device_buffer) = 0; // Transfers the given literal into the previously allocated device memory - // represented by the given ShapedBuffer using the given executor. + // represented by the given ShapedBuffer using the given executor. The shape + // of the ShapedBuffer and DeviceShape(literal.shape()) must be compatible, + // but need not have the same layout virtual Status TransferLiteralToDevice( perftools::gputools::StreamExecutor* executor, const Literal& literal, const ShapedBuffer& device_buffer) = 0; + // Convenience methods for transferring an array to or from the device at a + // known address. This avoids having to construct a ShapedBuffer just to + // transfer an array at a known address. + Status TransferArrayToDevice( + perftools::gputools::StreamExecutor* executor, const Literal& literal, + const perftools::gputools::DeviceMemoryBase& dest); + StatusOr> TransferArrayFromDevice( + perftools::gputools::StreamExecutor* executor, const Shape& shape, + const perftools::gputools::DeviceMemoryBase& source); + // Transfers the given literal into the Infeed interface of the device, // using the given executor. virtual Status TransferLiteralToInfeed( perftools::gputools::StreamExecutor* executor, const Literal& literal) = 0; - // Transfer a memory block of the given size from 'source' buffer to the - // Infeed interface of the device using the given executor. - // - // size is the size to transfer from source in bytes. - // - // source is the source data that must be in the target-dependent layout that - // the Infeed HLO used in the computation expects. - virtual Status TransferBufferToInfeed( - perftools::gputools::StreamExecutor* executor, int64 size, - const void* source) = 0; - // Transfers the given literal from the Outfeed interface of the device, // using the given executor. virtual Status TransferLiteralFromOutfeed( @@ -104,37 +96,26 @@ class TransferManager { tensorflow::gtl::ArraySlice executor) = 0; - // Shallow copy a tuple from the device and create a DeviceMemoryBase object - // for each element in the tuple. A DeviceMemoryBase object refers to the - // buffer containing the data of that element. The DeviceMemoryBase objects - // are returned as a vector. - virtual StatusOr> - ShallowCopyTupleFromDevice( - perftools::gputools::StreamExecutor* executor, - const perftools::gputools::DeviceMemoryBase& source, - const Shape& shape) = 0; - // Given an allocated ShapedBuffer, constructs the tuple index table(s) in // each buffer of the given ShapedBuffer corresponding to tuple shapes. If the // ShapedBuffer is array-shaped this method does nothing. Status WriteTupleIndexTables(perftools::gputools::StreamExecutor* executor, const ShapedBuffer& device_buffer); - // Returns all buffer pointers that the tuple `source` refers to. Unlike - // ShallowCopyTupleFromDevice, this function gather buffer pointers in nested - // tuples as well. Also, the returned DeviceMemoryBase objects are - // deduplicated. - StatusOr> - GatherBufferPointersFromTuple( - perftools::gputools::StreamExecutor* executor, - const perftools::gputools::DeviceMemoryBase& source, const Shape& shape); - // Determines the byte size requirement for the given shape on the underlying // architecture. This will be used to allocate an appropriately sized memory // region for a host-to-device transfer. virtual int64 GetByteSizeRequirement(const Shape& shape) const = 0; - typedef std::unique_ptr (*TransferManagerCreationFunction)(); + // Allocate a ShapedBuffer which can hold data with the given on-host + // shape. The on-device shape may be different as indicated by + // HostShapeToDeviceShape. + StatusOr> AllocateShapedBuffer( + const Shape& on_host_shape, DeviceMemoryAllocator* allocator, + int device_ordinal); + StatusOr> AllocateScopedShapedBuffer( + const Shape& on_host_shape, DeviceMemoryAllocator* allocator, + int device_ordinal); ///// // The TransferManager class also serves as a point to register objects for @@ -144,6 +125,7 @@ class TransferManager { // assumed to be a singleton, so no ownership is transferred. // // Precondition: a platform kind must not be registered more than once. + typedef std::unique_ptr (*TransferManagerCreationFunction)(); static void RegisterTransferManager( perftools::gputools::Platform::Id platform_id, TransferManagerCreationFunction transfer_manager); @@ -154,6 +136,17 @@ class TransferManager { const perftools::gputools::Platform* platform); protected: + // Transfer a memory block of the given size from 'source' buffer to the + // Infeed interface of the device using the given executor. + // + // size is the size to transfer from source in bytes. + // + // source is the source data that must be in the target-dependent layout that + // the Infeed HLO used in the computation expects. + virtual Status TransferBufferToInfeed( + perftools::gputools::StreamExecutor* executor, int64 size, + const void* source) = 0; + // Transfer a memory block of the given size from the device source into the // 'destination' buffer. // @@ -172,10 +165,9 @@ class TransferManager { const void* source, perftools::gputools::DeviceMemoryBase* destination); // Writes the given device-memory pointers in 'elements' to the given region - // to construct a tuple in the platform-specific tuple representation. This - // can handle nested tuples as well. In the nested case, the element - // DeviceMemoryBase points to another array of pointers on the device. - virtual Status WriteTuplePointersToDevice( + // to construct a tuple index table in the platform-specific tuple + // representation. + virtual Status WriteSingleTupleIndexTable( perftools::gputools::StreamExecutor* executor, tensorflow::gtl::ArraySlice elements, diff --git a/tensorflow/compiler/xla/service/transpose_folding.cc b/tensorflow/compiler/xla/service/transpose_folding.cc index fb55d4e5433ce666a061256691ea08ee56fde396..83185ac49e9b7c386d10d1cbc4e20dcdfdfd6cae 100644 --- a/tensorflow/compiler/xla/service/transpose_folding.cc +++ b/tensorflow/compiler/xla/service/transpose_folding.cc @@ -42,7 +42,7 @@ TransposeFolding::OperandIndices CanFoldOperandsIntoDot( TransposeFolding::OperandIndices operand_set; for (int64 i = 0; i < dot.operand_count(); ++i) { auto& operand = *dot.operand(i); - if (operand.IsRank2Transpose() && operand.user_count() == 1) { + if (operand.IsRank2Transpose()) { operand_set.push_back(i); } } @@ -61,8 +61,7 @@ TransposeFolding::OperandIndices CanFoldOperandsIntoConvolution( TransposeFolding::OperandIndices operand_set; for (int64 i = 0; i < convolution.operand_count(); ++i) { auto& operand = *convolution.operand(i); - if (operand.opcode() == HloOpcode::kTranspose && - operand.user_count() == 1) { + if (operand.opcode() == HloOpcode::kTranspose) { operand_set.push_back(i); } } @@ -102,6 +101,10 @@ bool FoldTransposeIntoConvolution(InstructionOperandsPair pair) { auto& convolution = *pair.first; auto& operand_indices = pair.second; + if (operand_indices.empty()) { + return false; + } + const ConvolutionDimensionNumbers& dnums = convolution.convolution_dimension_numbers(); ConvolutionDimensionNumbers new_dnums = dnums; @@ -121,8 +124,9 @@ bool FoldTransposeIntoConvolution(InstructionOperandsPair pair) { transpose_dimensions[dnums.input_batch_dimension()]); new_dnums.set_input_feature_dimension( transpose_dimensions[dnums.input_feature_dimension()]); - for (const auto& spatial_dimension : dnums.input_spatial_dimensions()) { - CHECK_EQ(spatial_dimension, transpose_dimensions[spatial_dimension]); + for (auto& input_spatial_dimension : + *new_dnums.mutable_input_spatial_dimensions()) { + input_spatial_dimension = transpose_dimensions[input_spatial_dimension]; } new_lhs = &transpose_operand; } else { diff --git a/tensorflow/compiler/xla/service/transpose_folding_test.cc b/tensorflow/compiler/xla/service/transpose_folding_test.cc index 6ac32e88f1f4af4743990daecd6c1f66a4e32763..caa1a111ad880b9dee62c1c94e32e8275c196fbf 100644 --- a/tensorflow/compiler/xla/service/transpose_folding_test.cc +++ b/tensorflow/compiler/xla/service/transpose_folding_test.cc @@ -64,9 +64,12 @@ TEST_F(TransposeFoldingTest, FoldDotTranspose) { HloInstruction* transpose_y = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(F32, {3, 2}), y, {1, 0})); - HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {2, 2}), /*opcode=*/HloOpcode::kDot, - /*lhs=*/x, /*rhs=*/transpose_y)); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + HloInstruction* dot = builder.AddInstruction( + HloInstruction::CreateDot(ShapeUtil::MakeShape(F32, {2, 2}), /*lhs=*/x, + /*rhs=*/transpose_y, dot_dnums)); HloModule module("test_module"); HloComputation* entry_computation = @@ -104,9 +107,12 @@ TEST_F(TransposeFoldingTest, FoldDotTransposeConstant) { HloInstruction* transpose1 = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(F32, {2, 3}), const1, {1, 0})); - HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {1, 3}), /*opcode=*/HloOpcode::kDot, - /*lhs=*/transpose0, /*rhs=*/transpose1)); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot( + ShapeUtil::MakeShape(F32, {1, 3}), + /*lhs=*/transpose0, /*rhs=*/transpose1, dot_dnums)); HloModule module("test_module"); HloComputation* entry_computation = @@ -169,9 +175,12 @@ TEST_F(TransposeFoldingTest, FoldDotTransposeInWhile) { HloInstruction* transpose_y = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(F32, {3, 2}), y, {1, 0})); - HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {2, 2}), /*opcode=*/HloOpcode::kDot, - /*lhs=*/x, /*rhs=*/transpose_y)); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + HloInstruction* dot = builder.AddInstruction( + HloInstruction::CreateDot(ShapeUtil::MakeShape(F32, {2, 2}), /*lhs=*/x, + /*rhs=*/transpose_y, dot_dnums)); HloModule module("test_module"); HloComputation* entry_computation = @@ -376,5 +385,69 @@ TEST_F(TransposeFoldingTest, FoldConvTransposeLhs) { new_conv->convolution_dimension_numbers().output_spatial_dimensions(1)); } +// Test that a transpose of every dimension in the activations gets folded into +// convolution. +TEST_F(TransposeFoldingTest, FoldConvComplexTransposeLhs) { + auto builder = HloComputation::Builder("entry_computation"); + HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {3, 2, 1, 1}), + /*name=*/"x")); + HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {2, 3, 1, 1}), + /*name=*/"y")); + HloInstruction* transpose_x = + builder.AddInstruction(HloInstruction::CreateTranspose( + ShapeUtil::MakeShape(F32, {2, 3, 1, 1}), x, {1, 0, 3, 2})); + auto dnums = ComputationBuilder::CreateDefaultConvDimensionNumbers(); + Window window; + for (int i = 0; i < 2; ++i) { + WindowDimension* dim = window.add_dimensions(); + dim->set_padding_low(0); + dim->set_padding_high(0); + dim->set_base_dilation(1); + dim->set_window_dilation(1); + dim->set_stride(1); + dim->set_size(y->shape().dimensions(dnums.kernel_spatial_dimensions(i))); + } + StatusOr conv_shape = ShapeInference::InferConvolveShape( + transpose_x->shape(), y->shape(), window, dnums); + EXPECT_IS_OK(conv_shape); + HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( + conv_shape.ValueOrDie(), transpose_x, y, window, dnums)); + + HloModule module("test_module"); + HloComputation* entry_computation = + module.AddEntryComputation(builder.Build(conv)); + FoldTranspose(&module); + + // Instructions after folding: x, y, and the convolution. + std::unordered_set instruction_set( + entry_computation->instructions().begin(), + entry_computation->instructions().end()); + EXPECT_EQ(1, instruction_set.erase(x)) << "x is not in entry_computation."; + EXPECT_EQ(1, instruction_set.erase(y)) << "y is not in entry_computation."; + EXPECT_EQ(1, instruction_set.size()) + << "entry_computation should contain exactly 3 instructions."; + HloInstruction* new_conv = *instruction_set.begin(); + EXPECT_EQ(HloOpcode::kConvolution, new_conv->opcode()); + EXPECT_EQ(dnums.input_feature_dimension(), + new_conv->convolution_dimension_numbers().input_batch_dimension()); + EXPECT_EQ( + dnums.input_batch_dimension(), + new_conv->convolution_dimension_numbers().input_feature_dimension()); + EXPECT_EQ( + dnums.input_spatial_dimensions(0), + new_conv->convolution_dimension_numbers().input_spatial_dimensions(1)); + EXPECT_EQ( + dnums.input_spatial_dimensions(1), + new_conv->convolution_dimension_numbers().input_spatial_dimensions(0)); + EXPECT_EQ( + dnums.output_spatial_dimensions(0), + new_conv->convolution_dimension_numbers().output_spatial_dimensions(0)); + EXPECT_EQ( + dnums.output_spatial_dimensions(1), + new_conv->convolution_dimension_numbers().output_spatial_dimensions(1)); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc index 0c848566478a25d4862cb0698e029dacd71f7a6a..657a8fe09ae9df906d695f7f49df72500d611792 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc @@ -273,6 +273,16 @@ Status TuplePointsToAnalysis::HandleBitcast(HloInstruction* bitcast) { return Status::OK(); } +Status TuplePointsToAnalysis::HandleSlice(HloInstruction* slice) { + // A kSlice instruction aliases its operand if the backend lowers it to an + // in-place implementation. + if (slice->IsInPlaceSlice()) { + CreateCopiedPointsToSet(slice, slice->operand(0)); + return Status::OK(); + } + return DefaultAction(slice); +} + Status TuplePointsToAnalysis::HandleRecvDone(HloInstruction* recv_done) { // RecvDone aliases its input (Recv) tuple element {0} to its output. PointsToSet& points_to_set = CreateEmptyPointsToSet(recv_done); @@ -427,10 +437,15 @@ bool TuplePointsToAnalysis::InstructionDefinesBufferAtIndex( Status TuplePointsToAnalysis::VerifyBuffer(const LogicalBuffer& buffer) const { if (!InstructionDefinesBufferAtIndex(buffer.instruction(), buffer.index())) { - return FailedPrecondition( - "LogicalBuffer %s is ill-defined: instruction %s does not define a " - "buffer at that index", - buffer.ToString().c_str(), buffer.instruction()->name().c_str()); + // kSlice ops that are lowered to an in-place version are expected to not + // define their output buffer. + if (buffer.instruction()->opcode() != HloOpcode::kSlice || + !buffer.instruction()->IsInPlaceSlice()) { + return FailedPrecondition( + "LogicalBuffer %s is ill-defined: instruction %s does not define a " + "buffer at that index", + buffer.ToString().c_str(), buffer.instruction()->name().c_str()); + } } if (buffer.id() < 0 || diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h index 8928de107eed8c40bbe2130e26fe83ca3802d2f6..c3743b150168ebcf1051050dc511e50c43108c4f 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h @@ -199,12 +199,10 @@ class TuplePointsToAnalysis : public DfsHloVisitorWithDefault { StatusOr GetBufferDefinedAt( const HloInstruction* instruction, const ShapeIndex& index) const; - // Return a vector containing all BufferAliases of the given logical buffer - // This trivially includes the BufferAlias with same instruction and index as - // the logical buffer itself, so the returned vector is never empty. The - // buffer alias set is the inverse of the points-to set. That is, - // LogicalBuffer B is in the points-to set of instruction I at index N iff - // instruction I, index N is a BufferAlias of B. + // Return a (possibly empty) vector containing all BufferAliases of the given + // logical buffer The buffer alias set is the inverse of the points-to set. + // That is, LogicalBuffer B is in the points-to set of instruction I at index + // N iff instruction I, index N is a BufferAlias of B. using BufferAliasVector = tensorflow::gtl::InlinedVector; const BufferAliasVector& GetBufferAliases(const LogicalBuffer& buffer) const; @@ -250,6 +248,7 @@ class TuplePointsToAnalysis : public DfsHloVisitorWithDefault { Status HandleTuple(HloInstruction* tuple) override; Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; Status HandleBitcast(HloInstruction* bitcast) override; + Status HandleSlice(HloInstruction* slice) override; Status HandleCopy(HloInstruction* copy) override; Status HandleRecvDone(HloInstruction* recv_done) override; Status HandleSend(HloInstruction* send) override; diff --git a/tensorflow/compiler/xla/service/tuple_util.cc b/tensorflow/compiler/xla/service/tuple_util.cc new file mode 100644 index 0000000000000000000000000000000000000000..4a530bb0b20582b303f4af969514748b46fd5064 --- /dev/null +++ b/tensorflow/compiler/xla/service/tuple_util.cc @@ -0,0 +1,61 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/tuple_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/core/lib/gtl/array_slice.h" + +namespace xla { + +/*static*/ HloInstruction* TupleUtil::ExtractPrefix(HloInstruction* input_tuple, + int64 elements) { + CHECK(ShapeUtil::IsTuple(input_tuple->shape())); + + HloComputation* computation = input_tuple->parent(); + const Shape& input_shape = input_tuple->shape(); + + std::vector tuple_elements; + tuple_elements.reserve(elements); + for (int i = 0; i < elements; i++) { + tuple_elements.push_back( + computation->AddInstruction(HloInstruction::CreateGetTupleElement( + input_shape.tuple_shapes(i), input_tuple, i))); + } + + return computation->AddInstruction( + HloInstruction::CreateTuple(tuple_elements)); +} + +/*static*/ HloInstruction* TupleUtil::AppendSuffix( + HloInstruction* input_tuple, + tensorflow::gtl::ArraySlice trailing_values) { + CHECK(ShapeUtil::IsTuple(input_tuple->shape())); + + HloComputation* computation = input_tuple->parent(); + const Shape& input_shape = input_tuple->shape(); + std::vector tuple_elements; + tuple_elements.reserve(input_shape.tuple_shapes_size()); + for (int i = 0; i < input_shape.tuple_shapes_size(); i++) { + tuple_elements.push_back( + computation->AddInstruction(HloInstruction::CreateGetTupleElement( + input_shape.tuple_shapes(i), input_tuple, i))); + } + tuple_elements.insert(tuple_elements.end(), trailing_values.begin(), + trailing_values.end()); + return computation->AddInstruction( + HloInstruction::CreateTuple(tuple_elements)); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/tuple_util.h b/tensorflow/compiler/xla/service/tuple_util.h new file mode 100644 index 0000000000000000000000000000000000000000..e5ff9aaa8357fe8e4777d6dee37bbec72e144c06 --- /dev/null +++ b/tensorflow/compiler/xla/service/tuple_util.h @@ -0,0 +1,45 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_TUPLE_UTIL_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_TUPLE_UTIL_H_ + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" + +namespace xla { +class TupleUtil { + public: + // Generates HLO instructions to get a prefix tuple from `input_tuple` (which + // must be of tuple shape) of length `elements`. Returns the root of the + // graph of instructions generated. + // + // The instructions are generated into the computation containing + // `input_tuple`. + static HloInstruction* ExtractPrefix(HloInstruction* input_tuple, + int64 elements); + + // Generates HLO instructions to create a tuple that consists of the values in + // `trailing_values` appended to `input_tuple` (which must be of tuple shape). + // Returns the root of the graph of instructions generated. + // + // The instructions are generated into the computation containing + // `input_tuple`. + static HloInstruction* AppendSuffix( + HloInstruction* input_tuple, + tensorflow::gtl::ArraySlice trailing_values); +}; +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_TUPLE_UTIL_H_ diff --git a/tensorflow/compiler/xla/service/tuple_util_test.cc b/tensorflow/compiler/xla/service/tuple_util_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..754fd8ef169231827eeb5bfd72aeb596644ca767 --- /dev/null +++ b/tensorflow/compiler/xla/service/tuple_util_test.cc @@ -0,0 +1,81 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/tuple_util.h" + +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" + +namespace xla { +namespace { + +namespace op = ::xla::testing::opcode_matchers; + +StatusOr> GetParsedModule( + HloComputation** entry_computation, HloInstruction** param0, + HloInstruction** param1) { + const char* const hlo_string = R"( +HloModule Module + +ENTRY entry { + p0 = (f32[32,32]{1,0},f32[32,32]{1,0},f32[32,32]{1,0}) parameter(0) + ROOT p1 = f32[32,32]{1,0} parameter(1) +} +)"; + + TF_ASSIGN_OR_RETURN(std::unique_ptr module, + tools::Parse(hlo_string)); + + *entry_computation = module->entry_computation(); + *param0 = (*entry_computation)->parameter_instruction(0); + *param1 = (*entry_computation)->parameter_instruction(1); + + return std::move(module); +} + +TEST(TupleUtilTest, ExtractPrefix) { + HloInstruction *param0, *param1; + HloComputation* entry_computation; + + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + GetParsedModule(&entry_computation, ¶m0, ¶m1)); + + HloInstruction* prefix = TupleUtil::ExtractPrefix(param0, 2); + + EXPECT_THAT(prefix, op::Tuple(op::GetTupleElement(op::Parameter(0), 0), + op::GetTupleElement(op::Parameter(0), 1))); +} + +TEST(TupleUtilTest, AppendSuffix) { + HloInstruction *param0, *param1; + HloComputation* entry_computation; + + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + GetParsedModule(&entry_computation, ¶m0, ¶m1)); + + HloInstruction* with_suffix = + TupleUtil::AppendSuffix(param0, {param1, param1}); + + EXPECT_THAT(with_suffix, op::Tuple(op::GetTupleElement(op::Parameter(0), 0), + op::GetTupleElement(op::Parameter(0), 1), + op::GetTupleElement(op::Parameter(0), 2), + op::Parameter(1), op::Parameter(1))); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/user_computation.cc b/tensorflow/compiler/xla/service/user_computation.cc index 4e90491b55a5688e37cbabae0843f584578add55..7882b70ab7765ad528b68f97c115e3ae5f19e48a 100644 --- a/tensorflow/compiler/xla/service/user_computation.cc +++ b/tensorflow/compiler/xla/service/user_computation.cc @@ -88,8 +88,6 @@ HloOpcode BinaryOperationToHloOpcode(BinaryOperation binop) { return HloOpcode::kAtan2; case BINOP_COMPLEX: return HloOpcode::kComplex; - case BINOP_DOT: - return HloOpcode::kDot; case BINOP_MUL: return HloOpcode::kMultiply; case BINOP_ADD: @@ -371,14 +369,6 @@ StatusOr UserComputation::AddRngInstruction( // Check the number of parameters per RNG distribution. switch (rng_request.distribution()) { - case RandomDistribution::RNG_BERNOULLI: - if (rng_request.parameter_size() != 1) { - return InvalidArgument( - "RNG distribution (%s) expects 1 parameters, but got %d", - RandomDistribution_Name(rng_request.distribution()).c_str(), - rng_request.parameter_size()); - } - break; case RandomDistribution::RNG_NORMAL: case RandomDistribution::RNG_UNIFORM: if (rng_request.parameter_size() != 2) { @@ -765,6 +755,54 @@ StatusOr UserComputation::AddWhileInstruction( return handle; } +StatusOr UserComputation::AddConditionalInstruction( + const ConditionalRequest& conditional_request, + const UserComputation& true_computation, + const UserComputation& false_computation) { + tensorflow::mutex_lock lock(mutex_); + + TF_ASSIGN_OR_RETURN(const OperationRequest* pred, + LookUpRequest(conditional_request.predicate())); + TF_ASSIGN_OR_RETURN(const OperationRequest* true_operand, + LookUpRequest(conditional_request.true_operand())); + TF_ASSIGN_OR_RETURN(const OperationRequest* false_operand, + LookUpRequest(conditional_request.false_operand())); + + VersionedComputationHandle::Version true_computation_version = + true_computation.version(); + TF_ASSIGN_OR_RETURN( + std::shared_ptr true_computation_shape, + true_computation.ComputeProgramShape(true_computation_version)); + + VersionedComputationHandle::Version false_computation_version = + false_computation.version(); + TF_ASSIGN_OR_RETURN( + std::shared_ptr false_computation_shape, + false_computation.ComputeProgramShape(false_computation_version)); + + TF_ASSIGN_OR_RETURN(Shape inferred_shape, + ShapeInference::InferConditionalShape( + pred->output_shape(), true_operand->output_shape(), + false_operand->output_shape(), + *true_computation_shape, *false_computation_shape)); + + ComputationDataHandle handle = CreateComputationDataHandle(); + + OperationRequest& request = + (*session_computation_.mutable_requests())[handle.handle()]; + *request.mutable_output_handle() = handle; + *request.mutable_output_shape() = inferred_shape; + request.add_embedded_computation_versions(true_computation_version); + request.add_embedded_computation_versions(false_computation_version); + *request.mutable_request()->mutable_conditional_request() = + conditional_request; + + VLOG(1) << "AddConditionalInstruction (" << GetVersionedHandleInternal() + << "), data handle " << handle.handle() << ": " + << conditional_request.ShortDebugString(); + return handle; +} + StatusOr UserComputation::AddBroadcastInstruction( const BroadcastRequest& broadcast_request) { tensorflow::mutex_lock lock(mutex_); @@ -1075,6 +1113,31 @@ StatusOr UserComputation::AddConvolveInstruction( return handle; } +StatusOr UserComputation::AddFftInstruction( + const FftRequest& fft_request) { + tensorflow::mutex_lock lock(mutex_); + + TF_ASSIGN_OR_RETURN(const OperationRequest* operand, + LookUpRequest(fft_request.operand())); + TF_ASSIGN_OR_RETURN(Shape shape, + ShapeInference::InferFftShape( + operand->output_shape(), fft_request.fft_type(), + AsInt64Slice(fft_request.fft_length()))); + + const ComputationDataHandle handle = CreateComputationDataHandle(); + + OperationRequest& request = + (*session_computation_.mutable_requests())[handle.handle()]; + *request.mutable_output_handle() = handle; + *request.mutable_output_shape() = shape; + *request.mutable_request()->mutable_fft_request() = fft_request; + + VLOG(1) << "AddFftInstruction (" << GetVersionedHandleInternal() + << "), data handle " << handle.handle() << ": " + << fft_request.ShortDebugString(); + return handle; +} + StatusOr UserComputation::AddCrossReplicaSumInstruction( const CrossReplicaSumRequest& cross_replica_sum_request) { tensorflow::mutex_lock lock(mutex_); @@ -1082,7 +1145,7 @@ StatusOr UserComputation::AddCrossReplicaSumInstruction( TF_ASSIGN_OR_RETURN(const OperationRequest* operand, LookUpRequest(cross_replica_sum_request.operand())); TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferCrossReplicaSumShape( - operand->output_shape())); + {&operand->output_shape()})); ComputationDataHandle handle = CreateComputationDataHandle(); @@ -1192,6 +1255,14 @@ StatusOr UserComputation::AddCustomCallInstruction( TF_RETURN_IF_ERROR(LookUpRequest(handle).status()); } + if (tensorflow::StringPiece(custom_call_request.call_target_name()) + .starts_with("$")) { + return InvalidArgument( + "Invalid custom_call_target \"%s\": Call targets that start with '$' " + "are reserved for internal use.", + custom_call_request.call_target_name().c_str()); + } + const ComputationDataHandle handle = CreateComputationDataHandle(); OperationRequest& request = @@ -1207,6 +1278,33 @@ StatusOr UserComputation::AddCustomCallInstruction( return handle; } +StatusOr UserComputation::AddDotInstruction( + const DotRequest& dot_request) { + tensorflow::mutex_lock lock(mutex_); + + TF_ASSIGN_OR_RETURN(const OperationRequest* lhs, + LookUpRequest(dot_request.lhs())); + TF_ASSIGN_OR_RETURN(const OperationRequest* rhs, + LookUpRequest(dot_request.rhs())); + + TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferDotOpShape( + lhs->output_shape(), rhs->output_shape(), + dot_request.dimension_numbers())); + + const ComputationDataHandle handle = CreateComputationDataHandle(); + + OperationRequest& request = + (*session_computation_.mutable_requests())[handle.handle()]; + *request.mutable_output_handle() = handle; + *request.mutable_output_shape() = shape; + *request.mutable_request()->mutable_dot_request() = dot_request; + + VLOG(1) << "AddDotInstruction (" << GetVersionedHandleInternal() + << "), data handle " << handle.handle() << ": " + << dot_request.ShortDebugString(); + return handle; +} + StatusOr UserComputation::AddUnaryInstruction( const UnaryOpRequest& unary_request) { tensorflow::mutex_lock lock(mutex_); @@ -1433,7 +1531,7 @@ StatusOr LookUpRequest( return &session_computation.requests().at(handle_value); } -// Returns the OperationRequestion corresponding to the root (result) of the +// Returns the OperationRequest corresponding to the root (result) of the // session computation. StatusOr GetRoot( VersionedComputationHandle::Version version, @@ -1479,8 +1577,8 @@ UserComputation::ComputeProgramShape( request.request().parameter_request(); int64 param_no = parameter_request.parameter(); // Parameters may be out of order so expand ProgramShape parameters - // until - // it is at least large enough to hold the current parameter number. + // until it is at least large enough to hold the current parameter + // number. while (program_shape->parameters_size() <= param_no) { program_shape->add_parameters(); program_shape->add_parameter_names(); @@ -1594,6 +1692,13 @@ void PureFunctionalVisitor(const SessionComputation& session_computation, break; } + case OpRequest::kFftRequest: { + const FftRequest& fft_request = request.request().fft_request(); + PureFunctionalVisitor(session_computation, fft_request.operand(), + num_parameters, visited, is_functional); + break; + } + case OpRequest::kCrossReplicaSumRequest: { // TODO(b/33009255): Implmement constant folding for cross replica sum. *is_functional = false; @@ -1629,6 +1734,15 @@ void PureFunctionalVisitor(const SessionComputation& session_computation, break; } + case OpRequest::kDotRequest: { + const DotRequest& dot_request = request.request().dot_request(); + PureFunctionalVisitor(session_computation, dot_request.lhs(), + num_parameters, visited, is_functional); + PureFunctionalVisitor(session_computation, dot_request.rhs(), + num_parameters, visited, is_functional); + break; + } + case OpRequest::kSendRequest: { *is_functional = false; break; @@ -1757,6 +1871,23 @@ void PureFunctionalVisitor(const SessionComputation& session_computation, break; } + case OpRequest::kConditionalRequest: { + const ConditionalRequest& conditional_request = + request.request().conditional_request(); + PureFunctionalVisitor(session_computation, + conditional_request.predicate(), num_parameters, + visited, is_functional); + PureFunctionalVisitor(session_computation, + conditional_request.true_operand(), num_parameters, + visited, is_functional); + PureFunctionalVisitor(session_computation, + conditional_request.false_operand(), num_parameters, + visited, is_functional); + // TODO(b/32495713): We aren't checking the true and false computations + // themselves. + break; + } + case OpRequest::kTernaryOpRequest: { const TernaryOpRequest& ternary_op_request = request.request().ternary_op_request(); @@ -1985,6 +2116,21 @@ UserComputation::GetEmbeddedComputations( break; } + case OpRequest::kConditionalRequest: { + CHECK_EQ(2, request.embedded_computation_versions_size()); + const ConditionalRequest& conditional_request = + request.request().conditional_request(); + const VersionedComputationHandle true_computation_versioned_handle = { + conditional_request.true_computation(), + request.embedded_computation_versions(0)}; + computations.push_back(true_computation_versioned_handle); + const VersionedComputationHandle false_computation_versioned_handle = + {conditional_request.false_computation(), + request.embedded_computation_versions(1)}; + computations.push_back(false_computation_versioned_handle); + break; + } + default: // No embedded computation. break; @@ -2000,6 +2146,24 @@ UserComputation::GetEmbeddedComputations( return computations; } +StatusOr +UserComputation::LookUpRequestForErrorReporting( + const ComputationDataHandle& handle) const { + tensorflow::mutex_lock lock(mutex_); + return LookUpRequest(handle); +} + +tensorflow::gtl::optional UserComputation::ParameterMetadata( + int parameter_number) const { + tensorflow::mutex_lock lock(mutex_); + auto it = parameters_.find(parameter_number); + if (it == parameters_.end()) { + return tensorflow::gtl::nullopt; + } + OperationRequest* op = it->second; + return &op->request().metadata(); +} + Status UserComputation::RemapEmbeddedComputations( const std::map& old_to_new) { auto update = [&old_to_new](ComputationHandle* to_update) -> Status { @@ -2071,6 +2235,16 @@ Status UserComputation::RemapEmbeddedComputations( TF_RETURN_IF_ERROR(update(while_request->mutable_body())); break; } + case OpRequest::kConditionalRequest: { + TF_RET_CHECK(2 == request.embedded_computation_versions_size()); + ConditionalRequest* conditional_request = + request.mutable_request()->mutable_conditional_request(); + TF_RETURN_IF_ERROR( + update(conditional_request->mutable_true_computation())); + TF_RETURN_IF_ERROR( + update(conditional_request->mutable_false_computation())); + break; + } default: // No embedded computation. TF_RET_CHECK(0 == request.embedded_computation_versions_size()); @@ -2274,6 +2448,12 @@ static void ForEachOperand( break; } + case OpRequest::kFftRequest: { + const FftRequest& fft_request = request.request().fft_request(); + apply(fft_request.operand()); + break; + } + case OpRequest::kBatchNormTrainingRequest: { const BatchNormTrainingRequest& batch_norm_training_request = request.request().batch_norm_training_request(); @@ -2417,6 +2597,15 @@ static void ForEachOperand( break; } + case OpRequest::kConditionalRequest: { + const ConditionalRequest& conditional_request = + request.request().conditional_request(); + apply(conditional_request.predicate()); + apply(conditional_request.true_operand()); + apply(conditional_request.false_operand()); + break; + } + case OpRequest::kTernaryOpRequest: { const TernaryOpRequest& ternary_op_request = request.request().ternary_op_request(); @@ -2453,6 +2642,13 @@ static void ForEachOperand( break; } + case OpRequest::kDotRequest: { + const DotRequest& dot_request = request.request().dot_request(); + apply(dot_request.rhs()); + apply(dot_request.lhs()); + break; + } + case OpRequest::kUnaryOpRequest: { const UnaryOpRequest& unary_op_request = request.request().unary_op_request(); @@ -2653,7 +2849,8 @@ void ComputationLowerer::Visit( const ConstantRequest& constant_request = request.request().constant_request(); hlo_instruction = add_instruction(HloInstruction::CreateConstant( - Literal(constant_request.literal()).CloneToUnique())); + Literal::CreateFromProto(constant_request.literal()) + .ConsumeValueOrDie())); break; } @@ -2732,13 +2929,31 @@ void ComputationLowerer::Visit( break; } + case OpRequest::kFftRequest: { + const FftRequest& fft_request = request.request().fft_request(); + HloInstruction* operand = lookup_instruction(fft_request.operand()); + hlo_instruction = add_instruction(HloInstruction::CreateFft( + request.output_shape(), operand, fft_request.fft_type(), + AsInt64Slice(fft_request.fft_length()))); + break; + } + + case OpRequest::kDotRequest: { + const DotRequest& dot_request = request.request().dot_request(); + HloInstruction* lhs = lookup_instruction(dot_request.lhs()); + HloInstruction* rhs = lookup_instruction(dot_request.rhs()); + hlo_instruction = add_instruction(HloInstruction::CreateDot( + request.output_shape(), lhs, rhs, dot_request.dimension_numbers())); + break; + } + case OpRequest::kCrossReplicaSumRequest: { const CrossReplicaSumRequest& cross_replica_sum_request = request.request().cross_replica_sum_request(); HloInstruction* operand = lookup_instruction(cross_replica_sum_request.operand()); hlo_instruction = add_instruction(HloInstruction::CreateCrossReplicaSum( - request.output_shape(), operand)); + request.output_shape(), {operand})); break; } @@ -3021,6 +3236,30 @@ void ComputationLowerer::Visit( break; } + case OpRequest::kConditionalRequest: { + const ConditionalRequest& conditional_request = + request.request().conditional_request(); + CHECK_EQ(2, request.embedded_computation_versions_size()); + VersionedComputationHandle::Version true_computation_version = + request.embedded_computation_versions(0); + HloComputation* true_computation = ResolveComputation( + conditional_request.true_computation(), true_computation_version); + VersionedComputationHandle::Version false_computation_version = + request.embedded_computation_versions(1); + HloComputation* false_computation = ResolveComputation( + conditional_request.false_computation(), false_computation_version); + HloInstruction* predicate = + lookup_instruction(conditional_request.predicate()); + HloInstruction* true_operand = + lookup_instruction(conditional_request.true_operand()); + HloInstruction* false_operand = + lookup_instruction(conditional_request.false_operand()); + hlo_instruction = add_instruction(HloInstruction::CreateConditional( + request.output_shape(), predicate, true_operand, true_computation, + false_operand, false_computation)); + break; + } + case OpRequest::kTernaryOpRequest: { const TernaryOpRequest& ternary_op_request = request.request().ternary_op_request(); @@ -3151,8 +3390,7 @@ void ComputationLowerer::Visit( lhs = (lhs == operand_to_broadcast) ? broadcasted_operand : lhs; rhs = (rhs == operand_to_broadcast) ? broadcasted_operand : rhs; } - if (debug_options_.xla_eliminate_hlo_implicit_broadcast() && - binary_op_request.binop() != BINOP_DOT) { + if (debug_options_.xla_eliminate_hlo_implicit_broadcast()) { if (!ShapeUtil::SameDimensions(request.output_shape(), lhs->shape())) { // lhs side is being implicitly broadcast. Change to explicit. lhs = diff --git a/tensorflow/compiler/xla/service/user_computation.h b/tensorflow/compiler/xla/service/user_computation.h index 317c631dca2e1ebe6f3c8fbaf1a3e94106034f79..4f92e58877a1d06728fdd250744ca2ce7b57d9ad 100644 --- a/tensorflow/compiler/xla/service/user_computation.h +++ b/tensorflow/compiler/xla/service/user_computation.h @@ -133,6 +133,10 @@ class UserComputation { StatusOr AddConvolveInstruction( const ConvolveRequest& convolve_request); + // Enqueues an FFT instruction onto this user computation. + StatusOr AddFftInstruction( + const FftRequest& fft_request); + // Enqueues a cross replica sum instruction onto this user computation. StatusOr AddCrossReplicaSumInstruction( const CrossReplicaSumRequest& cross_replica_sum_request); @@ -153,6 +157,10 @@ class UserComputation { StatusOr AddCustomCallInstruction( const CustomCallRequest& custom_call_request); + // Enqueues a dot instruction onto this user computation. + StatusOr AddDotInstruction( + const DotRequest& dot_request); + // Enqueues a broadcast instruction onto this user computation. StatusOr AddBroadcastInstruction( const BroadcastRequest& broadcast_request); @@ -216,6 +224,12 @@ class UserComputation { const UserComputation& condition_computation, const UserComputation& body_computation); + // Enqueues a conditional instruction on this user computation. + StatusOr AddConditionalInstruction( + const ConditionalRequest& conditional_request, + const UserComputation& true_computation, + const UserComputation& false_computation); + // Enqueues a Send instruction onto this user computation. Status AddSendInstruction(const SendRequest& send_request); @@ -307,6 +321,23 @@ class UserComputation { SessionComputation CloneSessionComputation( VersionedComputationHandle::Version version) const; + // Warning: typically we don't want to look up computation data handles until + // the computation is finished being built, for consistency purposes. We + // expose this routine for error reporting purposes so that we can provide + // more meaningful error messages from the XLA service layer. + // + // Returns the operation request that the handle comes from. + StatusOr LookUpRequestForErrorReporting( + const ComputationDataHandle& handle) const; + + // Retrieves the parameter metadata for the given parameter number. + // + // If the parameter number is invalid for this computation, nullopt is + // returned. When the return value has_value(), nullptr will never be + // the held value. + tensorflow::gtl::optional ParameterMetadata( + int parameter_number) const; + private: // Warning: dangerous mutating operation that doesn't respect versioning. // This is only used at initialization time when constructing from a diff --git a/tensorflow/compiler/xla/service/user_computation_test.cc b/tensorflow/compiler/xla/service/user_computation_test.cc index 5afaf226ae0cce7e9afc966c6b4adf838aeebc91..ca02115863e6906ef709ba63259024877e0dcef4 100644 --- a/tensorflow/compiler/xla/service/user_computation_test.cc +++ b/tensorflow/compiler/xla/service/user_computation_test.cc @@ -65,6 +65,7 @@ TEST_F(UserComputationTest, SimpleComputation) { OutfeedRequest outfeed_request; *outfeed_request.mutable_operand() = constant_handle; + *outfeed_request.mutable_shape() = kVectorShape; outfeed_request.set_outfeed_config("abc"); TF_ASSERT_OK(computation.AddOutfeedInstruction(outfeed_request)); @@ -334,50 +335,5 @@ TEST_F(UserComputationTest, EliminateDegenerateBroadcastAfterIndimBroadcast) { operands[1]->opcode() == HloOpcode::kBroadcast); } -TEST_F(UserComputationTest, SkipDotInEliminatingImplicitBroadcast) { - auto debug_options = DebugOptions(); - debug_options.set_xla_eliminate_hlo_implicit_broadcast(true); - - // %a = Param({1, 3}); - // %b = Param({3, 1}); - // %dot = Dot(%a, %b); - ComputationHandle handle; - handle.set_handle(123); - UserComputation computation("TheComputation", handle); - - ParameterRequest a_request; - *a_request.mutable_shape() = ShapeUtil::MakeShape(F32, {1, 3}); - a_request.set_name("a"); - a_request.set_parameter(0); - TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle a_handle, - computation.AddParameterInstruction(a_request)); - - ParameterRequest b_request; - *b_request.mutable_shape() = ShapeUtil::MakeShape(F32, {3, 1}); - b_request.set_name("b"); - b_request.set_parameter(1); - TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle b_handle, - computation.AddParameterInstruction(b_request)); - - BinaryOpRequest dot; - dot.set_binop(BINOP_DOT); - *dot.mutable_lhs() = a_handle; - *dot.mutable_rhs() = b_handle; - TF_ASSERT_OK(computation.AddBinaryInstruction(dot).status()); - - auto hlo_resolver = [](const VersionedComputationHandle& handle) { - return nullptr; - }; - VersionedComputationHandle latest_version = computation.GetVersionedHandle(); - - // Build the HLO computation. - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr hlo_computation, - computation.BuildHloComputation(latest_version.version, hlo_resolver, - debug_options)); - - EXPECT_EQ(3, hlo_computation->instruction_count()); -} - } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc new file mode 100644 index 0000000000000000000000000000000000000000..a5f9b01f011ce04f1114c74391a967c62f015221 --- /dev/null +++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc @@ -0,0 +1,296 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h" +#include "tensorflow/compiler/xla/service/tuple_util.h" +#include "tensorflow/compiler/xla/service/while_util.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/lib/gtl/flatset.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" + +namespace xla { + +using tensorflow::gtl::FlatMap; +using tensorflow::gtl::FlatSet; +using tensorflow::gtl::InlinedVector; + +// Copies `to_hoist` to the computation containing `while_instr`, hoisting its +// operands as needed. All of its transitive operands are expected to be either +// in `hoisted_instructions` or `unhoisted_invariant_instructions`. This +// function hoists the operands in `unhoisted_invariant_instructions` and moves +// them into `hoisted_instructions`. +static void CreateLoopInvariantCopy( + FlatMap* hoisted_instructions, + FlatSet* unhoisted_invariant_instructions, + HloInstruction* while_instr, HloInstruction* to_hoist) { + HloComputation* parent_of_while = while_instr->parent(); + HloComputation* while_body = while_instr->while_body(); + + struct DFSFrame { + HloInstruction* instruction; + int64 operand_index; + }; + + InlinedVector dfs_stack; + dfs_stack.push_back({to_hoist, 0}); + + HloInstruction* while_body_param = while_body->parameter_instruction(0); + HloInstruction* while_operand = while_instr->mutable_operand(0); + + do { + DFSFrame* frame = &dfs_stack.back(); + if (frame->operand_index == frame->instruction->operand_count()) { + HloInstruction* old_instruction = frame->instruction; + + // All of the operands for old_instruction have been cloned, so it is + // time to clone old_instruction itself. + + auto get_new_operand = [&](HloInstruction* old_operand) { + return old_operand == while_body_param + ? while_operand + : FindOrDie(*hoisted_instructions, old_operand); + }; + + InlinedVector new_operands; + c_transform(old_instruction->operands(), std::back_inserter(new_operands), + get_new_operand); + + HloInstruction* new_instruction = + parent_of_while->AddInstruction(old_instruction->CloneWithNewOperands( + old_instruction->shape(), new_operands)); + + InsertOrDie(hoisted_instructions, old_instruction, new_instruction); + + // Approximately half of the instructions that would normally be present + // in unhoisted_invariant_instructions are constants. We save a bit of + // compile time by not putting these in the hashtable. + CHECK_EQ(unhoisted_invariant_instructions->erase(old_instruction), + to_hoist != old_instruction && + old_instruction->opcode() != HloOpcode::kConstant); + dfs_stack.pop_back(); + continue; + } + + HloInstruction* next_operand = + frame->instruction->mutable_operand(frame->operand_index++); + if (hoisted_instructions->count(next_operand) || + next_operand == while_body_param) { + continue; + } + + dfs_stack.push_back({next_operand, 0}); + } while (!dfs_stack.empty()); +} + +// Returns true if `instruction` is worth hoisting only if it lets us hoist some +// instruction using it. The rationale is that hoisting these instructions will +// prevent simplification and fusion in the while body. +static bool NotWorthHoistingIndividually(const HloInstruction& instruction) { + switch (instruction.opcode()) { + default: + return false; + + case HloOpcode::kBitcast: + case HloOpcode::kBroadcast: + case HloOpcode::kConstant: + case HloOpcode::kReverse: + case HloOpcode::kSlice: + case HloOpcode::kTuple: + return true; + + case HloOpcode::kTranspose: + return ShapeUtil::TransposeIsBitcast( + /*input_shape=*/instruction.operand(0)->shape(), + /*output_shape=*/instruction.shape(), instruction.dimensions()); + + case HloOpcode::kReshape: + return ShapeUtil::ReshapeIsBitcast( + /*input_shape=*/instruction.operand(0)->shape(), + /*output_shape=*/instruction.shape()); + } +} + +// Populates `gte_set` with the GetTupleElement instructions in `while_body` +// that access elements in the parameter tuple that don't change across +// iterations. Assumes `while_body` is the body computation of the while loop +// in question. +static void GatherInvariantGTEs(HloComputation* while_body, + FlatSet* gte_set) { + const HloInstruction::InstructionVector root_operands = + while_body->root_instruction()->operands(); + for (int i = 0; i < root_operands.size(); i++) { + HloInstruction* instr = root_operands[i]; + if (instr->opcode() == HloOpcode::kGetTupleElement && + instr->tuple_index() == i && + instr->operand(0) == while_body->parameter_instruction(0) && + ShapeUtil::IsArray(instr->shape())) { + InsertOrDie(gte_set, instr); + } + } +} + +static StatusOr TryHoistingInvariantInstructionsFromWhileBody( + HloInstruction* while_instr) { + auto print_no_metadata = HloPrintOptions{}.set_print_metadata(false); + + if (!ShapeUtil::IsTuple(while_instr->shape())) { + // This restriction leaves one interesting pattern on the table: + // + // while_body(f32[1024, 1024] %param) { + // %value = expensive_op(%param) + // outfeed(%value) + // ROOT = %param + // } + // + // If we see that pattern in the while, instead of generalizing this + // algorithm to work with non-tuples, we should instead add a pass that + // canonicalizes while loops like the above to use a tuple state. + return false; + } + + string while_instr_name = while_instr->ToString(print_no_metadata); + VLOG(2) << "Trying to hoist from " << while_instr_name; + + HloComputation* while_body = while_instr->while_body(); + + // Maps instructions in the while body to instructions hoisted outside the + // while that compute the same value. + FlatMap hoisted_instructions; + + // Contains instructions that can be legally hoisted, but were deemed to be + // unprofitable to be hoisted alone by NotWorthHoistingIndividually. When we + // hoist an instruction in this set, we move it from + // unhoisted_invariant_instructions to hoisted_instructions. + FlatSet unhoisted_invariant_instructions; + + // Invariant GTE's axiomatically satisfy the constraints for + // unhoisted_invariant_instructions -- they can be legally hoisted, but there + // is no benefit to hoisting them unless something that uses it is also + // hoisted. + GatherInvariantGTEs(while_body, &unhoisted_invariant_instructions); + + if (unhoisted_invariant_instructions.empty()) { + // There are no obviously loop invariant elements in the state being + // threaded through the while loop so give up. In theory this precondition + // is too strong -- we could have code that e.g. permutes the elements in + // the while state but uses a select to pick the same value on every + // iteration. + return false; + } + + // instructions_to_replace[i] is hoisted into a loop invariant instruction + // replacement_instructions[i]. + std::vector instructions_to_replace; + std::vector replacement_instructions; + + for (auto* instruction : while_body->MakeInstructionPostOrder()) { + if (instruction->HasSideEffect() || + instruction->opcode() == HloOpcode::kParameter || + !instruction->control_predecessors().empty() || + !instruction->control_successors().empty()) { + continue; + } + + auto is_invariant = [&](HloInstruction* op) { + return hoisted_instructions.find(op) != hoisted_instructions.end() || + unhoisted_invariant_instructions.count(op) || + op->opcode() == HloOpcode::kConstant; + }; + + if (!c_all_of(instruction->operands(), is_invariant)) { + continue; + } + + if (NotWorthHoistingIndividually(*instruction)) { + VLOG(2) << "Adding " << instruction->ToString(print_no_metadata) + << " to unhoisted invariant set."; + // Approximately half of the instructions that reach this point are + // constants. We save a bit of compile time by not putting these in the + // hashtable. + if (instruction->opcode() != HloOpcode::kConstant) { + InsertOrDie(&unhoisted_invariant_instructions, instruction); + } + continue; + } + + VLOG(2) << "Hoisting " << instruction->ToString(print_no_metadata); + + CreateLoopInvariantCopy(&hoisted_instructions, + &unhoisted_invariant_instructions, while_instr, + instruction); + + instructions_to_replace.push_back(instruction); + replacement_instructions.push_back( + FindOrDie(hoisted_instructions, instruction)); + } + + if (instructions_to_replace.empty()) { + return false; + } + + TF_ASSIGN_OR_RETURN( + WhileUtil::MakeInstructionsLiveInResult live_in_instructions_result, + WhileUtil::MakeInstructionsLiveIn(while_instr, replacement_instructions)); + + HloComputation* new_while_body = + live_in_instructions_result.new_while_instr->while_body(); + + for (int i = 0; i < instructions_to_replace.size(); i++) { + HloInstruction* instruction_to_replace_in_new_while = + FindOrDie(live_in_instructions_result.while_body_instruction_map, + instructions_to_replace[i]); + TF_RETURN_IF_ERROR(new_while_body->ReplaceInstruction( + instruction_to_replace_in_new_while, + live_in_instructions_result.while_body_live_in_values[i])); + } + + VLOG(1) << "Hoisted " << instructions_to_replace.size() + << " instructions from " << while_instr_name; + + return true; +} + +StatusOr WhileLoopInvariantCodeMotion::Run(HloModule* module) { + bool changed = false; + std::vector while_instrs; + for (auto* comp : module->computations()) { + c_copy_if(comp->instructions(), std::back_inserter(while_instrs), + [](const HloInstruction* instr) { + return instr->opcode() == HloOpcode::kWhile; + }); + } + + for (HloInstruction* while_instr : while_instrs) { + // Right now we only hoist computations from the while body, but + // TryHoistingInvariantInstructionsFromWhileBody can be generalized to + // optimize the condition computation too, if needed. + // + // The transform we do here is a pessmization for while loops that execute + // zero times*, but at this time we expect those to be rare. If this + // becomes a problem we can consider using the conditional HLO to avoid + // doing extra work for while loops with zero trip count. + // + // * We delete while loops that have a zero trip count, so this would have + // to be a while loop with a somewhat opaque condition expression. + + TF_ASSIGN_OR_RETURN( + bool result, + TryHoistingInvariantInstructionsFromWhileBody(while_instr)); + changed |= result; + } + return changed; +} +} // namespace xla diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h new file mode 100644 index 0000000000000000000000000000000000000000..8c4b765b0003c48cfacb9d28e7c8259ac0927d66 --- /dev/null +++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h @@ -0,0 +1,39 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_INVARIANT_CODE_MOTION_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_INVARIANT_CODE_MOTION_H_ + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace xla { + +// HLO pass that rewrites while loops to hoist loop invariant instructions in +// the while body into the computation that contains the while instruction. + +class WhileLoopInvariantCodeMotion : public HloPassInterface { + public: + ~WhileLoopInvariantCodeMotion() override = default; + + tensorflow::StringPiece name() const override { + return "while-loop-invariant-code-motion"; + } + StatusOr Run(HloModule* module) override; +}; +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_INVARIANT_CODE_MOTION_H_ diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..799340fda905fb7d40b19b4cb79bb0fcb5629fd3 --- /dev/null +++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc @@ -0,0 +1,442 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h" + +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { +namespace { + +namespace op = xla::testing::opcode_matchers; + +class WhileLoopInvariantCodeMotionTest : public HloVerifiedTestBase { + public: + // Makes a computation which has one parameter, of the given shape, and always + // returns PRED[]{true}. This is useful as a dummy loop condition. + HloComputation* MakeAlwaysTrueComputation(const Shape& param_shape, + HloModule* module); +}; + +static void FindOnlyWhileInstruction(HloComputation* computation, + HloInstruction** while_instruction) { + *while_instruction = nullptr; + for (auto* instr : computation->instructions()) { + if (instr->opcode() == HloOpcode::kWhile) { + ASSERT_EQ(*while_instruction, nullptr); + *while_instruction = instr; + } + } + + ASSERT_NE(*while_instruction, nullptr); +} + +HloComputation* WhileLoopInvariantCodeMotionTest::MakeAlwaysTrueComputation( + const Shape& param_shape, HloModule* module) { + HloComputation::Builder builder(TestName() + ".always_true"); + builder.AddInstruction( + HloInstruction::CreateParameter(0, param_shape, "param")); + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(true))); + return module->AddEmbeddedComputation(builder.Build()); +} + +TEST_F(WhileLoopInvariantCodeMotionTest, HoistOneInvariantOperation) { + auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); + Shape while_shape = + ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32, scalar_s32}); + + HloComputation* while_body = [&]() { + HloComputation::Builder builder(TestName() + ".while_body"); + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, while_shape, "param")); + HloInstruction* gte_0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_s32, param, 0)); + HloInstruction* gte_1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_s32, param, 1)); + HloInstruction* add_result = + builder.AddInstruction(HloInstruction::CreateBinary( + scalar_s32, HloOpcode::kAdd, gte_0, gte_1)); + builder.AddInstruction( + HloInstruction::CreateTuple({gte_0, gte_1, add_result})); + + return module().AddEmbeddedComputation(builder.Build()); + }(); + + HloComputation::Builder builder(TestName()); + auto* init_value = builder.AddInstruction( + HloInstruction::CreateParameter(0, while_shape, "init_value")); + builder.AddInstruction(HloInstruction::CreateWhile( + while_shape, MakeAlwaysTrueComputation(while_shape, &module()), + while_body, init_value)); + HloComputation* entry_computation = + module().AddEntryComputation(builder.Build()); + TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, + WhileLoopInvariantCodeMotion{}.Run(&module())); + EXPECT_TRUE(simplified_loop); + + HloInstruction* transformed_while; + FindOnlyWhileInstruction(entry_computation, &transformed_while); + + EXPECT_THAT(entry_computation->instructions(), Contains(op::Add())); + EXPECT_THAT(transformed_while->while_body()->instructions(), + Each(Not(op::Add()))); +} + +TEST_F(WhileLoopInvariantCodeMotionTest, HoistInvariantOperationTree) { + auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); + Shape while_shape = + ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32, scalar_s32}); + + HloComputation* while_body = [&]() { + HloComputation::Builder builder(TestName() + ".while_body"); + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, while_shape, "param")); + HloInstruction* gte_0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_s32, param, 0)); + HloInstruction* gte_1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_s32, param, 1)); + HloInstruction* gte_2_loop_variant = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_s32, param, 2)); + + HloInstruction* add_result = + builder.AddInstruction(HloInstruction::CreateBinary( + scalar_s32, HloOpcode::kAdd, gte_0, gte_1)); + HloInstruction* mul_result = + builder.AddInstruction(HloInstruction::CreateBinary( + scalar_s32, HloOpcode::kMultiply, add_result, gte_1)); + HloInstruction* negate_result = + builder.AddInstruction(HloInstruction::CreateUnary( + scalar_s32, HloOpcode::kNegate, mul_result)); + HloInstruction* constant = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(4))); + HloInstruction* sub_result = + builder.AddInstruction(HloInstruction::CreateBinary( + scalar_s32, HloOpcode::kSubtract, negate_result, constant)); + HloInstruction* divide_result = + builder.AddInstruction(HloInstruction::CreateBinary( + scalar_s32, HloOpcode::kDivide, sub_result, gte_2_loop_variant)); + builder.AddInstruction( + HloInstruction::CreateTuple({gte_0, gte_1, divide_result})); + + return module().AddEmbeddedComputation(builder.Build()); + }(); + + HloComputation::Builder builder(TestName()); + auto* init_value = builder.AddInstruction( + HloInstruction::CreateParameter(0, while_shape, "init_value")); + builder.AddInstruction(HloInstruction::CreateWhile( + while_shape, MakeAlwaysTrueComputation(while_shape, &module()), + while_body, init_value)); + HloComputation* entry_computation = + module().AddEntryComputation(builder.Build()); + TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, + WhileLoopInvariantCodeMotion{}.Run(&module())); + EXPECT_TRUE(simplified_loop); + + HloInstruction* transformed_while; + FindOnlyWhileInstruction(entry_computation, &transformed_while); + + EXPECT_THAT(entry_computation->instructions(), + AllOf(Contains(op::Add()), Contains(op::Multiply()), + Contains(op::Negate()), Contains(op::Subtract()), + Contains(op::Constant()), + + // The division had a loop varying operand so that better + // not be hoisted. + Not(Contains(op::Divide())))); + + EXPECT_THAT(transformed_while->while_body()->instructions(), + Each(Not(AnyOf(op::Add(), op::Multiply(), op::Negate(), + op::Subtract(), op::Constant())))); + + EXPECT_THAT(transformed_while->while_body()->instructions(), + Contains(op::Divide())); +} + +TEST_F(WhileLoopInvariantCodeMotionTest, + DontHoistTriviallyLoopVaryingComputation) { + // Basic negative test: the add expression is not loop invariant. + auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); + Shape while_shape = ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32}); + + HloComputation* while_body = [&]() { + HloComputation::Builder builder(TestName() + ".while_body"); + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, while_shape, "param")); + HloInstruction* gte_0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_s32, param, 0)); + HloInstruction* gte_1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_s32, param, 1)); + HloInstruction* add_result = + builder.AddInstruction(HloInstruction::CreateBinary( + scalar_s32, HloOpcode::kAdd, gte_0, gte_1)); + builder.AddInstruction(HloInstruction::CreateTuple({gte_0, add_result})); + + return module().AddEmbeddedComputation(builder.Build()); + }(); + + HloComputation::Builder builder(TestName()); + auto* init_value = builder.AddInstruction( + HloInstruction::CreateParameter(0, while_shape, "init_value")); + auto* while_inst = builder.AddInstruction(HloInstruction::CreateWhile( + while_shape, MakeAlwaysTrueComputation(while_shape, &module()), + while_body, init_value)); + + module().AddEntryComputation(builder.Build()); + + TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, + WhileLoopInvariantCodeMotion{}.Run(&module())); + EXPECT_FALSE(simplified_loop); + + EXPECT_THAT(while_inst->while_body()->instructions(), Contains(op::Add())); +} + +TEST_F(WhileLoopInvariantCodeMotionTest, + DontHoistLoopVaryingComputationWithAlternatingTuples) { + auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); + Shape while_shape = + ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32, scalar_s32}); + + HloComputation* while_body = [&]() { + HloComputation::Builder builder(TestName() + ".while_body"); + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, while_shape, "param")); + HloInstruction* gte_0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_s32, param, 0)); + HloInstruction* gte_1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_s32, param, 1)); + HloInstruction* add_result = + builder.AddInstruction(HloInstruction::CreateBinary( + scalar_s32, HloOpcode::kAdd, gte_0, gte_1)); + builder.AddInstruction( + HloInstruction::CreateTuple({gte_1, gte_0, add_result})); + + return module().AddEmbeddedComputation(builder.Build()); + }(); + + HloComputation::Builder builder(TestName()); + auto* init_value = builder.AddInstruction( + HloInstruction::CreateParameter(0, while_shape, "init_value")); + auto* while_inst = builder.AddInstruction(HloInstruction::CreateWhile( + while_shape, MakeAlwaysTrueComputation(while_shape, &module()), + while_body, init_value)); + + module().AddEntryComputation(builder.Build()); + TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, + WhileLoopInvariantCodeMotion{}.Run(&module())); + EXPECT_FALSE(simplified_loop); + + EXPECT_THAT(while_inst->while_body()->instructions(), Contains(op::Add())); +} + +TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistInstructionWithSideEffects) { + auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); + Shape while_shape = ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32}); + + HloComputation* while_body = [&]() { + HloComputation::Builder builder(TestName() + ".while_body"); + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, while_shape, "param")); + HloInstruction* gte_0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_s32, param, 0)); + HloInstruction* gte_1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_s32, param, 1)); + builder.AddInstruction( + HloInstruction::CreateOutfeed(scalar_s32, gte_0, "")); + builder.AddInstruction(HloInstruction::CreateTuple({gte_0, gte_1})); + + return module().AddEmbeddedComputation(builder.Build()); + }(); + + HloComputation::Builder builder(TestName()); + auto* init_value = builder.AddInstruction( + HloInstruction::CreateParameter(0, while_shape, "init_value")); + auto* while_inst = builder.AddInstruction(HloInstruction::CreateWhile( + while_shape, MakeAlwaysTrueComputation(while_shape, &module()), + while_body, init_value)); + + module().AddEntryComputation(builder.Build()); + + TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, + WhileLoopInvariantCodeMotion{}.Run(&module())); + EXPECT_FALSE(simplified_loop); + + EXPECT_THAT(while_inst->while_body()->instructions(), + Contains(op::Outfeed())); +} + +TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistBitcastAlone) { + // The bitcast's user, an outfeed, can't be hoisted, so don't hoist the + // bitcast either. + auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); + auto scalar_f32 = ShapeUtil::MakeShape(F32, {}); + Shape while_shape = ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32}); + + HloComputation* while_body = [&]() { + HloComputation::Builder builder(TestName() + ".while_body"); + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, while_shape, "param")); + HloInstruction* gte_0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_s32, param, 0)); + HloInstruction* gte_1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_s32, param, 1)); + HloInstruction* bitcast_inst = builder.AddInstruction( + HloInstruction::CreateUnary(scalar_f32, HloOpcode::kBitcast, gte_0)); + builder.AddInstruction( + HloInstruction::CreateOutfeed(scalar_f32, bitcast_inst, "")); + builder.AddInstruction(HloInstruction::CreateTuple({gte_0, gte_1})); + + return module().AddEmbeddedComputation(builder.Build()); + }(); + + HloComputation::Builder builder(TestName()); + auto* init_value = builder.AddInstruction( + HloInstruction::CreateParameter(0, while_shape, "init_value")); + auto* while_inst = builder.AddInstruction(HloInstruction::CreateWhile( + while_shape, MakeAlwaysTrueComputation(while_shape, &module()), + while_body, init_value)); + + module().AddEntryComputation(builder.Build()); + + TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, + WhileLoopInvariantCodeMotion{}.Run(&module())); + EXPECT_FALSE(simplified_loop); + + EXPECT_THAT(while_inst->while_body()->instructions(), + Contains(op::Outfeed())); + EXPECT_THAT(while_inst->while_body()->instructions(), + Contains(op::Bitcast())); +} + +TEST_F(WhileLoopInvariantCodeMotionTest, HoistBitcastIfNeeded) { + // The bitcast's user can be hoisted, so hoist the bitcast too. + auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); + auto scalar_f32 = ShapeUtil::MakeShape(F32, {}); + Shape while_shape = + ShapeUtil::MakeTupleShape({scalar_s32, scalar_f32, scalar_f32}); + + HloComputation* while_body = [&]() { + HloComputation::Builder builder(TestName() + ".while_body"); + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, while_shape, "param")); + HloInstruction* gte_0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_s32, param, 0)); + HloInstruction* gte_1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_f32, param, 1)); + HloInstruction* bitcast_inst = builder.AddInstruction( + HloInstruction::CreateUnary(scalar_f32, HloOpcode::kBitcast, gte_0)); + HloInstruction* add_inst = + builder.AddInstruction(HloInstruction::CreateBinary( + scalar_f32, HloOpcode::kAdd, bitcast_inst, gte_1)); + builder.AddInstruction( + HloInstruction::CreateTuple({gte_0, gte_1, add_inst})); + + return module().AddEmbeddedComputation(builder.Build()); + }(); + + HloComputation::Builder builder(TestName()); + auto* init_value = builder.AddInstruction( + HloInstruction::CreateParameter(0, while_shape, "init_value")); + builder.AddInstruction(HloInstruction::CreateWhile( + while_shape, MakeAlwaysTrueComputation(while_shape, &module()), + while_body, init_value)); + + HloComputation* entry_computation = + module().AddEntryComputation(builder.Build()); + + TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, + WhileLoopInvariantCodeMotion{}.Run(&module())); + EXPECT_TRUE(simplified_loop); + + HloInstruction* transformed_while; + FindOnlyWhileInstruction(entry_computation, &transformed_while); + + EXPECT_THAT(transformed_while->while_body()->instructions(), + Each(Not(op::Add()))); + EXPECT_THAT(transformed_while->while_body()->instructions(), + Each(Not(op::Bitcast()))); + EXPECT_THAT(entry_computation->instructions(), Contains(op::Add())); + EXPECT_THAT(entry_computation->instructions(), Contains(op::Bitcast())); +} + +TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistControlDependencies) { + auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); + Shape while_shape = + ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32, scalar_s32}); + + HloComputation* while_body; + { + HloComputation::Builder builder(TestName() + ".while_body"); + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, while_shape, "param")); + HloInstruction* gte_0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_s32, param, 0)); + HloInstruction* gte_1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_s32, param, 1)); + HloInstruction* add_result = + builder.AddInstruction(HloInstruction::CreateBinary( + scalar_s32, HloOpcode::kAdd, gte_0, gte_1)); + TF_ASSERT_OK(param->AddControlDependencyTo(add_result)); + builder.AddInstruction( + HloInstruction::CreateTuple({gte_0, gte_1, add_result})); + + while_body = module().AddEmbeddedComputation(builder.Build()); + } + + HloComputation::Builder builder(TestName()); + auto* init_value = builder.AddInstruction( + HloInstruction::CreateParameter(0, while_shape, "init_value")); + builder.AddInstruction(HloInstruction::CreateWhile( + while_shape, MakeAlwaysTrueComputation(while_shape, &module()), + while_body, init_value)); + module().AddEntryComputation(builder.Build()); + TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, + WhileLoopInvariantCodeMotion{}.Run(&module())); + EXPECT_FALSE(simplified_loop); +} + +TEST_F(WhileLoopInvariantCodeMotionTest, BodyHasNonTupleRoot) { + auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); + Shape while_shape = ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32}); + + HloComputation* while_body = [&]() { + HloComputation::Builder builder(TestName() + ".passthrough"); + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, while_shape, "param")); + HloComputation* result = module().AddEmbeddedComputation(builder.Build()); + + result->AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_s32, param, 1)); + return result; + }(); + + HloComputation::Builder builder(TestName()); + auto* init_value = builder.AddInstruction( + HloInstruction::CreateParameter(0, while_shape, "init_value")); + builder.AddInstruction(HloInstruction::CreateWhile( + while_shape, MakeAlwaysTrueComputation(while_shape, &module()), + while_body, init_value)); + module().AddEntryComputation(builder.Build()); + TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, + WhileLoopInvariantCodeMotion{}.Run(&module())); + EXPECT_FALSE(simplified_loop); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.cc b/tensorflow/compiler/xla/service/while_loop_simplifier.cc index b38ee907d70e29093c5cef718e1432663015728b..87a7f86f4ec9844de3e350d7774093dd6248dd83 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier.cc +++ b/tensorflow/compiler/xla/service/while_loop_simplifier.cc @@ -236,7 +236,7 @@ static optional GetLoopTripCount(HloInstruction* while_op) { VLOG(2) << "Couldn't evaluate while cond: " << result.status(); return nullopt; } - return result.ValueOrDie()->GetArraySlice() == + return result.ValueOrDie()->data() == tensorflow::gtl::ArraySlice{true}; }; @@ -289,7 +289,7 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { // Don't try this transformation if the while loop isn't removable, since if // it succeeds ultimately we're going to have to replace the old while loop // with a new one. - if (!while_op->parent()->IsRemovable(while_op)) { + if (!while_op->parent()->IsRemovable(while_op) || while_op->HasSideEffect()) { VLOG(2) << "Can't remove dead parameters from non-removable while op."; return false; } @@ -306,6 +306,13 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { return false; } + if (while_body_root->opcode() != HloOpcode::kTuple) { + VLOG(2) << "While body's root is not a tuple(...) instruction."; + return false; + } + + auto print_no_metadata = HloPrintOptions().set_print_metadata(false); + // Bail if param0 of while_cond or while_body has users which aren't of type // get-tuple-element. for (const HloInstruction* instr : {while_body->parameter_instruction(0), @@ -313,9 +320,10 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { for (const HloInstruction* user : instr->users()) { if (user->opcode() != HloOpcode::kGetTupleElement) { VLOG(2) << "Cowardly refusing to analyze while loop with " - << instr->ToStringNoMetadata() - << " used by non-GTE instruction " << user->ToStringNoMetadata() - << " in computation " << instr->parent()->name(); + << instr->ToString(print_no_metadata) + << " used by non-GTE instruction " + << user->ToString(print_no_metadata) << " in computation " + << instr->parent()->name(); return false; } } @@ -351,7 +359,7 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { used_tuple_indices.insert(user->tuple_index()); if (used_tuple_indices.size() == tuple_size) { - VLOG(2) << "Loop " << while_op->ToStringNoMetadata() + VLOG(2) << "Loop " << while_op->ToString(print_no_metadata) << " uses all of its inputs; no simplification possible."; return false; } @@ -375,7 +383,7 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { used_tuple_indices.insert(i); if (used_tuple_indices.size() == tuple_size) { - VLOG(2) << "Loop " << while_op->ToStringNoMetadata() + VLOG(2) << "Loop " << while_op->ToString(print_no_metadata) << " uses all of its inputs; no simplification possible."; return false; } @@ -387,7 +395,8 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { CHECK_LT(used_tuple_indices.size(), tuple_size); VLOG(1) << "Eliminating " << tuple_size - used_tuple_indices.size() - << " elements from tuple of " << while_op->ToStringNoMetadata(); + << " elements from tuple of " + << while_op->ToString(print_no_metadata); // Build up maps from the old/new to the new/old tuple indices. std::vector new_to_old_tuple_idx(used_tuple_indices.begin(), @@ -431,7 +440,7 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { continue; } CHECK_EQ(user->opcode(), HloOpcode::kGetTupleElement) - << user->ToStringNoMetadata(); + << user->ToString(print_no_metadata); int64 old_idx = user->tuple_index(); auto new_idx_iter = old_to_new_tuple_idx.find(old_idx); @@ -446,14 +455,14 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { CHECK(user->user_count() == 0 || user->user_count() == 1 && user->users().front() == while_body_root) - << "Instruction " << user->ToStringNoMetadata() + << "Instruction " << user->ToString(print_no_metadata) << " should be unused (except by root of while body), but has " "users: {" << tensorflow::str_util::Join( user->users(), ", ", - [](string* out, const HloInstruction* instr) { + [&](string* out, const HloInstruction* instr) { tensorflow::strings::StrAppend( - out, instr->ToStringNoMetadata()); + out, instr->ToString(print_no_metadata)); }) << "}"; @@ -558,7 +567,7 @@ static StatusOr TryRemoveWhileLoop(HloInstruction* while_op) { // the loop aren't removed, just cloned and added back to the loop. // Nevertheless our infrastructure sees loop simplification as removal of // these nodes and currently doesn't allow it. - if (!while_op->parent()->IsRemovable(while_op)) { + if (!while_op->parent()->IsRemovable(while_op) || while_op->HasSideEffect()) { VLOG(2) << "Not attempting to remove while loop it is not removable: " << while_op->ToShortString(); return false; @@ -586,7 +595,9 @@ static StatusOr TryRemoveWhileLoop(HloInstruction* while_op) { auto call_op = computation->AddInstruction(HloInstruction::CreateCall( while_op->shape(), while_op->operands(), while_op->while_body())); TF_RETURN_IF_ERROR(computation->ReplaceInstruction(while_op, call_op)); - TF_RETURN_IF_ERROR(CallInliner::Inline(call_op)); + TF_ASSIGN_OR_RETURN(auto inlined_instructions_map, + CallInliner::Inline(call_op)); + (void)inlined_instructions_map; return true; } return false; diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc index d99b31dc0037968bc88d5f22d53309a6a4546963..c5183f8d3aee99696ed4114c3f7e451888222137 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc @@ -418,5 +418,32 @@ TEST_F(WhileLoopSimplifierTest, RemoveUnusedOperand) { op::GetTupleElement(op::Parameter(0), /*tuple_index=*/1))); } +TEST_F(WhileLoopSimplifierTest, BodyHasNonTupleRoot) { + auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); + Shape while_shape = ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32}); + + HloComputation* while_body = [&]() { + HloComputation::Builder builder(TestName() + ".passthrough"); + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, while_shape, "param")); + HloComputation* result = module().AddEmbeddedComputation(builder.Build()); + + result->AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_s32, param, 1)); + return result; + }(); + + HloComputation::Builder builder(TestName()); + auto* init_value = builder.AddInstruction( + HloInstruction::CreateParameter(0, while_shape, "init_value")); + builder.AddInstruction(HloInstruction::CreateWhile( + while_shape, MakeAlwaysTrueComputation(while_shape, &module()), + while_body, init_value)); + module().AddEntryComputation(builder.Build()); + TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, + WhileLoopSimplifier{}.Run(&module())); + EXPECT_FALSE(simplified_loop); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/while_util.cc b/tensorflow/compiler/xla/service/while_util.cc new file mode 100644 index 0000000000000000000000000000000000000000..e20b25e4a08a946f6b58575a4d4e557744f8035c --- /dev/null +++ b/tensorflow/compiler/xla/service/while_util.cc @@ -0,0 +1,140 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/while_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/tuple_util.h" + +namespace xla { + +static StatusOr WidenWhileCondition( + HloComputation* narrow_condition, const Shape& wide_shape) { + const Shape& narrow_shape = + narrow_condition->parameter_instruction(0)->shape(); + + HloComputation* wide_while_cond = [&]() { + HloComputation::Builder builder( + tensorflow::strings::StrCat("wide.", narrow_condition->name())); + builder.AddInstruction( + HloInstruction::CreateParameter(0, wide_shape, "wide_param")); + + // This is needed so that the root instruction is shaped as a PRED[] -- we + // need to get this right to begin with since we can't mutate the type of + // the root instruction later. We later change the root instruction to + // something more appropriate. + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(false))); + return narrow_condition->parent()->AddEmbeddedComputation(builder.Build()); + }(); + + HloInstruction* truncated_parameter = + TupleUtil::ExtractPrefix(wide_while_cond->parameter_instruction(0), + narrow_shape.tuple_shapes_size()); + HloInstruction* call_narrow_cond = wide_while_cond->AddInstruction( + HloInstruction::CreateCall(ShapeUtil::MakeShape(PRED, {}), + {truncated_parameter}, narrow_condition)); + + wide_while_cond->set_root_instruction(call_narrow_cond); + + TF_RETURN_IF_ERROR(CallInliner::Inline(call_narrow_cond).status()); + return wide_while_cond; +} + +static StatusOr> +WidenWhileBody(HloComputation* narrow_body, const Shape& wide_shape) { + const Shape& narrow_shape = narrow_body->parameter_instruction(0)->shape(); + + HloComputation* wide_while_body = [&]() { + HloComputation::Builder builder( + tensorflow::strings::StrCat("wide.", narrow_body->name())); + builder.AddInstruction( + HloInstruction::CreateParameter(0, wide_shape, "wide_param")); + return narrow_body->parent()->AddEmbeddedComputation(builder.Build()); + }(); + + HloInstruction* wide_parameter = wide_while_body->parameter_instruction(0); + HloInstruction* truncated_parameter = TupleUtil::ExtractPrefix( + wide_parameter, narrow_shape.tuple_shapes_size()); + HloInstruction* call_narrow_body = + wide_while_body->AddInstruction(HloInstruction::CreateCall( + narrow_shape, {truncated_parameter}, narrow_body)); + + std::vector live_through_values; + for (int i = narrow_shape.tuple_shapes_size(); + i < wide_shape.tuple_shapes_size(); i++) { + live_through_values.push_back( + wide_while_body->AddInstruction(HloInstruction::CreateGetTupleElement( + wide_shape.tuple_shapes(i), wide_parameter, i))); + } + + wide_while_body->set_root_instruction( + TupleUtil::AppendSuffix(call_narrow_body, live_through_values)); + + TF_ASSIGN_OR_RETURN(auto inlined_instructions_map, + CallInliner::Inline(call_narrow_body)); + return {{wide_while_body, std::move(inlined_instructions_map)}}; +} + +/*static*/ StatusOr +WhileUtil::MakeInstructionsLiveIn( + HloInstruction* while_instr, + tensorflow::gtl::ArraySlice instructions) { + CHECK(ShapeUtil::IsTuple(while_instr->shape())); + + int64 elements_in_old_while_shape = while_instr->shape().tuple_shapes_size(); + Shape new_while_shape = while_instr->shape(); + for (auto* instruction : instructions) { + *new_while_shape.add_tuple_shapes() = instruction->shape(); + } + + TF_ASSIGN_OR_RETURN( + HloComputation * new_while_condition, + WidenWhileCondition(while_instr->while_condition(), new_while_shape)); + + HloComputation* new_while_body; + CallInliner::InlinedInstructionMap inlined_instructions_map; + TF_ASSIGN_OR_RETURN( + std::tie(new_while_body, inlined_instructions_map), + WidenWhileBody(while_instr->while_body(), new_while_shape)); + + HloInstruction* new_while_init = + TupleUtil::AppendSuffix(while_instr->mutable_operand(0), instructions); + HloComputation* containing_computation = while_instr->parent(); + HloInstruction* new_while = containing_computation->AddInstruction( + HloInstruction::CreateWhile(new_while_shape, new_while_condition, + new_while_body, new_while_init)); + TF_RETURN_IF_ERROR(containing_computation->ReplaceInstruction( + while_instr, TupleUtil::ExtractPrefix( + new_while, while_instr->shape().tuple_shapes_size()))); + + HloInstruction* while_body_param = new_while_body->parameter_instruction(0); + std::vector live_in_instructions; + for (int64 i = elements_in_old_while_shape; + i < new_while_shape.tuple_shapes_size(); i++) { + live_in_instructions.push_back( + new_while_body->AddInstruction(HloInstruction::CreateGetTupleElement( + instructions[i - elements_in_old_while_shape]->shape(), + while_body_param, i))); + } + + WhileUtil::MakeInstructionsLiveInResult result; + + result.new_while_instr = new_while; + result.while_body_live_in_values = std::move(live_in_instructions); + result.while_body_instruction_map = std::move(inlined_instructions_map); + + return std::move(result); +} +} // namespace xla diff --git a/tensorflow/compiler/xla/service/while_util.h b/tensorflow/compiler/xla/service/while_util.h new file mode 100644 index 0000000000000000000000000000000000000000..3600b5a80d26e37fdb7d5173c3b8743734306390 --- /dev/null +++ b/tensorflow/compiler/xla/service/while_util.h @@ -0,0 +1,58 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_UTIL_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_UTIL_H_ + +#include "tensorflow/compiler/xla/service/call_inliner.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" + +namespace xla { +class WhileUtil { + public: + // Holds a return value from MakeInstructionsLiveIn. + struct MakeInstructionsLiveInResult { + // The new while operation that has the requested values live in. + HloInstruction* new_while_instr; + + // The i'th element of `while_body_live_in_values` is an instruction in the + // while body that holds the i'th *newly added* live in value at runtime. + std::vector while_body_live_in_values; + + // `while_body_instruction_map` maps instructions in the original while body + // to the corresponding instructions in the body for the newly created while + // operation. + CallInliner::InlinedInstructionMap while_body_instruction_map; + }; + + // Replaces `while_instr` with a new while instruction that is equivalent to + // `while_instr`, except that it has all of the HLO instructions in + // `instructions` as live-in, loop invariant values. These new live in values + // are represented as new elements appended to the parameter of the while + // loop, which must be of tuple shape. GetTupleElement instructions computing + // each new live in value is returned in the `while_body_live_in_values` + // vector. + // + // Precondition: `while_instr` must have a tuple shaped state. + // + // Every instruction in `instructions` must be contained in the computation + // that contains `while_instr`. + static StatusOr MakeInstructionsLiveIn( + HloInstruction* while_instr, + tensorflow::gtl::ArraySlice instructions); +}; +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_UTIL_H_ diff --git a/tensorflow/compiler/xla/service/while_util_test.cc b/tensorflow/compiler/xla/service/while_util_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..cf0d0db99bd92b6b364b4e28e56a0902d4065963 --- /dev/null +++ b/tensorflow/compiler/xla/service/while_util_test.cc @@ -0,0 +1,130 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/while_util.h" + +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" + +namespace xla { +namespace { + +namespace op = ::xla::testing::opcode_matchers; + +StatusOr> GetParsedModule( + HloComputation** entry_computation, HloInstruction** param0, + HloInstruction** param1, HloInstruction** param2) { + const char* const hlo_string = R"( +HloModule ModuleWithWhile + +while_body { + ROOT p_body = (f32[32,32]{1,0}, f32[32,32]{1,0}) parameter(0) +} + +while_condition { + p_cond = f32[32,32]{1,0} parameter(0) + ROOT result = pred[] constant(true) +} + +ENTRY entry { + p_entry_0 = f32[32,32]{1,0} parameter(0) + p_entry_1 = s32[32,32]{1,0} parameter(1) + p_entry_2 = s64[32,32]{1,0} parameter(2) + while_init = (f32[32,32]{1,0}, f32[32,32]{1,0}) tuple(p_entry_0, p_entry_0) + ROOT while = (f32[32,32]{1,0}, f32[32,32]{1,0}) while(while_init), condition=while_condition, body=while_body +} +)"; + + TF_ASSIGN_OR_RETURN(std::unique_ptr module, + tools::Parse(hlo_string)); + + *entry_computation = module->entry_computation(); + *param0 = (*entry_computation)->parameter_instruction(0); + *param1 = (*entry_computation)->parameter_instruction(1); + *param2 = (*entry_computation)->parameter_instruction(2); + + return std::move(module); +} + +TEST(WhileUtil, MakeZeroInstructionsLiveOp) { + HloInstruction *param0, *param1, *param2; + HloComputation* entry_computation; + + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + GetParsedModule(&entry_computation, ¶m0, ¶m1, ¶m2)); + + HloInstruction* while_instr = entry_computation->root_instruction(); + ASSERT_EQ(while_instr->opcode(), HloOpcode::kWhile); + + TF_ASSERT_OK_AND_ASSIGN( + WhileUtil::MakeInstructionsLiveInResult make_live_in_result, + WhileUtil::MakeInstructionsLiveIn(while_instr, /*instructions=*/{})); + + HloInstruction* new_while_instr = make_live_in_result.new_while_instr; + + EXPECT_THAT( + entry_computation->root_instruction(), + op::Tuple(op::GetTupleElement(::testing::Eq(new_while_instr), 0), + op::GetTupleElement(::testing::Eq(new_while_instr), 1))); + + auto param_reconstructed = + op::Tuple(op::GetTupleElement(op::Parameter(0), 0), + op::GetTupleElement(op::Parameter(0), 1)); + + EXPECT_THAT(new_while_instr->while_body()->root_instruction(), + op::Tuple(op::GetTupleElement(param_reconstructed, 0), + op::GetTupleElement(param_reconstructed, 1))); +} + +TEST(WhileUtilTest, MakeTwoInstructionsLive) { + HloInstruction *param0, *param1, *param2; + HloComputation* entry_computation; + + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + GetParsedModule(&entry_computation, ¶m0, ¶m1, ¶m2)); + + HloInstruction* while_instr = entry_computation->root_instruction(); + ASSERT_EQ(while_instr->opcode(), HloOpcode::kWhile); + + TF_ASSERT_OK_AND_ASSIGN( + WhileUtil::MakeInstructionsLiveInResult make_live_in_result, + WhileUtil::MakeInstructionsLiveIn(while_instr, + /*instructions=*/{param0, param1})); + + HloInstruction* new_while_instr = make_live_in_result.new_while_instr; + + XLA_VLOG_LINES(3, module->ToString()); + + EXPECT_THAT( + entry_computation->root_instruction(), + op::Tuple(op::GetTupleElement(::testing::Eq(new_while_instr), 0), + op::GetTupleElement(::testing::Eq(new_while_instr), 1))); + + auto first_half_param_reconstructed = + op::Tuple(op::GetTupleElement(op::Parameter(0), 0), + op::GetTupleElement(op::Parameter(0), 1)); + + EXPECT_THAT(new_while_instr->while_body()->root_instruction(), + op::Tuple(op::GetTupleElement(first_half_param_reconstructed, 0), + op::GetTupleElement(first_half_param_reconstructed, 1), + op::GetTupleElement(op::Parameter(0), 2), + op::GetTupleElement(op::Parameter(0), 3))); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.cc b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.cc new file mode 100644 index 0000000000000000000000000000000000000000..aa40b5cb264803097f52966d6f61f1f41b6b3017 --- /dev/null +++ b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.cc @@ -0,0 +1,50 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h" + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { + +StatusOr ZeroSizedHloElimination::Run(HloModule* module) { + bool changed = false; + for (HloComputation* comp : module->MakeNonfusionComputations()) { + for (HloInstruction* instruction : comp->MakeInstructionPostOrder()) { + if (instruction->HasSideEffect() || + ShapeUtil::IsTuple(instruction->shape())) { + continue; + } + if (comp->IsRemovable(instruction) && + ShapeUtil::HasZeroElements(instruction->shape())) { + TF_RETURN_IF_ERROR(comp->ReplaceWithNewInstruction( + instruction, HloInstruction::CreateConstant( + Literal::CreateFromShape(instruction->shape())))); + changed = true; + } + } + } + return changed; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h new file mode 100644 index 0000000000000000000000000000000000000000..63afab4206eb072e84745ced3307295c0516da7b --- /dev/null +++ b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h @@ -0,0 +1,32 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_ZERO_SIZED_HLO_ELIMINATION_H_ +#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_ZERO_SIZED_HLO_ELIMINATION_H_ + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +// HLO pass that replaces zero sized Hlos with an zero sized constant literal. +namespace xla { +class ZeroSizedHloElimination : public HloPassInterface { + public: + StatusOr Run(HloModule* module) override; + tensorflow::StringPiece name() const override { + return "zero_sized_hlo_elimination"; + } +}; +} // namespace xla +#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_ZERO_SIZED_HLO_ELIMINATION_H_ diff --git a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..4f8cdc1e0e73cdaa8675fc945ba3dbe19ce3da7d --- /dev/null +++ b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc @@ -0,0 +1,77 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h" + +#include +#include +#include + +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/shape_inference.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { +namespace { +class ZeroSizedHloEliminationTest : public HloTestBase { + protected: + ZeroSizedHloEliminationTest() + : HloTestBase(), + builder_("zero_sized_computation"), + zero_sized_param_( + builder_.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {3, 0}), "zero sized param"))) {} + + StatusOr RunZeroSizedElimination() { + HloModule module("zero_sized_elimination_test_module"); + module.AddEntryComputation(builder_.Build()); + return ZeroSizedHloElimination{}.Run(&module); + } + + HloComputation::Builder builder_; + HloInstruction* zero_sized_param_; +}; + +TEST_F(ZeroSizedHloEliminationTest, EliminatedZeroSizedOp) { + builder_.AddInstruction(HloInstruction::CreateUnary( + zero_sized_param_->shape(), HloOpcode::kTanh, zero_sized_param_)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunZeroSizedElimination()); + EXPECT_TRUE(changed); +} + +TEST_F(ZeroSizedHloEliminationTest, DoesNotEliminateParameter) { + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunZeroSizedElimination()); + EXPECT_FALSE(changed); +} + +TEST_F(ZeroSizedHloEliminationTest, DoesNotEliminateSideEffects) { + builder_.AddInstruction(HloInstruction::CreateSend(zero_sized_param_, 0)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunZeroSizedElimination()); + EXPECT_FALSE(changed); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/shape_layout.cc b/tensorflow/compiler/xla/shape_layout.cc index 5bf9842a6ce7be747f58c10f302f85c6f82ac6f9..789eba5780d37e1fd4d80ec881855951c8bba0eb 100644 --- a/tensorflow/compiler/xla/shape_layout.cc +++ b/tensorflow/compiler/xla/shape_layout.cc @@ -32,13 +32,13 @@ tensorflow::Status ShapeLayout::CopyLayoutFromShape(const Shape& other_shape) { return tensorflow::Status::OK(); } -tensorflow::Status ShapeLayout::AssignLayoutToShape(Shape* other_shape) const { - if (!ShapeUtil::Compatible(*other_shape, shape_)) { +tensorflow::Status ShapeLayout::AssignLayoutToShape(Shape* to_shape) const { + if (!ShapeUtil::Compatible(*to_shape, shape_)) { return InvalidArgument("Shape %s is not compatible with shape %s", - ShapeUtil::HumanString(*other_shape).c_str(), + ShapeUtil::HumanString(*to_shape).c_str(), ShapeUtil::HumanString(shape()).c_str()); } - *other_shape = shape_; + *to_shape = shape_; return tensorflow::Status::OK(); } diff --git a/tensorflow/compiler/xla/shape_layout.h b/tensorflow/compiler/xla/shape_layout.h index 92564660f21bf1b596c4b9ca04c07eaca27ed192..4c83750f3e6f3c735db66d8e0b86ae3f43e5ca11 100644 --- a/tensorflow/compiler/xla/shape_layout.h +++ b/tensorflow/compiler/xla/shape_layout.h @@ -38,18 +38,19 @@ class ShapeLayout { explicit ShapeLayout(const Shape& shape) : shape_(shape) {} // Assigns the layouts in this ShapeLayout to the Layout fields of the given - // shape. 'shape' and the shape of the ShapeLayout object must be compatible. - tensorflow::Status AssignLayoutToShape(Shape* shape) const; + // shape. 'to_shape' and the shape of the ShapeLayout object must be + // compatible. + tensorflow::Status AssignLayoutToShape(Shape* to_shape) const; // Returns true if the Layouts in this ShapeLayout match the layouts in the // given shape. Returns false otherwise. If the given shape is not compatible // with the ShapeLayout's shape, then false is returned. bool MatchesLayoutInShape(const Shape& shape) const; - // Copies the layout from the given shape into this ShapeLayout. 'shape' must - // be compatible with the ShapeLayout's shape, and 'shape' must have a layout - // (LayoutUtil::HasLayout). - tensorflow::Status CopyLayoutFromShape(const Shape& shape); + // Copies the layout from the given shape into this ShapeLayout. 'other_shape' + // must be compatible with the ShapeLayout's shape, and 'other_shape' must + // have a layout (LayoutUtil::HasLayout). + tensorflow::Status CopyLayoutFromShape(const Shape& other_shape); // Clears (Layout::Clear) all the Layouts stored in this object. void Clear(); diff --git a/tensorflow/compiler/xla/shape_tree.h b/tensorflow/compiler/xla/shape_tree.h index bf8d19015079f2ce0bd450594040ed818f94b66b..d752619bd65751779c24f061e44e206d66b01465 100644 --- a/tensorflow/compiler/xla/shape_tree.h +++ b/tensorflow/compiler/xla/shape_tree.h @@ -238,7 +238,7 @@ class ShapeTree { // (or compatible). // index : the index of the element in the shape. See ShapeUtil::GetSubshape // for definition of index. - // data : The data value at this elemnt. + // data : The data value at this element. template void ForEachElement(const Fn& func) const; diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index 74fa0b2f2e740310be23661caef3f19e24e4087b..cba73322fa924785fbc73a4e931b5f27227d89b9 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include #include +#include #include #include @@ -58,36 +59,47 @@ std::ostream& operator<<(std::ostream& out, const ShapeIndex& shape_index) { return out; } +std::ostream& operator<<(std::ostream& out, const ShapeIndexView& shape_index) { + out << shape_index.ToString(); + return out; +} + namespace { // Recursive helper for comparing the equality of two shapes. Returns true if // the shapes are the same. If compare_layouts is true, then layouts must also // match. bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) { - if (ShapeUtil::IsTuple(lhs)) { - return ShapeUtil::IsTuple(rhs) && + if (ShapeUtil::IsTuple(lhs) || ShapeUtil::IsTuple(rhs)) { + return ShapeUtil::IsTuple(lhs) && ShapeUtil::IsTuple(rhs) && ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(), [=](const Shape& l, const Shape& r) { return CompareShapes(l, r, compare_layouts); }); + } else if (ShapeUtil::IsOpaque(lhs) || ShapeUtil::IsOpaque(rhs)) { + return ShapeUtil::IsOpaque(lhs) && ShapeUtil::IsOpaque(rhs); } - // Explicitly compare the fields rather than using MessageDifferencer because - // we want empty layouts to be treated identically to missing layouts. + if (compare_layouts) { - if (!ContainersEqual(lhs.layout().minor_to_major(), - rhs.layout().minor_to_major())) { - VLOG(3) << "CompareShapes: lhs layout != rhs layout"; - return false; - } - if (!ContainersEqual(lhs.layout().padded_dimensions(), - rhs.layout().padded_dimensions())) { - VLOG(3) - << "CompareShapes: lhs padded_dimensions != rhs padded_dimensions"; + if (lhs.layout().format() != rhs.layout().format()) { return false; } - if (lhs.layout().padding_value() != rhs.layout().padding_value()) { - VLOG(3) << "CompareShapes: lhs padding value != rhs padding_value"; - return false; + if (LayoutUtil::IsDenseArray(lhs)) { + if (!ContainersEqual(LayoutUtil::MinorToMajor(lhs), + LayoutUtil::MinorToMajor(rhs))) { + VLOG(3) << "CompareShapes: lhs layout != rhs layout"; + return false; + } + if (!ContainersEqual(lhs.layout().padded_dimensions(), + rhs.layout().padded_dimensions())) { + VLOG(3) + << "CompareShapes: lhs padded_dimensions != rhs padded_dimensions"; + return false; + } + if (lhs.layout().padding_value() != rhs.layout().padding_value()) { + VLOG(3) << "CompareShapes: lhs padding value != rhs padding_value"; + return false; + } } } @@ -141,7 +153,8 @@ StatusOr MakeShapeWithLayoutInternal( } /* static */ int64 ShapeUtil::Rank(const Shape& shape) { - CHECK(!ShapeUtil::IsTuple(shape)) << "Tuples do not have a rank"; + CHECK(!ShapeUtil::IsTuple(shape)) + << "Tuples do not have a rank, shape: " << shape; return shape.dimensions_size(); } @@ -182,20 +195,32 @@ StatusOr MakeShapeWithLayoutInternal( .ValueOrDie(); } -/* static */ Shape ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout( +/* static */ Shape ShapeUtil::MakeShapeWithDescendingLayout( PrimitiveType element_type, tensorflow::gtl::ArraySlice dimensions) { std::vector layout(dimensions.size()); std::iota(layout.rbegin(), layout.rend(), static_cast(0)); return MakeShapeWithLayout(element_type, dimensions, layout); } -/* static */ Shape ShapeUtil::NormalizeShapeToMonotonicDim0MajorLayout( +/* static */ Shape ShapeUtil::MakeShapeWithSparseLayout( + PrimitiveType element_type, tensorflow::gtl::ArraySlice dimensions, + int64 max_sparse_elements) { + DCHECK_NE(TUPLE, element_type); + DCHECK_NE(OPAQUE, element_type); + Shape shape = ShapeUtil::MakeShape(element_type, dimensions); + *shape.mutable_layout() = LayoutUtil::MakeSparseLayout(max_sparse_elements); + TF_DCHECK_OK(ShapeUtil::ValidateShape(shape)); + return shape; +} + +/* static */ Shape +ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( const Shape& shape) { std::vector dims(shape.dimensions_size()); for (int i = 0; i < shape.dimensions_size(); ++i) { dims[i] = shape.dimensions(LayoutUtil::Major(shape.layout(), i)); } - return MakeShapeWithMonotonicDim0MajorLayout(shape.element_type(), dims); + return MakeShapeWithDescendingLayout(shape.element_type(), dims); } /* static */ void ShapeUtil::PopulateShape( @@ -235,6 +260,7 @@ StatusOr MakeShapeWithLayoutInternal( } /* static */ void ShapeUtil::AppendMajorDimension(int bound, Shape* shape) { + CHECK(LayoutUtil::IsDenseArray(*shape)); shape->mutable_layout()->add_minor_to_major(Rank(*shape)); shape->add_dimensions(bound); TF_DCHECK_OK(ValidateShape(*shape)); @@ -329,6 +355,14 @@ StatusOr MakeShapeWithLayoutInternal( return MakeTupleShape(new_elements); } +// Returns the shape of a real or imaginary component. +/* static */ Shape ShapeUtil::ComplexComponentShape( + const Shape& complex_shape) { + CHECK(ElementIsComplex(complex_shape)) << HumanString(complex_shape); + return ChangeElementType(complex_shape, primitive_util::ComplexComponentType( + complex_shape.element_type())); +} + /* static */ bool ShapeUtil::ShapeIs(const Shape& shape, PrimitiveType element_type, std::initializer_list dimensions) { @@ -336,7 +370,7 @@ StatusOr MakeShapeWithLayoutInternal( } /* static */ int64 ShapeUtil::ElementsIn(const Shape& shape) { - CHECK(!IsTuple(shape)); + CHECK(!IsTuple(shape)) << ShapeUtil::HumanString(shape); CHECK_EQ(shape.dimensions_size(), Rank(shape)); return std::accumulate( shape.dimensions().begin(), shape.dimensions().end(), 1LL, @@ -352,7 +386,7 @@ StatusOr MakeShapeWithLayoutInternal( } /* static */ string ShapeUtil::HumanString(const Shape& shape) { - if (shape.element_type() == TUPLE) { + if (IsTuple(shape)) { string text = "("; const char* prefix = ""; for (const Shape& elem_shape : shape.tuple_shapes()) { @@ -396,10 +430,30 @@ const string& LowercasePrimitiveTypeName(PrimitiveType s) { static PrimitiveTypeNameGenerator* gen = new PrimitiveTypeNameGenerator(); return gen->LowercaseName(s); } + +StatusOr StringToPrimitiveType(const string& name) { + static std::unordered_map* name_to_type = [] { + static auto* map = new std::unordered_map; + for (int i = 0; i < PrimitiveType_ARRAYSIZE; i++) { + if (PrimitiveType_IsValid(i)) { + auto value = static_cast(i); + (*map)[LowercasePrimitiveTypeName(value)] = value; + } + } + return map; + }(); + auto found = name_to_type->find(name); + if (found == name_to_type->end()) { + return InvalidArgument("Invalid element type string: \"%s\".", + name.c_str()); + } + return found->second; +} + } // namespace /* static */ string ShapeUtil::HumanStringWithLayout(const Shape& shape) { - if (shape.element_type() == TUPLE) { + if (IsTuple(shape)) { string text = "("; const char* prefix = ""; for (const Shape& elem_shape : shape.tuple_shapes()) { @@ -470,26 +524,35 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { string element_type_string; string dimensions_string; + string format_string; string layout_string; // tensorflow::StringPiece is not compatible with internal RE2 StringPiece, so // we convert in to the RE2-consumable type and then consume the corresponding // amount from our StringPiece type. tensorflow::RegexpStringPiece s_consumable(s->data(), s->size()); - if (RE2::Consume(&s_consumable, - "^(\\w*\\d*)\\[([\\d,]*)\\](?:\\s*{([\\d,]*)})?", - &element_type_string, &dimensions_string, &layout_string)) { + if (RE2::Consume( + &s_consumable, + "^(\\w*\\d*)\\[([\\d,]*)\\](?:\\s*(dense|sparse)?\\s*{([\\d,]+)})?", + &element_type_string, &dimensions_string, &format_string, + &layout_string)) { size_t consumed = s->size() - s_consumable.size(); s->remove_prefix(consumed); + auto string_to_int64 = [&s](const string& input) -> StatusOr { + int64 element; + if (!tensorflow::strings::safe_strto64(input.c_str(), &element)) { + return InvalidArgument( + "Invalid s64 value in parsed shape string: \"%s\" in \"%s\"", + input.c_str(), s->ToString().c_str()); + } + return element; + }; + auto comma_list_to_int64s = - [&s](const string& input) -> StatusOr> { + [&s, + string_to_int64](const string& input) -> StatusOr> { std::vector results; for (const string& piece : tensorflow::str_util::Split(input, ',')) { - int64 element; - if (!tensorflow::strings::safe_strto64(piece.c_str(), &element)) { - return InvalidArgument( - "Invalid s64 value in parsed shape string: \"%s\" in \"%s\"", - piece.c_str(), s->ToString().c_str()); - } + TF_ASSIGN_OR_RETURN(int64 element, string_to_int64(piece)); results.push_back(element); } return results; @@ -500,31 +563,32 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { comma_list_to_int64s(dimensions_string)); // Extract the primitive element type. - PrimitiveType primitive_type = PRIMITIVE_TYPE_INVALID; - for (PrimitiveType i = - static_cast(PRIMITIVE_TYPE_INVALID + 1); - i < TUPLE; i = static_cast(i + 1)) { - if (tensorflow::str_util::Lowercase(PrimitiveType_Name(i)) == - element_type_string) { - primitive_type = i; - break; - } - } - if (primitive_type == PRIMITIVE_TYPE_INVALID) { + TF_ASSIGN_OR_RETURN(const PrimitiveType primitive_type, + StringToPrimitiveType(element_type_string)); + if (primitive_type == PRIMITIVE_TYPE_INVALID || primitive_type == TUPLE || + primitive_type == OPAQUE) { return InvalidArgument("Invalid element type string: \"%s\".", element_type_string.c_str()); } Shape result; - if (layout_string.empty()) { + if (format_string.empty() && layout_string.empty()) { // Create a shape without a layout set. result = ShapeUtil::MakeShape(primitive_type, dimensions); - } else { + } else if (format_string == "sparse") { + TF_ASSIGN_OR_RETURN(int64 max_elements, string_to_int64(layout_string)); + result = ShapeUtil::MakeShapeWithSparseLayout(primitive_type, dimensions, + max_elements); + } else if (format_string.empty() || format_string == "dense") { // Extract the layout minor-to-major and set it. TF_ASSIGN_OR_RETURN(std::vector min2maj, comma_list_to_int64s(layout_string)); TF_ASSIGN_OR_RETURN(result, MakeShapeWithLayoutInternal( primitive_type, dimensions, min2maj)); + } else { + // This should not be reached. + LOG(FATAL) << "Unhandled condition when parsing shape; format: \"" + << format_string << "\", layout: \"" << layout_string << "\""; } TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(result)); return std::move(result); @@ -537,7 +601,12 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { /* static */ StatusOr ShapeUtil::ParseShapeString( tensorflow::StringPiece s) { - return ParseShapeStringInternal(&s); + TF_ASSIGN_OR_RETURN(Shape shape, ParseShapeStringInternal(&s)); + if (!s.empty()) { + return InvalidArgument("Invalid shape string to parse: \"%s\"", + s.ToString().c_str()); + } + return shape; } /* static */ bool ShapeUtil::SameDimensions(const Shape& lhs, @@ -622,23 +691,55 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { TF_DCHECK_OK(ValidateShape(shape)); DCHECK_NE(OPAQUE, shape.element_type()); if (shape.element_type() == TUPLE) { - CHECK_GT(pointer_size, 0); - return pointer_size * shape.tuple_shapes_size(); + return ByteSizeOfTupleIndexTable(shape, pointer_size); } + int64 byte_size = ByteSizeOfElements(shape); + if (LayoutUtil::IsSparseArray(shape)) { + byte_size += ByteSizeOfSparseIndices(shape); + } + return byte_size; +} + +/* static */ int64 ShapeUtil::ByteSizeOfTupleIndexTable(const Shape& shape, + int64 pointer_size) { + TF_DCHECK_OK(ValidateShape(shape)); + DCHECK_EQ(TUPLE, shape.element_type()); + CHECK_GT(pointer_size, 0); + return pointer_size * shape.tuple_shapes_size(); +} + +/* static */ int64 ShapeUtil::ByteSizeOfElements(const Shape& shape) { + TF_DCHECK_OK(ValidateShape(shape)); + DCHECK(ShapeUtil::IsArray(shape)); int64 allocated_element_count; - if (shape.layout().padded_dimensions_size() > 0) { - CHECK_EQ(Rank(shape), shape.layout().padded_dimensions_size()); - allocated_element_count = 1; - for (int64 dimension_size : shape.layout().padded_dimensions()) { - allocated_element_count *= dimension_size; - } + + if (LayoutUtil::IsSparseArray(shape)) { + allocated_element_count = LayoutUtil::MaxSparseElements(shape.layout()); } else { - allocated_element_count = ElementsIn(shape); + CHECK(LayoutUtil::IsDenseArray(shape)); + tensorflow::gtl::ArraySlice padded_dimensions = + LayoutUtil::PaddedDimensions(shape); + if (!padded_dimensions.empty()) { + CHECK_EQ(Rank(shape), padded_dimensions.size()); + allocated_element_count = 1; + for (int64 dimension_size : padded_dimensions) { + allocated_element_count *= dimension_size; + } + } else { + allocated_element_count = ElementsIn(shape); + } } return allocated_element_count * ByteSizeOfPrimitiveType(shape.element_type()); } +/* static */ int64 ShapeUtil::ByteSizeOfSparseIndices(const Shape& shape) { + TF_DCHECK_OK(ValidateShape(shape)); + DCHECK(LayoutUtil::IsSparseArray(shape)); + return LayoutUtil::MaxSparseElements(shape.layout()) * + ShapeUtil::Rank(shape) * sizeof(int64); +} + /* static */ Status ShapeUtil::ValidateShapeWithOptionalLayoutInternal( const Shape& shape) { if (shape.element_type() == TUPLE) { @@ -694,9 +795,9 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { return LayoutUtil::ValidateLayoutInShape(shape); } -/* static */ Shape ShapeUtil::ChangeElementType(const Shape& shape, +/* static */ Shape ShapeUtil::ChangeElementType(const Shape& original, PrimitiveType type) { - Shape new_shape = shape; + Shape new_shape = original; new_shape.set_element_type(type); return new_shape; } @@ -705,7 +806,8 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { ShapeIndexView index) { const Shape* return_shape = &shape; for (auto i : index) { - CHECK(IsTuple(*return_shape)); + CHECK(IsTuple(*return_shape)) + << "Invalid index " << index << " for shape " << shape; return_shape = &return_shape->tuple_shapes(i); } return *return_shape; @@ -863,7 +965,9 @@ Status ForEachMutableSubshapeHelper( new_shape.add_dimensions(dim); } if (shape.has_layout()) { + CHECK(LayoutUtil::IsDenseArray(shape)); Layout* new_layout = new_shape.mutable_layout(); + new_layout->set_format(DENSE); new_layout->clear_minor_to_major(); for (auto index : Permute(permutation, shape.layout().minor_to_major())) { new_layout->add_minor_to_major(index); @@ -1117,9 +1221,9 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, // as input_shape/output_shape and the dimension-0-major layout. These two // shapes are used for conversion between logical linear indices and // multi-dimensional indices. - Shape input_shape_dim0_major = MakeShapeWithMonotonicDim0MajorLayout( + Shape input_shape_dim0_major = MakeShapeWithDescendingLayout( input_shape.element_type(), AsInt64Slice(input_shape.dimensions())); - Shape output_shape_dim0_major = MakeShapeWithMonotonicDim0MajorLayout( + Shape output_shape_dim0_major = MakeShapeWithDescendingLayout( output_shape.element_type(), AsInt64Slice(output_shape.dimensions())); for (int64 input_dim = 0; input_dim < Rank(input_shape); ++input_dim) { @@ -1290,6 +1394,7 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, shape.mutable_dimensions()->erase(shape.dimensions().begin() + dim_to_delete); if (LayoutUtil::HasLayout(shape)) { Layout* layout = shape.mutable_layout(); + layout->set_format(DENSE); for (size_t i = 0; i < layout->minor_to_major().size();) { if (layout->minor_to_major(i) == dim_to_delete) { layout->mutable_minor_to_major()->erase( @@ -1319,4 +1424,9 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, return shape; } +std::ostream& operator<<(std::ostream& out, const Shape& shape) { + out << ShapeUtil::HumanString(shape); + return out; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index 2ea1bd95cb571134ab1e1dda37fbc887a1fa06b2..453d4ec04726a4dd3851b8becb439bb7506e4ca9 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -133,6 +134,7 @@ class ShapeIndexView { }; std::ostream& operator<<(std::ostream& out, const ShapeIndex& shape_index); +std::ostream& operator<<(std::ostream& out, const ShapeIndexView& shape_index); // Namespaced collection of (static) shape utilities. // @@ -141,7 +143,10 @@ std::ostream& operator<<(std::ostream& out, const ShapeIndex& shape_index); class ShapeUtil { public: // Returns the number of elements are contained within the provided shape; - // e.g. for rank 0 (scalars) the result is always 1. + // e.g. for rank 0 (scalars) the result is always 1. Note that sparse shapes + // may not actually be able to store this number of elements. See + // LayoutUtil::MaxSparseElements(shape) to obtain the maximum number of + // elements that can be stored in a sparse shape. // Precondition: !IsTuple(shape) static int64 ElementsIn(const Shape& shape); @@ -162,6 +167,27 @@ class ShapeUtil { // Precondition: !ShapeUtil::IsOpaque(shape) && !ShapeUtil::IsTuple(shape) static int64 ByteSizeOfPrimitiveType(PrimitiveType primitive_type); + // Returns the number of bytes required to store the tuple member pointers for + // a allocation of shape. The `shape` must be a TUPLE shape, and + // `pointer_size` must be larger than zero. + static int64 ByteSizeOfTupleIndexTable(const Shape& shape, + int64 pointer_size); + + // Returns the number of bytes required for the elements in an allocation of + // `shape`, which must be an array shape. The return value does not include + // the bytes needed to store sparse indices. Dense shapes use a separate + // memory location for each element, and so for these shapes, + // `ByteSizeOf(shape) == ByteSizeOfElements(shape)`. For dense shapes, this + // size also includes padding if present in the layout. For sparse shapes, + // `ByteSizeOf(shape) == ByteSizeOfElements(shape) + + // ByteSizeOfSparseindices(shape)`. + static int64 ByteSizeOfElements(const Shape& shape); + + // Returns the number of bytes required for the sparse indices in an + // allocation of shape. The shape must be an array shape. The return value + // does not include the bytes needed to store sparse indices. + static int64 ByteSizeOfSparseIndices(const Shape& shape); + // Returns a human-readable string that represents the given shape, with or // without layout. e.g. "f32[42x12] {0, 1}" or "f32[64]". static string HumanString(const Shape& shape); @@ -170,7 +196,7 @@ class ShapeUtil { // As above, but for program shapes, returns a string for the form: // // (param_name: f32[42x12], ...) -> f32[24x42] - static string HumanString(const ProgramShape& shape); + static string HumanString(const ProgramShape& program_shape); // Parses a ShapeUtil::HumanString-format shape string back into a shape // object. @@ -267,14 +293,22 @@ class ShapeUtil { PrimitiveType element_type, tensorflow::gtl::ArraySlice dimensions, tensorflow::gtl::ArraySlice minor_to_major); - // Constructs a new shape with major-first layout. - static Shape MakeShapeWithMonotonicDim0MajorLayout( + static Shape MakeShapeWithSparseLayout( + PrimitiveType element_type, tensorflow::gtl::ArraySlice dimensions, + int64 max_sparse_elements); + + // Constructs a new shape with major-first layout (i.e. {n, n-1, ..., 0}). + static Shape MakeShapeWithDescendingLayout( PrimitiveType element_type, tensorflow::gtl::ArraySlice dimensions); - // Returns a new shape with major-first layout that has the same layout of - // elements with a different shape. - static Shape NormalizeShapeToMonotonicDim0MajorLayout(const Shape& shape); + // Returns a new Shape based on the given Shape with low-dimension-major + // layout (i.e. {n, n-1, ..., 0}, like Fortran), and with the dimensions + // rearranged so that it has the same in-memory layout as the given shape. + // + // For example, transforms f32[B,H,W,C]{0,3,2,1} to f32[H,W,C,B]{3,2,1,0}. + static Shape MakeShapeWithDescendingLayoutAndSamePhysicalLayout( + const Shape& shape); // As MakeShape, but the object to write to is passed in. static void PopulateShape(PrimitiveType element_type, @@ -324,7 +358,8 @@ class ShapeUtil { return shape.element_type() == OPAQUE; } - // Returns whether the shape is an array. + // Returns whether the shape is an array. Note that scalars are considered + // arrays. static bool IsArray(const Shape& shape) { return !IsTuple(shape) && !IsOpaque(shape); } @@ -351,6 +386,10 @@ class ShapeUtil { // shape. E.g. a tuple like (f32, s32, u32) would slice via 1,3 to (s32, u32). static Shape SliceTuple(const Shape& tuple, int64 start, int64 limit); + // Returns the shape of the real/imaginary components of the given complex + // shape. + static Shape ComplexComponentShape(const Shape& complex_shape); + // Shorthand for testing whether a shape is of a given element type and // sequence of dimensions. // @@ -502,8 +541,7 @@ class ShapeUtil { CHECK_EQ(Rank(shape), base.size()); CHECK_EQ(incr.size(), base.size()); CHECK_EQ(count.size(), base.size()); - const Layout& layout = shape.layout(); - const int64 rank = layout.minor_to_major_size(); + const int64 rank = LayoutUtil::MinorToMajor(shape).size(); // Allows handling R0 arrays, such that the visitor function will be called // once with the proper empty indexes. int64 n = -1; @@ -511,7 +549,7 @@ class ShapeUtil { while (n < rank && visitor_function(indexes)) { // Increments dimensions in minor to major order. for (n = 0; n < rank; ++n) { - int64 dim = layout.minor_to_major(n); + int64 dim = LayoutUtil::Minor(shape.layout(), n); indexes[dim] += incr[dim]; if (indexes[dim] < base[dim] + count[dim]) { break; @@ -529,6 +567,8 @@ class ShapeUtil { TF_DISALLOW_COPY_AND_ASSIGN(ShapeUtil); }; +std::ostream& operator<<(std::ostream& out, const Shape& shape); + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SHAPE_UTIL_H_ diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc index 4bce7ca51d0534cbcad6faac12818c5f3e94b29e..81ba7afb95265398e830e26122cd0056a32daee3 100644 --- a/tensorflow/compiler/xla/shape_util_test.cc +++ b/tensorflow/compiler/xla/shape_util_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/types.h" @@ -71,7 +72,8 @@ TEST(ShapeUtilTest, Rank4DimensionIndexing) { TEST(ShapeUtilTest, ParseShapeStringR2F32) { string shape_string = "f32[123,456]"; - Shape actual = ShapeUtil::ParseShapeString(shape_string).ValueOrDie(); + TF_ASSERT_OK_AND_ASSIGN(Shape actual, + ShapeUtil::ParseShapeString(shape_string)); Shape expected = ShapeUtil::MakeShape(F32, {123, 456}); ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) << "expected: " << ShapeUtil::HumanString(expected) @@ -80,7 +82,8 @@ TEST(ShapeUtilTest, ParseShapeStringR2F32) { TEST(ShapeUtilTest, ParseShapeStringTupleOfArrays) { string shape_string = "(f32[1572864],s8[5120,1024])"; - Shape actual = ShapeUtil::ParseShapeString(shape_string).ValueOrDie(); + TF_ASSERT_OK_AND_ASSIGN(Shape actual, + ShapeUtil::ParseShapeString(shape_string)); Shape expected = ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {1572864}), ShapeUtil::MakeShape(S8, {5120, 1024})}); @@ -91,7 +94,8 @@ TEST(ShapeUtilTest, ParseShapeStringTupleOfArrays) { TEST(ShapeUtilTest, ParseShapeStringNestedTuple) { string shape_string = "(f32[1],(f32[2]), f32[3])"; - Shape actual = ShapeUtil::ParseShapeString(shape_string).ValueOrDie(); + TF_ASSERT_OK_AND_ASSIGN(Shape actual, + ShapeUtil::ParseShapeString(shape_string)); Shape expected = ShapeUtil::MakeTupleShape({ ShapeUtil::MakeShape(F32, {1}), ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {2})}), @@ -102,6 +106,47 @@ TEST(ShapeUtilTest, ParseShapeStringNestedTuple) { << "actual: " << ShapeUtil::HumanString(actual); } +TEST(ShapeUtilTest, ParseShapeStringWithLayout) { + string shape_string = "f32[123,456]{0,1}"; + TF_ASSERT_OK_AND_ASSIGN(Shape actual, + ShapeUtil::ParseShapeString(shape_string)); + Shape expected = ShapeUtil::MakeShapeWithLayout(F32, {123, 456}, {0, 1}); + ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) + << "expected: " << ShapeUtil::HumanString(expected) + << "actual: " << ShapeUtil::HumanString(actual); +} + +TEST(ShapeUtilTest, ParseShapeStringWithExplicitDenseLayout) { + string shape_string = "f32[123,456]dense{0,1}"; + TF_ASSERT_OK_AND_ASSIGN(Shape actual, + ShapeUtil::ParseShapeString(shape_string)); + Shape expected = ShapeUtil::MakeShapeWithLayout(F32, {123, 456}, {0, 1}); + ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) + << "expected: " << ShapeUtil::HumanString(expected) + << "actual: " << ShapeUtil::HumanString(actual); +} + +TEST(ShapeUtilTest, ParseShapeStringWithSparseLayout) { + string shape_string = "f32[123,456]sparse{10}"; + TF_ASSERT_OK_AND_ASSIGN(Shape actual, + ShapeUtil::ParseShapeString(shape_string)); + Shape expected = ShapeUtil::MakeShapeWithSparseLayout(F32, {123, 456}, 10); + ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) + << "expected: " << ShapeUtil::HumanString(expected) + << "actual: " << ShapeUtil::HumanString(actual); +} + +TEST(ShapeUtilTest, ParseInvalidShapeString) { + string shape_strings[] = { + "f32[123,456]foobar{0,1}", "f32[123,456]sparse{0,1}", "f32[123,456]{foo}", + "f32[123,456]dense{foo}", "f32[123,456]sparse{foo}", + }; + for (const string& shape_string : shape_strings) { + StatusOr result = ShapeUtil::ParseShapeString(shape_string); + ASSERT_FALSE(result.ok()) << "shape: " << shape_string; + } +} + TEST(ShapeUtilTest, CompatibleIdenticalShapes) { Shape shape1 = ShapeUtil::MakeShape(F32, {3, 2}); Shape shape2 = ShapeUtil::MakeShape(F32, {3, 2}); @@ -165,20 +210,6 @@ TEST(ShapeUtilTest, IncompatibleTuplesWithDifferentDimensions) { EXPECT_FALSE(ShapeUtil::Compatible(tuple1, tuple2)); } -TEST(ShapeUtilTest, EmptyLayoutEqualsMissingLayout) { - // A shape with a missing layout should be equal to a shape with an empty - // layout. - Shape scalar1 = ShapeUtil::MakeShape(F32, {}); - Shape scalar2 = ShapeUtil::MakeShape(F32, {}); - - EXPECT_TRUE(ShapeUtil::Equal(scalar1, scalar2)); - - scalar1.clear_layout(); // Remove layout field. - scalar2.mutable_layout(); // Create empty layout field. - - EXPECT_TRUE(ShapeUtil::Equal(scalar1, scalar2)); -} - TEST(ShapeUtilTest, CompareShapesWithPaddedDimensionsMismatch) { Shape shape1 = ShapeUtil::MakeShape(F32, {20, 30}); shape1.mutable_layout()->add_padded_dimensions(10); @@ -199,17 +230,17 @@ TEST(ShapeUtilTest, CompareShapesWithPaddingValueMismatch) { EXPECT_FALSE(ShapeUtil::Equal(shape1, shape2)); } -TEST(ShapeUtilTest, ScalarUnpopulatedLayoutEqualsScalarLayout) { - Shape scalar_unpopulated = ShapeUtil::MakeShape(F32, {}); - scalar_unpopulated.clear_layout(); - ASSERT_FALSE(scalar_unpopulated.has_layout()) - << ShapeUtil::HumanStringWithLayout(scalar_unpopulated); +TEST(ShapeUtilTest, ScalarDefaultLayoutEqualsScalarEmptyMin2Maj) { + Shape scalar_default_layout = ShapeUtil::MakeShape(F32, {}); + ASSERT_TRUE(scalar_default_layout.has_layout()) + << ShapeUtil::HumanStringWithLayout(scalar_default_layout); - const Shape scalar_populated = ShapeUtil::MakeShapeWithLayout(F32, {}, {}); - ASSERT_TRUE(scalar_populated.has_layout()) - << ShapeUtil::HumanStringWithLayout(scalar_populated); + const Shape scalar_empty_min2maj = + ShapeUtil::MakeShapeWithLayout(F32, {}, {}); + ASSERT_TRUE(scalar_empty_min2maj.has_layout()) + << ShapeUtil::HumanStringWithLayout(scalar_empty_min2maj); - EXPECT_TRUE(ShapeUtil::Equal(scalar_unpopulated, scalar_populated)); + EXPECT_TRUE(ShapeUtil::Equal(scalar_default_layout, scalar_empty_min2maj)); } TEST(ShapeUtilTest, ByteSizeOfWithoutPadding) { diff --git a/tensorflow/compiler/xla/sparse_index_array.cc b/tensorflow/compiler/xla/sparse_index_array.cc new file mode 100644 index 0000000000000000000000000000000000000000..31844abd89a020c87c403353374a80fb639a3244 --- /dev/null +++ b/tensorflow/compiler/xla/sparse_index_array.cc @@ -0,0 +1,110 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/sparse_index_array.h" + +#include "tensorflow/compiler/xla/index_util.h" +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/shape_util.h" + +namespace xla { + +SparseIndexArray::SparseIndexArray() : rank_(0), max_indices_(0) {} + +SparseIndexArray::SparseIndexArray(int64 max_indices, int64 rank, + std::vector indices) + : indices_(std::move(indices)), rank_(rank), max_indices_(max_indices) { + CHECK_GT(rank_, 0); + CHECK_EQ(indices_.size() % rank_, 0) + << "indices_.size(): " << indices_.size() << ", rank_: " << rank_; + CHECK_LT(index_count(), max_indices_); +} + +SparseIndexArray::SparseIndexArray(int64 max_indices, int64 rank, + tensorflow::gtl::ArraySlice indices) + : SparseIndexArray(max_indices, rank, + std::vector(indices.begin(), indices.end())) {} + +SparseIndexArray::SparseIndexArray(int64 max_indices, + const Array2D& indices) + : SparseIndexArray(max_indices, indices.n2(), + std::vector(indices.begin(), indices.end())) {} + +int64 SparseIndexArray::index_count() const { + CHECK_GT(rank_, 0); + CHECK_EQ(indices_.size() % rank_, 0); + return indices_.size() / rank_; +} + +tensorflow::gtl::ArraySlice SparseIndexArray::At( + int64 sparse_element_number) const { + CHECK_GT(rank_, 0); + CHECK_GE(sparse_element_number, 0); + CHECK_LE(rank_ * sparse_element_number + rank_, indices_.size()); + return tensorflow::gtl::ArraySlice( + indices_.data() + rank_ * sparse_element_number, rank_); +} + +tensorflow::gtl::MutableArraySlice SparseIndexArray::At( + int64 sparse_element_number) { + CHECK_GT(rank_, 0); + CHECK_GE(sparse_element_number, 0); + CHECK_LE(rank_ * sparse_element_number + rank_, indices_.size()); + return tensorflow::gtl::MutableArraySlice( + indices_.data() + rank_ * sparse_element_number, rank_); +} + +void SparseIndexArray::Append(tensorflow::gtl::ArraySlice index) { + CHECK_GT(rank_, 0); + CHECK_EQ(index.size(), rank_); + indices_.insert(indices_.end(), index.begin(), index.end()); +} + +void SparseIndexArray::Clear() { indices_.clear(); } + +void SparseIndexArray::Resize(int64 num_indices) { + CHECK_GT(rank_, 0); + indices_.resize(rank_ * num_indices); +} + +bool SparseIndexArray::Validate(const Shape& shape) const { + if (rank_ == 0 || rank_ != ShapeUtil::Rank(shape)) { + return false; + } + int64 num_indices = index_count(); + if (num_indices > LayoutUtil::MaxSparseElements(shape.layout())) { + return false; + } + if (num_indices < 2) { + return true; + } + tensorflow::gtl::ArraySlice last = At(0); + if (!IndexUtil::IndexInBounds(shape, last)) { + return false; + } + for (int64 n = 1; n < num_indices; ++n) { + tensorflow::gtl::ArraySlice next = At(n); + if (!IndexUtil::IndexInBounds(shape, next)) { + return false; + } + if (IndexUtil::CompareIndices(last, next) >= 0) { + return false; + } + last = next; + } + return true; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/sparse_index_array.h b/tensorflow/compiler/xla/sparse_index_array.h new file mode 100644 index 0000000000000000000000000000000000000000..903fee525520205dbd516897fe451b0fd59d3872 --- /dev/null +++ b/tensorflow/compiler/xla/sparse_index_array.h @@ -0,0 +1,176 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Utility class for managing sparse array indices. + +#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SPARSE_INDEX_ARRAY_H_ +#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SPARSE_INDEX_ARRAY_H_ + +#include + +#include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/index_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/gtl/array_slice.h" + +namespace xla { + +// Encapsulates the array of indices for a sparse array. A SparseIndexArray +// contain indices for up to `max_indices` elements of a sparse array. Each +// sparse index is an array of `rank` int64 value that gives the location of a +// value within a sparse array. Note that the dimensions of the array are not +// checked (except for the rank). To avoid confusion, we refer to the position +// of an index within a SparseIndexArray as a sparse index number. +class SparseIndexArray { + public: + SparseIndexArray(); + SparseIndexArray(const SparseIndexArray&) = default; + SparseIndexArray(SparseIndexArray&&) = default; + SparseIndexArray& operator=(const SparseIndexArray&) = default; + SparseIndexArray& operator=(SparseIndexArray&&) = default; + + // Constructs a SparseIndexArray that can hold up to `max_indices` sparse + // indices, with an initial contents obtained from the given array. The rank + // is taken from the minor dimension of the array. The major dimension of the + // array must not exceed `max_indices`. + SparseIndexArray(int64 max_indices, const Array2D& indices); + + // Like above, but the array is flattened. For example, the following are + // equivalent: + // + // SparseIndexArray(10, 3, + // Array2D{ + // {0, 1, 2}, + // {3, 4, 5}, + // {6, 7, 8}, + // {9, 10, 11}, + // }) + // + // SparseIndexArray(10, 3, + // {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}) + // + SparseIndexArray(int64 max_indices, int64 rank, + std::vector indices = {}); + SparseIndexArray(int64 max_indices, int64 rank, + tensorflow::gtl::ArraySlice indices); + + // Returns the number of elements represented by the indices stored in the + // array. + int64 index_count() const; + + // Returns a slice that refers to the given sparse index number. The argument + // must be in the range [0, element_count()). + tensorflow::gtl::ArraySlice At(int64 sparse_element_number) const; + tensorflow::gtl::MutableArraySlice At(int64 sparse_element_number); + + // Adds the given index at the end of the array. The new size of the + // SparseIndexArray must not exceed `max_indices`. + void Append(tensorflow::gtl::ArraySlice index); + + // Removes all indices from the array. + void Clear(); + + // Resizes the array to contain the given number of sparse indices. The new + // size must be smaller than `max_indices`. If the new size is larger than + // the old size, the value of the new indices is not specified. + void Resize(int64 num_indices); + + // Returns true iff all indices are unique and occur in sorted order, and are + // valid for the given shape. + bool Validate(const Shape& shape) const; + + int64 rank() const { return rank_; } + int64 max_indices() const { return max_indices_; } + + // Returns a pointer to the int64 array that holds the sparse indices. + tensorflow::gtl::MutableArraySlice mutable_data() { return &indices_; } + tensorflow::gtl::ArraySlice data() const { return indices_; } + + // Sorts this sparse index array along with the set of corresponding values. + // The indices and values are sorted in the lexicographic order of the + // indices, from smallest to largest. + // + // For example: + // + // std::vector v{10.0, 11.0, 12.0}; + // SparseIndexArray a(10, 3, + // {{3, 4, 5}, + // {1, 2, 3}, + // {2, 3, 4}}); + // a.SortWithValues(&v); + // // Prints "11.0, 12.0, 10.0": + // std::cout << v[0] << ", " << v[1] << ", " << v[2] << std::endl; + // + template + void SortWithValues(tensorflow::gtl::MutableArraySlice values); + + private: + std::vector indices_; + int64 rank_; + int64 max_indices_; +}; + +template +void SparseIndexArray::SortWithValues( + tensorflow::gtl::MutableArraySlice values) { + int64 num_elements = index_count(); + CHECK_EQ(values.size(), num_elements); + std::vector sort_order; + sort_order.reserve(num_elements); + for (int64 i = 0; i < num_elements; ++i) { + sort_order.push_back(i); + } + auto sort_order_less = [this](int64 lhs, int64 rhs) { + return IndexUtil::CompareIndices(At(lhs), At(rhs)) < 0; + }; + std::sort(sort_order.begin(), sort_order.end(), sort_order_less); + + // Reorder the array elements according to sort_order. Work through the array + // and follow cycles so we can do the reorder in-place. + tensorflow::gtl::InlinedVector saved_index(rank()); + for (int64 i = 0; i < num_elements; ++i) { + // sort_order[i] == -1 indicates the element has already been copied. + if (sort_order[i] < 0) { + continue; + } else if (i == sort_order[i]) { + // The element is already in sorted order. + sort_order[i] = -1; + continue; + } + + std::copy_n(At(i).begin(), rank(), saved_index.begin()); + NativeT saved_value = values[i]; + int64 j = i; + for (;;) { + if (sort_order[j] == i) { + std::copy_n(saved_index.begin(), rank(), At(j).begin()); + values[j] = saved_value; + sort_order[j] = -1; + break; + } + + std::copy_n(At(sort_order[j]).begin(), rank(), At(j).begin()); + values[j] = values[sort_order[j]]; + + int64 k = sort_order[j]; + sort_order[j] = -1; + j = k; + } + } +} + +} // namespace xla + +#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SPARSE_INDEX_ARRAY_H_ diff --git a/tensorflow/compiler/xla/sparse_index_array_test.cc b/tensorflow/compiler/xla/sparse_index_array_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..7377f88958dcb7daf3d3f4f0e07966fdc9294580 --- /dev/null +++ b/tensorflow/compiler/xla/sparse_index_array_test.cc @@ -0,0 +1,43 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/sparse_index_array.h" + +#include + +#include "tensorflow/compiler/xla/test.h" + +namespace xla { +namespace { + +TEST(SparseIndexArrayTest, Sort) { + SparseIndexArray a(10, 3); + a.Append({2, 3, 4}); + a.Append({3, 4, 5}); + a.Append({1, 2, 3}); + a.Append({5, 6, 7}); + a.Append({4, 5, 6}); + a.Append({6, 7, 8}); + std::vector values = { + 12.0, 13.0, 11.0, 15.0, 14.0, 16.0, + }; + a.SortWithValues(&values); + ASSERT_EQ(a.data(), std::vector({1, 2, 3, 2, 3, 4, 3, 4, 5, 4, 5, 6, 5, + 6, 7, 6, 7, 8})); + ASSERT_EQ(values, std::vector({11.0, 12.0, 13.0, 14.0, 15.0, 16.0})); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/status_macros.h b/tensorflow/compiler/xla/status_macros.h index 5e5550563d02de99ddefbeb8ee8e1bf98afdcdbf..e51dd64e2a3dc7c359918cb08c6c94b2b4d9e91b 100644 --- a/tensorflow/compiler/xla/status_macros.h +++ b/tensorflow/compiler/xla/status_macros.h @@ -196,18 +196,8 @@ class StatusAdaptorForMacros { #define TF_STATUS_MACROS_CONCAT_NAME(x, y) TF_STATUS_MACROS_CONCAT_IMPL(x, y) #define TF_STATUS_MACROS_CONCAT_IMPL(x, y) x##y -#define TF_ASSIGN_OR_RETURN(...) \ - TF_STATUS_MACRO_GET_VARIADIC_IMPL(__VA_ARGS__, TF_ASSIGN_OR_RETURN_IMPL_3, \ - TF_ASSIGN_OR_RETURN_IMPL_2) \ - (__VA_ARGS__) - -#define TF_STATUS_MACRO_GET_VARIADIC_IMPL(_1, _2, _3, NAME, ...) NAME - -#define TF_ASSIGN_OR_RETURN_IMPL_2(lhs, rexpr) \ - TF_ASSIGN_OR_RETURN_IMPL_3(lhs, rexpr) - -#define TF_ASSIGN_OR_RETURN_IMPL_3(lhs, rexpr) \ - TF_ASSIGN_OR_RETURN_IMPL( \ +#define TF_ASSIGN_OR_RETURN(lhs, rexpr) \ + TF_ASSIGN_OR_RETURN_IMPL( \ TF_STATUS_MACROS_CONCAT_NAME(_status_or_value, __COUNTER__), lhs, rexpr) #define TF_ASSIGN_OR_RETURN_IMPL(statusor, lhs, rexpr) \ diff --git a/tensorflow/compiler/xla/statusor_test.cc b/tensorflow/compiler/xla/statusor_test.cc index 5fa2211ac66177514ac8ecabfa8791e7c8c014a2..f9d25945bc617507735fb6c4d011c39723497f69 100644 --- a/tensorflow/compiler/xla/statusor_test.cc +++ b/tensorflow/compiler/xla/statusor_test.cc @@ -32,26 +32,26 @@ namespace { class Base1 { public: virtual ~Base1() {} - int pad; + int pad_; }; class Base2 { public: virtual ~Base2() {} - int yetotherpad; + int yetotherpad_; }; class Derived : public Base1, public Base2 { public: ~Derived() override {} - int evenmorepad; + int evenmorepad_; }; class CopyNoAssign { public: - explicit CopyNoAssign(int value) : foo(value) {} - CopyNoAssign(const CopyNoAssign& other) : foo(other.foo) {} - int foo; + explicit CopyNoAssign(int value) : foo_(value) {} + CopyNoAssign(const CopyNoAssign& other) : foo_(other.foo_) {} + int foo_; private: const CopyNoAssign& operator=(const CopyNoAssign&); @@ -253,7 +253,7 @@ TEST(StatusOr, TestCopyCtorNonAssignable) { StatusOr original(value); StatusOr copy(original); EXPECT_EQ(copy.status(), original.status()); - EXPECT_EQ(original.ValueOrDie().foo, copy.ValueOrDie().foo); + EXPECT_EQ(original.ValueOrDie().foo_, copy.ValueOrDie().foo_); } TEST(StatusOr, TestCopyCtorStatusOKConverting) { diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index addce9019b340f9489a25dbdd2437f4d71740b95..3922c779a0979c493df84431bf97c1da57717443 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -69,6 +69,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_dataflow_analysis", "//tensorflow/compiler/xla/service:hlo_verifier", "//tensorflow/compiler/xla/service:transfer_manager", "//tensorflow/core:lib", @@ -104,7 +105,9 @@ cc_library( hdrs = ["hlo_test_base.h"], deps = [ ":literal_test_util", + ":test_utils", "//tensorflow/compiler/xla:shape_layout", + "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", @@ -114,6 +117,10 @@ cc_library( "//tensorflow/compiler/xla/service:computation_layout", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_runner", + "//tensorflow/compiler/xla/service:hlo_verifier", + "//tensorflow/compiler/xla/service:interpreter_plugin", # reference backend + "//tensorflow/compiler/xla/service:platform_util", + "//tensorflow/compiler/xla/tools/parser:hlo_parser", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core:test", @@ -338,6 +345,23 @@ xla_test( ], ) +xla_test( + name = "xla_hlo_profile_test", + srcs = ["xla_hlo_profile_test.cc"], + deps = [ + "//tensorflow/compiler/xla:array2d", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/service:platform_util", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:test_utils", + "//tensorflow/core:lib", + "//tensorflow/core:regexp_internal", + "//tensorflow/core:test", + ], +) + xla_test( name = "axpy_simple_test", srcs = ["axpy_simple_test.cc"], @@ -354,6 +378,7 @@ xla_test( xla_test( name = "map_test", srcs = ["map_test.cc"], + tags = ["enable_for_xla_interpreter"], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:literal_util", @@ -382,6 +407,7 @@ xla_test( name = "params_test", srcs = ["params_test.cc"], shard_count = 30, + tags = ["optonly"], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:literal_util", @@ -430,6 +456,22 @@ xla_test( ], ) +xla_test( + name = "conditional_test", + srcs = ["conditional_test.cc"], + deps = [ + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:global_data", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + xla_test( name = "unary_op_test", srcs = ["unary_op_test.cc"], @@ -774,8 +816,6 @@ xla_test( name = "bfloat16_test", srcs = ["bfloat16_test.cc"], blacklisted_backends = [ - "cpu", - "cpu_parallel", "gpu", ], shard_count = 40, @@ -961,7 +1001,10 @@ xla_test( name = "reduce_window_test", timeout = "long", srcs = [], - tags = ["optonly"], + tags = [ + "enable_for_xla_interpreter", + "optonly", + ], xla_test_library_deps = [":reduce_window_test_library"], deps = [], ) @@ -1036,9 +1079,10 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service/cpu:custom_call_target_registry", + "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep "//tensorflow/core:lib", "//tensorflow/core:test", ], @@ -1364,6 +1408,31 @@ xla_test( ], ) +xla_test( + name = "execution_profile_test", + srcs = ["execution_profile_test.cc"], + deps = [ + ":client_library_test_base", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:global_data", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", + ], +) + +xla_test( + name = "execution_profile_test_with_xla_hlo_profile", + srcs = ["execution_profile_test.cc"], + args = ["--xla_hlo_profile"], + deps = [ + ":client_library_test_base", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:global_data", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", + ], +) + xla_test( name = "replay_test", srcs = ["replay_test.cc"], @@ -1676,6 +1745,45 @@ xla_test( ], ) +# A demo of textual IR based test. +xla_test( + name = "sample_text_test", + srcs = ["sample_text_test.cc"], + # You can leave this empty if you want to test all supported backends. + backends = [ + "cpu", + "gpu", + ], + deps = [ + ":hlo_test_base", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + ], +) + +# A demo of test that loads an hlo module from a file and compares results on gpu and cpu. +tf_cc_test( + name = "sample_file_test", + srcs = ["sample_file_test.cc"], + data = ["isolated_convolution.hlo"], + tags = ["requires-gpu-sm35"], + deps = [ + ":hlo_test_base", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla/service:cpu_plugin", # reference backend + "//tensorflow/compiler/xla/service:gpu_plugin", # test backend + "//tensorflow/compiler/xla/service:platform_util", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + # ----------------------------------------------------------------------------- filegroup( diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc index c6e8b24d1211743d07878d388522feacf9c0e7f1..56fc21d019bb823f8f4631420a15fd607ef46a9a 100644 --- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc +++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc @@ -1971,6 +1971,18 @@ XLA_TEST_F(ArrayElementwiseOpTest, SinF32s) { error_spec_); } +XLA_TEST_F(ArrayElementwiseOpTest, Atan2F32s) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({0.0f, 5.0f, 0.0f, -3.0f, 2.0f, -8.0f}); + auto b = builder.ConstantR1({6.0f, 0.0f, -4.0f, 0.0f, 2.0f, 8.0f}); + auto atan = builder.Atan2(a, b); + + ComputeAndCompareR1( + &builder, + {0.0f, 1.57079633f, 3.14159265f, -1.57079633f, 0.78539816f, -0.78539816f}, + {}, error_spec_); +} + XLA_TEST_F(ArrayElementwiseOpTest, TanhF32s) { ComputationBuilder builder(client_, TestName()); auto a = builder.ConstantR1({-2.5f, 3.14f, 2.25f}); @@ -2520,9 +2532,8 @@ XLA_TEST_F(ArrayElementwiseOpTest, R4_16x16x2x2_Plus_R1_16) { std::iota(r1.begin(), r1.end(), 1.0); ComputationBuilder builder(client_, TestName()); - std::unique_ptr a_literal = Literal::CreateR4FromArray4D(r4); - *a_literal->mutable_shape()->mutable_layout() = - LayoutUtil::MakeLayout({0, 1, 2, 3}); + std::unique_ptr a_literal = Literal::CreateR4FromArray4DWithLayout( + r4, LayoutUtil::MakeLayout({0, 1, 2, 3})); auto a = builder.ConstantLiteral(*a_literal); auto b = builder.ConstantR1(r1); builder.Add(a, b, {1}); diff --git a/tensorflow/compiler/xla/tests/batch_normalization_test.cc b/tensorflow/compiler/xla/tests/batch_normalization_test.cc index 028d1251b455b82a291c236f7866e52e27d3590e..28ab9654997728fbafd6610af840e721e72cce5a 100644 --- a/tensorflow/compiler/xla/tests/batch_normalization_test.cc +++ b/tensorflow/compiler/xla/tests/batch_normalization_test.cc @@ -39,6 +39,8 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/math/math_util.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -46,9 +48,13 @@ limitations under the License. namespace xla { namespace { -class BatchNormalizationTest : public ClientLibraryTestBase { +class BatchNormalizationTest + : public ClientLibraryTestBase, + public ::testing::WithParamInterface { protected: BatchNormalizationTest() : input_array_(kSamples, kZ, kY, kX) { + mutable_debug_options()->set_xla_gpu_use_cudnn_batchnorm(GetParam()); + Array2D pz({ // z0 z1 {-1.0f, 4.1f}, // p0 @@ -56,7 +62,7 @@ class BatchNormalizationTest : public ClientLibraryTestBase { {5.0f, 4.4f}, // p2 }); input_array_.FillWithPZ(pz); - input_literal_ = *Literal::CreateR4FromArray4D(input_array_); + input_literal_ = std::move(*Literal::CreateR4FromArray4D(input_array_)); CHECK_EQ(kSamples, input_array_.planes()); CHECK_EQ(kZ, input_array_.depth()); CHECK_EQ(kY, input_array_.height()); @@ -73,7 +79,18 @@ class BatchNormalizationTest : public ClientLibraryTestBase { const ErrorSpec error_spec_{0.001, 0.001}; }; -TEST_F(BatchNormalizationTest, SubtractInZ) { +// If testing the GPU backend, run the tests twice, with and without cudnn +// batchnorm. Otherwise, just run the tests once -- the value of this flag +// doesn't matter. +#ifdef XLA_TEST_BACKEND_GPU +INSTANTIATE_TEST_CASE_P(BatchNormalizationTestInstance, BatchNormalizationTest, + ::testing::Bool()); +#else +INSTANTIATE_TEST_CASE_P(BatchNormalizationTestInstance, BatchNormalizationTest, + ::testing::Values(false)); +#endif + +XLA_TEST_P(BatchNormalizationTest, SubtractInZ) { ComputationBuilder builder(client_, "subtract_in_z_one_sample"); auto x = builder.ConstantLiteral(input_literal_); auto y = builder.ConstantR1({3.14, 4.25}); @@ -89,22 +106,24 @@ TEST_F(BatchNormalizationTest, SubtractInZ) { ComputeAndCompareR4(&builder, expected, {}, error_spec_); } -TEST_F(BatchNormalizationTest, SquareTesseractElementwise) { +XLA_TEST_P(BatchNormalizationTest, SquareTesseractElementwise) { ComputationBuilder builder(client_, "square_tesseract_elementwise"); auto x = builder.ConstantLiteral(input_literal_); builder.SquareF32(x); + using tensorflow::MathUtil; + Array4D expected(kSamples, kZ, kY, kX); Array2D expected_pz({ - {std::pow(-1.0f, 2.0f), std::pow(4.1f, 2.0f)}, - {std::pow(2.0f, 2.0f), std::pow(4.1f, 2.0f)}, - {std::pow(5.0f, 2.0f), std::pow(4.4f, 2.0f)}, + {MathUtil::IPow(-1.0f, 2), MathUtil::IPow(4.1f, 2)}, + {MathUtil::IPow(2.0f, 2), MathUtil::IPow(4.1f, 2)}, + {MathUtil::IPow(5.0f, 2), MathUtil::IPow(4.4f, 2)}, }); expected.FillWithPZ(expected_pz); ComputeAndCompareR4(&builder, expected, {}, error_spec_); } -TEST_F(BatchNormalizationTest, SumToZ) { +XLA_TEST_P(BatchNormalizationTest, SumToZ) { ComputationBuilder builder(client_, "sum_to_z"); auto input_activations = builder.ConstantLiteral(input_literal_); Computation add = CreateScalarAddComputation(F32, &builder); @@ -116,7 +135,7 @@ TEST_F(BatchNormalizationTest, SumToZ) { ComputeAndCompareR1(&builder, expected, {}, error_spec_); } -TEST_F(BatchNormalizationTest, SquareAndReduce) { +XLA_TEST_P(BatchNormalizationTest, SquareAndReduce) { ComputationBuilder builder(client_, "square_and_reduce"); auto input_activations = builder.ConstantLiteral(input_literal_); auto set_means = builder.ConstantR1({2.f, 4.2f}); @@ -131,7 +150,7 @@ TEST_F(BatchNormalizationTest, SquareAndReduce) { ComputeAndCompareR1(&builder, expected, {}, error_spec_); } -TEST_F(BatchNormalizationTest, VarianceToStddev) { +XLA_TEST_P(BatchNormalizationTest, VarianceToStddev) { ComputationBuilder builder(client_, "variance_to_stddev"); auto variance = builder.ConstantR1({6.f, .02f}); auto sqrt = builder.SqrtF32(variance); @@ -142,7 +161,7 @@ TEST_F(BatchNormalizationTest, VarianceToStddev) { // Compare against a forward batch normalization example in the NN spec // reference. -TEST_F(BatchNormalizationTest, SpecComparisonForward) { +XLA_TEST_P(BatchNormalizationTest, SpecComparisonForward) { ComputationBuilder builder(client_, "batch_normalize_per_spec"); auto input_activations = builder.CheckShape(builder.ConstantLiteral(input_literal_), @@ -198,19 +217,227 @@ TEST_F(BatchNormalizationTest, SpecComparisonForward) { ComputeAndCompareR4(&builder, expected, {}, error_spec_); } +XLA_TEST_P(BatchNormalizationTest, BasicTraining) { + const int kFeatureIndex = 3; + ComputationBuilder builder(client_, TestName()); + + auto operand = builder.ConstantR4FromArray4D( + {{{{1.f, 2.f}}, {{3.f, 4.f}}}, {{{5.f, 6.f}}, {{7.f, 8.f}}}}); + + auto scale = builder.ConstantR1({2.0f, 3.0f}); + + auto offset = builder.ConstantR1({1.0f, 2.0f}); + + auto tuple = builder.BatchNormTraining(operand, scale, offset, + /*epsilon=*/0.001, kFeatureIndex); + + auto expected = Literal::MakeTuple( + {Literal::CreateR4({{{{-1.6f, -2.0f}}, {{0.1f, 0.6f}}}, + {{{1.9f, 3.3f}}, {{3.7f, 6.0f}}}}) + .get(), + Literal::CreateR1({4, 5}).get(), + Literal::CreateR1({5, 5}).get()}); + + ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.1)); +} + +XLA_TEST_P(BatchNormalizationTest, BasicTrainingOnSublane) { + const int kFeatureIndex = 2; + ComputationBuilder builder(client_, TestName()); + + auto operand = builder.ConstantR4FromArray4D( + {{{{1.f}, {2.f}}, {{3.f}, {4.f}}}, {{{5.f}, {6.f}}, {{7.f}, {8.f}}}}); + + auto scale = builder.ConstantR1({2.0f, 3.0f}); + + auto offset = builder.ConstantR1({1.0f, 2.0f}); + + auto tuple = builder.BatchNormTraining(operand, scale, offset, + /*epsilon=*/0.001, kFeatureIndex); + + auto expected = Literal::MakeTuple( + {Literal::CreateR4({{{{-1.6f}, {-2.0f}}, {{0.1f}, {0.6f}}}, + {{{1.9f}, {3.3f}}, {{3.7f}, {6.0f}}}}) + .get(), + Literal::CreateR1({4, 5}).get(), + Literal::CreateR1({5, 5}).get()}); + + ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.1)); +} + +XLA_TEST_P(BatchNormalizationTest, TrainingWithFeatureOnLowDimension) { + // Use 0 dimension as feature, tests layout analyzer. + const int kFeatureIndex = 0; + ComputationBuilder builder(client_, TestName()); + + ComputationDataHandle h0; + auto operand = CreateR3Parameter(Array3D(260, 2, 2, 1.0f), + /*parameter_number=*/0, "operand", + &builder, &h0); + ComputationDataHandle h1; + auto scale = + CreateR1Parameter(std::vector(260, 1.0f), + /*parameter_number=*/1, "scale", &builder, &h1); + ComputationDataHandle h2; + auto offset = + CreateR1Parameter(std::vector(260, 1.0f), + /*parameter_number=*/2, "offset", &builder, &h2); + + auto tuple = builder.BatchNormTraining(h0, h1, h2, + /*epsilon=*/1, kFeatureIndex); + + auto expected = Literal::MakeTuple( + {Literal::CreateR3FromArray3D(Array3D(260, 2, 2, 1.0f)) + .get(), + Literal::CreateR1(std::vector(260, 1.0f)).get(), + Literal::CreateR1(std::vector(260, 0.0f)).get()}); + + ComputeAndCompareTuple(&builder, *expected, + {operand.get(), scale.get(), offset.get()}, + ErrorSpec(0.1)); +} + +XLA_TEST_P(BatchNormalizationTest, LargeEpsilonTest) { + // Test the correctness of choosing a large epsilon value. + const int kFeatureIndex = 2; + ComputationBuilder builder(client_, TestName()); + + ComputationDataHandle h0; + auto operand = CreateR3Parameter({{{0.0f}, {10.0f}, {20.0f}, {30.0f}}}, + /*parameter_number=*/0, "operand", + &builder, &h0); + ComputationDataHandle h1; + auto scale = + CreateR1Parameter(std::vector(1, 1.0f), + /*parameter_number=*/1, "scale", &builder, &h1); + ComputationDataHandle h2; + auto offset = + CreateR1Parameter(std::vector(1, 0.0f), + /*parameter_number=*/2, "offset", &builder, &h2); + + // var = 125, mean = 15, epsilon = -100 + auto tuple = builder.BatchNormTraining(h0, h1, h2, + /*epsilon=*/-100, kFeatureIndex); + + auto expected = Literal::MakeTuple( + {Literal::CreateR3FromArray3D({{{-3.0f}, {-1.0f}, {1.0f}, {3.0f}}}) + .get(), + Literal::CreateR1(std::vector(1, 15.0f)).get(), + Literal::CreateR1(std::vector(1, 125.0f)).get()}); + + ComputeAndCompareTuple(&builder, *expected, + {operand.get(), scale.get(), offset.get()}, + ErrorSpec(0.1)); +} + +XLA_TEST_P(BatchNormalizationTest, BatchNormGradBasic) { + const int kFeatureIndex = 2; + ComputationBuilder builder(client_, TestName()); + + auto operand = + builder.ConstantR4FromArray4D(Array4D(2, 2, 2, 1, 0.0f)); + + auto scale = builder.ConstantR1({1.0f, 1.0f}); + + auto mean = builder.ConstantR1({0.0f, 0.0f}); + + auto var = builder.ConstantR1({1.0f, 1.0f}); + + auto grad_output = builder.ConstantR4FromArray4D( + {{{{1.f}, {2.f}}, {{3.f}, {4.f}}}, {{{5.f}, {6.f}}, {{7.f}, {8.f}}}}); + + builder.BatchNormGrad(operand, scale, mean, var, grad_output, + /*epsilon=*/0.0, kFeatureIndex); + + auto expected = Literal::MakeTuple( + {Literal::CreateR4({{{{-3.f}, {-3.f}}, {{-1.f}, {-1.f}}}, + {{{1.f}, {1.f}}, {{3.f}, {3.f}}}}) + .get(), + Literal::CreateR1({0, 0}).get(), + Literal::CreateR1({16, 20}).get()}); + + ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.1)); +} + struct BatchNormTestParam { std::vector bounds; int64 feature_index; float random_value_mean; float random_value_var; + bool use_cudnn_batchnorm; + + friend ::std::ostream& operator<<(::std::ostream& os, + const BatchNormTestParam& p) { + os << "bounds={" << tensorflow::str_util::Join(p.bounds, ", ") << "}, "; + os << "feature_index=" << p.feature_index << ", "; + os << "random_value_mean=" << p.random_value_mean << ", "; + os << "random_value_var=" << p.random_value_var; + + // Don't print use_cudnn_batchnorm when it's false, because most backends + // never set it to true. + if (p.use_cudnn_batchnorm) { + os << ", use_cudnn_batchnorm=true"; + } + return os; + } }; // Tests to test the fused operation of BatchNorm. -class BatchNormTest : public ClientLibraryTestBase, - public ::testing::WithParamInterface { +class BatchNormTestManySizes + : public ClientLibraryTestBase, + public ::testing::WithParamInterface { + public: + BatchNormTestManySizes() { + mutable_debug_options()->set_xla_gpu_use_cudnn_batchnorm( + GetParam().use_cudnn_batchnorm); + } }; -XLA_TEST_P(BatchNormTest, RandomizedTests) { +std::vector BuildBatchNormTestParams() { + std::vector params; + + auto add_testcase = [&](std::vector bounds, int64 feature_index, + float random_value_mean, float random_value_var) { + BatchNormTestParam p{bounds, feature_index, random_value_mean, + random_value_var, /*use_cudnn_batchnorm=*/false}; + params.push_back(p); + + // If testing the GPU backend, also run with cudnn batchnorm enabled. +#ifdef XLA_TEST_BACKEND_GPU + p.use_cudnn_batchnorm = true; + params.push_back(p); +#endif + }; + + add_testcase({2, 2, 2, 2}, 0, 100.2f, 200.0f); + add_testcase({2, 2, 2, 2}, 3, 300.f, 400.0f); + + add_testcase({1, 10, 1, 1}, 0, 10.1f, 20.1f); + add_testcase({10, 10, 10, 10}, 1, 3.14f, 314.15f); + add_testcase({10, 10, 10, 10}, 2, 666.6f, 777.7f); + add_testcase({10, 10, 10, 10}, 1, -666.6f, 777.7f); + add_testcase({10, 10, 10, 10}, 2, 0.f, 777.7f); + add_testcase({1, 1, 10, 130}, 2, 0.f, 777.7f); + add_testcase({1, 1, 130, 11}, 2, 0.f, 777.7f); + add_testcase({1, 1, 10, 1}, 3, 888.8f, 9.9f); + + add_testcase({24, 129, 1, 2}, 2, 10000, 10000); + add_testcase({24, 129, 1, 2}, 3, 10000, 10000); + + // Feature on low dimension to trigger relayout, check that internal logical + // to physical dimension calculation is correct after relayout. + add_testcase({1, 2, 3, 4}, 0, 100, 100); + + // Zero-sized tensor. + add_testcase({1, 0, 100, 42}, 0, 100, 100); + + return params; +} + +INSTANTIATE_TEST_CASE_P(BatchNormTest_Instantiation, BatchNormTestManySizes, + ::testing::ValuesIn(BuildBatchNormTestParams())); + +XLA_TEST_P(BatchNormTestManySizes, RandomizedTrainingTests) { float epsilon = 0.001; ComputationBuilder builder(client_, TestName()); const std::vector& bounds = GetParam().bounds; @@ -286,9 +513,9 @@ XLA_TEST_P(BatchNormTest, RandomizedTests) { auto offset_activations = builder.Parameter(2, offset_literal->shape(), "scale"); - auto expected = *Literal::MakeTuple({expected_normalized.get(), - Literal::CreateR1(mean).get(), - Literal::CreateR1(var).get()}); + auto expected = Literal::MakeTuple({expected_normalized.get(), + Literal::CreateR1(mean).get(), + Literal::CreateR1(var).get()}); std::unique_ptr input_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); @@ -300,13 +527,17 @@ XLA_TEST_P(BatchNormTest, RandomizedTests) { builder.BatchNormTraining(input_activations, scale_activations, offset_activations, epsilon, feature_index); + // Run all HLO passes during this test. In particular, ClientLibraryTestBase + // disables constant folding, but we want it enabled for our zero-sized tensor + // testcase. + execution_options_.mutable_debug_options()->clear_xla_disable_hlo_passes(); ComputeAndCompareTuple( - &builder, expected, + &builder, *expected, {input_data.get(), scale_data.get(), offset_data.get()}, ErrorSpec(0.01, 1)); } -XLA_TEST_P(BatchNormTest, RandomizedInferencingTests) { +XLA_TEST_P(BatchNormTestManySizes, RandomizedInferencingTests) { float epsilon = 0.001; ComputationBuilder builder(client_, TestName()); const std::vector& bounds = GetParam().bounds; @@ -402,6 +633,11 @@ XLA_TEST_P(BatchNormTest, RandomizedInferencingTests) { offset_activations, mean_activations, variance_activations, epsilon, feature_index); + // Run all HLO passes during this test. In particular, ClientLibraryTestBase + // disables constant folding, but we want it enabled for our zero-sized tensor + // testcase. + execution_options_.mutable_debug_options()->clear_xla_disable_hlo_passes(); + ComputeAndCompareR4( &builder, expected, {input_data.get(), scale_data.get(), offset_data.get(), mean_data.get(), @@ -409,7 +645,7 @@ XLA_TEST_P(BatchNormTest, RandomizedInferencingTests) { ErrorSpec(0.01, 1)); } -XLA_TEST_P(BatchNormTest, RandomizedGradTests) { +XLA_TEST_P(BatchNormTestManySizes, RandomizedGradTests) { float epsilon = 0.001; ComputationBuilder builder(client_, TestName()); const std::vector& bounds = GetParam().bounds; @@ -447,7 +683,11 @@ XLA_TEST_P(BatchNormTest, RandomizedGradTests) { std::vector mean(feature_bound); for (int64 i = 0; i < feature_bound; ++i) { - mean[i] = sum[i] / num_elements_per_feature; + if (num_elements_per_feature > 0) { + mean[i] = sum[i] / num_elements_per_feature; + } else { + mean[i] = 0; + } } std::vector mean_square(feature_bound); @@ -457,7 +697,11 @@ XLA_TEST_P(BatchNormTest, RandomizedGradTests) { std::vector square_mean(feature_bound); for (int64 i = 0; i < feature_bound; ++i) { - square_mean[i] = sum_squared[i] / num_elements_per_feature; + if (num_elements_per_feature > 0) { + square_mean[i] = sum_squared[i] / num_elements_per_feature; + } else { + square_mean[i] = 0; + } } std::vector var(feature_bound); @@ -535,8 +779,12 @@ XLA_TEST_P(BatchNormTest, RandomizedGradTests) { grad_activation, scale4D, [](float a, float b) { return a * b; }); grad_activation = *ReferenceUtil::MapArray4D( - grad_activation, rsqrt_var_add_epsilon, - [=](float a, float b) { return a * b / num_elements_per_feature; }); + grad_activation, rsqrt_var_add_epsilon, [=](float a, float b) { + if (num_elements_per_feature > 0) { + return a * b / num_elements_per_feature; + } + return 0.f; + }); auto expected_grad_activation = Literal::CreateR4FromArray4D(grad_activation); @@ -571,179 +819,20 @@ XLA_TEST_P(BatchNormTest, RandomizedGradTests) { grad_output_parameter, epsilon, feature_index); auto expected = - *Literal::MakeTuple({expected_grad_activation.get(), - Literal::CreateR1(grad_scale).get(), - Literal::CreateR1(grad_offset).get()}); + Literal::MakeTuple({expected_grad_activation.get(), + Literal::CreateR1(grad_scale).get(), + Literal::CreateR1(grad_offset).get()}); + + // Run all HLO passes during this test. In particular, ClientLibraryTestBase + // disables constant folding, but we want it enabled for our zero-sized tensor + // testcase. + execution_options_.mutable_debug_options()->clear_xla_disable_hlo_passes(); - ComputeAndCompareTuple(&builder, expected, + ComputeAndCompareTuple(&builder, *expected, {input_data.get(), scale_data.get(), mean_data.get(), var_data.get(), grad_output_data.get()}, ErrorSpec(0.01, 1)); } -INSTANTIATE_TEST_CASE_P( - BatchNormTest_Instantiation, BatchNormTest, - ::testing::Values(BatchNormTestParam{{2, 2, 2, 2}, 0, 100.2f, 200.0f}, - BatchNormTestParam{{2, 2, 2, 2}, 3, 300.f, 400.0f}, - - BatchNormTestParam{{1, 10, 1, 1}, 0, 10.1f, 20.1f}, - BatchNormTestParam{{10, 10, 10, 10}, 1, 3.14f, 314.15f}, - BatchNormTestParam{{10, 10, 10, 10}, 2, 666.6f, 777.7f}, - BatchNormTestParam{{10, 10, 10, 10}, 1, -666.6f, 777.7f}, - BatchNormTestParam{{10, 10, 10, 10}, 2, 0.f, 777.7f}, - BatchNormTestParam{{1, 1, 10, 130}, 2, 0.f, 777.7f}, - BatchNormTestParam{{1, 1, 130, 11}, 2, 0.f, 777.7f}, - BatchNormTestParam{{1, 1, 10, 1}, 3, 888.8f, 9.9f}, - - BatchNormTestParam{{24, 129, 1, 2}, 2, 10000, 10000}, - BatchNormTestParam{{24, 129, 1, 2}, 3, 10000, 10000}, - - // Feature on low dimension to trigger relayout, test - // internal logical to physical dimension calculation - // is correct after relayout. - BatchNormTestParam{{1, 2, 3, 4}, 0, 100, 100})); - -XLA_TEST_F(BatchNormTest, BasicTraining) { - const int kFeatureIndex = 3; - ComputationBuilder builder(client_, TestName()); - - auto operand = builder.ConstantR4FromArray4D( - {{{{1.f, 2.f}}, {{3.f, 4.f}}}, {{{5.f, 6.f}}, {{7.f, 8.f}}}}); - - auto scale = builder.ConstantR1({2.0f, 3.0f}); - - auto offset = builder.ConstantR1({1.0f, 2.0f}); - - auto tuple = builder.BatchNormTraining(operand, scale, offset, - /*epsilon=*/0.001, kFeatureIndex); - - auto expected = *Literal::MakeTuple( - {Literal::CreateR4({{{{-1.6f, -2.0f}}, {{0.1f, 0.6f}}}, - {{{1.9f, 3.3f}}, {{3.7f, 6.0f}}}}) - .get(), - Literal::CreateR1({4, 5}).get(), - Literal::CreateR1({5, 5}).get()}); - - ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.1)); -} - -XLA_TEST_F(BatchNormTest, BasicTrainingOnSublane) { - const int kFeatureIndex = 2; - ComputationBuilder builder(client_, TestName()); - - auto operand = builder.ConstantR4FromArray4D( - {{{{1.f}, {2.f}}, {{3.f}, {4.f}}}, {{{5.f}, {6.f}}, {{7.f}, {8.f}}}}); - - auto scale = builder.ConstantR1({2.0f, 3.0f}); - - auto offset = builder.ConstantR1({1.0f, 2.0f}); - - auto tuple = builder.BatchNormTraining(operand, scale, offset, - /*epsilon=*/0.001, kFeatureIndex); - - auto expected = *Literal::MakeTuple( - {Literal::CreateR4({{{{-1.6f}, {-2.0f}}, {{0.1f}, {0.6f}}}, - {{{1.9f}, {3.3f}}, {{3.7f}, {6.0f}}}}) - .get(), - Literal::CreateR1({4, 5}).get(), - Literal::CreateR1({5, 5}).get()}); - - ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.1)); -} - -XLA_TEST_F(BatchNormTest, DISABLED_ON_GPU(TrainingWithFeatureOnLowDimension)) { - // Use 0 dimension as feature, tests layout analyzer. - const int kFeatureIndex = 0; - ComputationBuilder builder(client_, TestName()); - - ComputationDataHandle h0; - auto operand = CreateR3Parameter(Array3D(260, 2, 2, 1.0f), - /*parameter_number=*/0, "operand", - &builder, &h0); - ComputationDataHandle h1; - auto scale = - CreateR1Parameter(std::vector(260, 1.0f), - /*parameter_number=*/1, "scale", &builder, &h1); - ComputationDataHandle h2; - auto offset = - CreateR1Parameter(std::vector(260, 1.0f), - /*parameter_number=*/2, "offset", &builder, &h2); - - auto tuple = builder.BatchNormTraining(h0, h1, h2, - /*epsilon=*/1, kFeatureIndex); - - auto expected = *Literal::MakeTuple( - {Literal::CreateR3FromArray3D(Array3D(260, 2, 2, 1.0f)) - .get(), - Literal::CreateR1(std::vector(260, 1.0f)).get(), - Literal::CreateR1(std::vector(260, 0.0f)).get()}); - - ComputeAndCompareTuple(&builder, expected, - {operand.get(), scale.get(), offset.get()}, - ErrorSpec(0.1)); -} - -XLA_TEST_F(BatchNormTest, LargeEpsilonTest) { - // Test the correctness of choosing a large epsilon value. - const int kFeatureIndex = 2; - ComputationBuilder builder(client_, TestName()); - - ComputationDataHandle h0; - auto operand = CreateR3Parameter({{{0.0f}, {10.0f}, {20.0f}, {30.0f}}}, - /*parameter_number=*/0, "operand", - &builder, &h0); - ComputationDataHandle h1; - auto scale = - CreateR1Parameter(std::vector(1, 1.0f), - /*parameter_number=*/1, "scale", &builder, &h1); - ComputationDataHandle h2; - auto offset = - CreateR1Parameter(std::vector(1, 0.0f), - /*parameter_number=*/2, "offset", &builder, &h2); - - // var = 125, mean = 15, epsilon = -100 - auto tuple = builder.BatchNormTraining(h0, h1, h2, - /*epsilon=*/-100, kFeatureIndex); - - auto expected = *Literal::MakeTuple( - {Literal::CreateR3FromArray3D({{{-3.0f}, {-1.0f}, {1.0f}, {3.0f}}}) - .get(), - Literal::CreateR1(std::vector(1, 15.0f)).get(), - Literal::CreateR1(std::vector(1, 125.0f)).get()}); - - ComputeAndCompareTuple(&builder, expected, - {operand.get(), scale.get(), offset.get()}, - ErrorSpec(0.1)); -} - -XLA_TEST_F(BatchNormTest, BatchNormGradBasic) { - const int kFeatureIndex = 2; - ComputationBuilder builder(client_, TestName()); - - auto operand = - builder.ConstantR4FromArray4D(Array4D(2, 2, 2, 1, 0.0f)); - - auto scale = builder.ConstantR1({1.0f, 1.0f}); - - auto mean = builder.ConstantR1({0.0f, 0.0f}); - - auto var = builder.ConstantR1({1.0f, 1.0f}); - - auto grad_output = builder.ConstantR4FromArray4D( - {{{{1.f}, {2.f}}, {{3.f}, {4.f}}}, {{{5.f}, {6.f}}, {{7.f}, {8.f}}}}); - - builder.BatchNormGrad(operand, scale, mean, var, grad_output, - /*epsilon=*/0.0, kFeatureIndex); - - auto expected = *Literal::MakeTuple( - {Literal::CreateR4({{{{-3.f}, {-3.f}}, {{-1.f}, {-1.f}}}, - {{{1.f}, {1.f}}, {{3.f}, {3.f}}}}) - .get(), - Literal::CreateR1({0, 0}).get(), - Literal::CreateR1({16, 20}).get()}); - - ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.1)); -} - } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/bfloat16_test.cc b/tensorflow/compiler/xla/tests/bfloat16_test.cc index a1c53ef2aa95c7d2a9d46483dfda22a05ff0cf1a..e47fcad475bb176a7b4598daf2c98897eb34182b 100644 --- a/tensorflow/compiler/xla/tests/bfloat16_test.cc +++ b/tensorflow/compiler/xla/tests/bfloat16_test.cc @@ -61,6 +61,15 @@ XLA_TEST_F(Bfloat16Test, ScalarOperation) { error_spec_); } +XLA_TEST_F(Bfloat16Test, LogOperation) { + ComputationBuilder builder(client_, TestName()); + auto x = builder.ConstantR0(static_cast(4.0f)); + builder.Log(x); + + ComputeAndCompareR0(&builder, static_cast(1.387f), {}, + error_spec_); +} + XLA_TEST_F(Bfloat16Test, NegateScalarF16) { ComputationBuilder builder(client_, TestName()); builder.Neg(builder.ConstantR0(static_cast(2.1f))); @@ -88,7 +97,7 @@ XLA_TEST_F(Bfloat16Test, BatchNormTraining) { auto tuple = builder.BatchNormTraining(operand, scale, offset, /*epsilon=*/0.001, kFeatureIndex); - auto expected = *Literal::MakeTuple( + auto expected = Literal::MakeTuple( {Literal::CreateR4( {{{{static_cast(-1.7f)}, {static_cast(-2.04f)}}, {{static_cast(0.105f)}, {static_cast(0.65f)}}}, @@ -102,7 +111,7 @@ XLA_TEST_F(Bfloat16Test, BatchNormTraining) { {static_cast(5), static_cast(5)}) .get()}); - ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.01)); + ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.01)); } XLA_TEST_F(Bfloat16Test, BatchNormGrad) { @@ -130,7 +139,7 @@ XLA_TEST_F(Bfloat16Test, BatchNormGrad) { builder.BatchNormGrad(operand, scale, mean, var, grad_output, /*epsilon=*/0.0, kFeatureIndex); - auto expected = *Literal::MakeTuple( + auto expected = Literal::MakeTuple( {Literal::CreateR4( {{{{static_cast(-3.f)}, {static_cast(-3.f)}}, {{static_cast(-1.f)}, {static_cast(-1.f)}}}, @@ -144,7 +153,7 @@ XLA_TEST_F(Bfloat16Test, BatchNormGrad) { {static_cast(16), static_cast(20)}) .get()}); - ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.01)); + ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.01)); } } // namespace diff --git a/tensorflow/compiler/xla/tests/broadcast_test.cc b/tensorflow/compiler/xla/tests/broadcast_test.cc index 0294628a127c9d506e6387d0b80f3da583c5a174..6ebbf7191833ef85ee4a48cc96c0a3be38c71228 100644 --- a/tensorflow/compiler/xla/tests/broadcast_test.cc +++ b/tensorflow/compiler/xla/tests/broadcast_test.cc @@ -87,11 +87,11 @@ XLA_TEST_F(BroadcastTest, BroadcastVectorTo2D) { LiteralTestUtil::ExpectNear( *Literal::CreateR2({{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}}), - result->tuple_literals(0), error_spec_); + LiteralView::Create(*result, {0}), error_spec_); LiteralTestUtil::ExpectNear( *Literal::CreateR2({{1.0, 2.0, 3.0}, {1.0, 2.0, 3.0}}), - result->tuple_literals(1), error_spec_); + LiteralView::Create(*result, {1}), error_spec_); } XLA_TEST_F(BroadcastTest, Broadcast2DTo2D) { diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc index 15bd273e9b69f9c177a4ec6b5c9f0e1dccee7fc1..7c9494f133f3db3733fc2ffa4dacfb9a71dd01d8 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.cc +++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc @@ -251,8 +251,17 @@ ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts( tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( ComputationBuilder* builder, const Literal& expected, - tensorflow::gtl::ArraySlice arguments, + tensorflow::gtl::ArraySlice arguments_passed_in, const Shape* shape_with_layout) { + std::vector arguments(arguments_passed_in.begin(), + arguments_passed_in.end()); + if (!arguments_.empty()) { + CHECK(arguments.empty()); + for (const auto& argument : arguments_) { + arguments.push_back(argument.get()); + } + } + TF_ASSIGN_OR_RETURN(auto computation, builder->Build()); if (ShapeUtil::ElementIsFloating(expected.shape()) || ShapeUtil::ElementIsComplex(expected.shape())) { @@ -267,12 +276,17 @@ tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( const Literal* expected_ptr = &expected; std::unique_ptr converted_expected; Shape layout_shape; - if (expected.shape().element_type() == F32 && use_bfloat16_) { + if (use_bfloat16_) { converted_expected = LiteralTestUtil::ConvertF32ToBF16(expected); expected_ptr = converted_expected.get(); if (shape_with_layout != nullptr) { layout_shape = *shape_with_layout; - layout_shape.set_element_type(BF16); + ShapeUtil::ForEachMutableSubshape( + &layout_shape, [&](Shape* subshape, const ShapeIndex& /*index*/) { + if (subshape->element_type() == F32) { + subshape->set_element_type(BF16); + } + }); shape_with_layout = &layout_shape; } } @@ -295,8 +309,17 @@ tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( ComputationBuilder* builder, const Literal& expected, - tensorflow::gtl::ArraySlice arguments, ErrorSpec error, - const Shape* shape_with_layout) { + tensorflow::gtl::ArraySlice arguments_passed_in, + ErrorSpec error, const Shape* shape_with_layout) { + std::vector arguments(arguments_passed_in.begin(), + arguments_passed_in.end()); + if (!arguments_.empty()) { + CHECK(arguments.empty()); + for (const auto& argument : arguments_) { + arguments.push_back(argument.get()); + } + } + TF_RET_CHECK(ShapeUtil::ElementIsFloating(expected.shape()) || ShapeUtil::ElementIsComplex(expected.shape())); TF_ASSIGN_OR_RETURN(auto computation, builder->Build()); @@ -305,13 +328,17 @@ tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( const Literal* expected_ptr = &expected; std::unique_ptr converted_expected; Shape layout_shape; - if (expected.shape().element_type() == F32 && use_bfloat16_) { + if (use_bfloat16_) { converted_expected = LiteralTestUtil::ConvertF32ToBF16(expected); expected_ptr = converted_expected.get(); - layout_shape.set_element_type(BF16); if (shape_with_layout != nullptr) { layout_shape = *shape_with_layout; - layout_shape.set_element_type(BF16); + ShapeUtil::ForEachMutableSubshape( + &layout_shape, [&](Shape* subshape, const ShapeIndex& /*index*/) { + if (subshape->element_type() == F32) { + subshape->set_element_type(BF16); + } + }); shape_with_layout = &layout_shape; } } @@ -348,7 +375,7 @@ void ClientLibraryTestBase::ComputeAndCompareR1U8( VLOG(1) << "expected: " << expected_literal->ToString(); VLOG(1) << "actual: " << actual->ToString(); - EXPECT_EQ(expected, actual->u8s_string()); + EXPECT_EQ(expected, actual->GetR1U8AsString()); } void ClientLibraryTestBase::ComputeAndCompareTuple( @@ -499,17 +526,41 @@ std::unique_ptr ClientLibraryTestBase::CreateParameterAndTransferLiteral( int64 parameter_number, const Literal& literal, const string& name, ComputationBuilder* builder, ComputationDataHandle* data_handle) { + return CreateParameterAndTransferLiteral(parameter_number, literal, name, + nullptr, builder, data_handle); +} + +std::unique_ptr +ClientLibraryTestBase::CreateParameterAndTransferLiteral( + int64 parameter_number, const Literal& literal, const string& name, + const DeviceHandle* device_handle, ComputationBuilder* builder, + ComputationDataHandle* data_handle) { const Literal* param_literal = &literal; std::unique_ptr converted_literal; - if (use_bfloat16_ && literal.shape().element_type() == F32) { + if (use_bfloat16_) { converted_literal = LiteralTestUtil::ConvertF32ToBF16(literal); param_literal = converted_literal.get(); } std::unique_ptr data = - client_->TransferToServer(*param_literal).ConsumeValueOrDie(); + client_->TransferToServer(*param_literal, device_handle) + .ConsumeValueOrDie(); *data_handle = builder->Parameter(parameter_number, param_literal->shape(), name); return data; } +ComputationDataHandle ClientLibraryTestBase::AddParam( + const Literal& argument, ComputationBuilder* builder) { + ComputationDataHandle data_handle; + arguments_.push_back(CreateParameterAndTransferLiteral( + arguments_.size(), argument, "", builder, &data_handle)); + return data_handle; +} + +ComputationDataHandle ClientLibraryTestBase::CreateConstantFromLiteral( + const Literal& literal, ComputationBuilder* builder) { + return builder->ConstantLiteral( + use_bfloat16_ ? *LiteralTestUtil::ConvertF32ToBF16(literal) : literal); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h index 1d27880fb1413adbbe691b5d12cadcd85fbe5d92..a559a653df89f3b99bd87665a7f2ccf99afa54e0 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.h +++ b/tensorflow/compiler/xla/tests/client_library_test_base.h @@ -43,6 +43,23 @@ limitations under the License. namespace xla { +// Sets the use_bfloat16 on a container of test cases according to the values in +// use_bfloat16_params. Generates one set of test cases for each values in +// use_bfloat16_params with that value. Returns the result. +template +std::vector ExpandUseBfloat16( + tensorflow::gtl::ArraySlice use_bfloat16_params, + tensorflow::gtl::ArraySlice specs) { + std::vector expanded; + for (bool use_bfloat16 : use_bfloat16_params) { + for (const auto& spec : specs) { + expanded.push_back(spec); + expanded.back().use_bfloat16 = use_bfloat16; + } + } + return expanded; +} + // A client library test establishes an in-process XLA client connection. class ClientLibraryTestBase : public ::testing::Test { protected: @@ -194,7 +211,7 @@ class ClientLibraryTestBase : public ::testing::Test { tensorflow::gtl::ArraySlice arguments); void ComputeAndCompareTuple( ComputationBuilder* builder, const Literal& expected, - tensorflow::gtl::ArraySlice arguments, ErrorSpec abs_error); + tensorflow::gtl::ArraySlice arguments, ErrorSpec error); // Convenience method for running a built computation and comparing the result // with the HloEvaluator. @@ -253,6 +270,51 @@ class ClientLibraryTestBase : public ::testing::Test { int64 parameter_number, const Literal& literal, const string& name, ComputationBuilder* builder, ComputationDataHandle* data_handle); + // As above, but the caller can specify the device that the literal is + // transferred to. If device_handle is nullptr, the literal will be + // transferred to the default device. + std::unique_ptr CreateParameterAndTransferLiteral( + int64 parameter_number, const Literal& literal, const string& name, + const DeviceHandle* device_handle, ComputationBuilder* builder, + ComputationDataHandle* data_handle); + + // Creates a parameter instruction and sets the value that will be passed to + // the computation as specified. This function must be used for all parameters + // or none and no parameters must be passed when invoking the computation if + // using this mechanism. If using this mechanism, then each parameter must be + // set exactly once. The first added parameter gets index 0, then 1 and so on. + ComputationDataHandle AddParam(const Literal& argument, + ComputationBuilder* builder); + + template + ComputationDataHandle AddParam(const Array& argument, + ComputationBuilder* builder) { + return AddParam(*Literal::CreateFromArray(argument), builder); + } + + // Creates a constant instruction with the given literal. When the + // use_bfloat16 flag is set but the literal has F32 elements, the elements + // will be converted to BF16s. + ComputationDataHandle CreateConstantFromLiteral(const Literal& literal, + ComputationBuilder* builder); + + // Creates a constant instruction with the given array. When the use_bfloat16 + // flag is set but the array has float elements, the elements will be + // converted to bfloat16s. + template + ComputationDataHandle CreateConstantFromArray(const Array& array, + ComputationBuilder* builder) { + return CreateConstantFromLiteral(*Literal::CreateFromArray(array), builder); + } + + // Same as CreateConstantFromArray, but for scalars. + template + ComputationDataHandle CreateConstantFromScalar(NativeT value, + ComputationBuilder* builder) { + return CreateConstantFromLiteral(*Literal::CreateR0(value), + builder); + } + // Creates a parameter instruction that wraps a given value and then stores // into "data_handle" the global handle for that parameter. // @@ -315,6 +377,9 @@ class ClientLibraryTestBase : public ::testing::Test { bool use_bfloat16() const { return use_bfloat16_; } void set_use_bfloat16(bool value) { use_bfloat16_ = value; } + // The float type used in this test, BF16 or F32 according to use_bfloat16. + PrimitiveType FloatType() const { return use_bfloat16_ ? BF16 : F32; } + Client* client_; ExecutionOptions execution_options_; @@ -344,6 +409,9 @@ class ClientLibraryTestBase : public ::testing::Test { // Whether to run tests with all float-type input/output converted to // bfloat16. bool use_bfloat16_ = false; + + // Arguments to be passed to the computation when it runs. + std::vector> arguments_; }; template diff --git a/tensorflow/compiler/xla/tests/client_test.cc b/tensorflow/compiler/xla/tests/client_test.cc index 8853ed9e5780672d4006c326291767b8b5253f56..045148cdd11da94ae4789a753efca95c6aaa1f27 100644 --- a/tensorflow/compiler/xla/tests/client_test.cc +++ b/tensorflow/compiler/xla/tests/client_test.cc @@ -36,7 +36,7 @@ namespace { class ClientTest : public ClientLibraryTestBase {}; -TEST_F(ClientTest, ExecuteWithLayout) { +XLA_TEST_F(ClientTest, ExecuteWithLayout) { ComputationBuilder b(client_, TestName()); std::vector> layouts = {{0, 1}, {1, 0}}; @@ -68,7 +68,7 @@ TEST_F(ClientTest, ExecuteWithLayout) { } } -TEST_F(ClientTest, ExecuteWithTupleLayout) { +XLA_TEST_F(ClientTest, ExecuteWithTupleLayout) { ComputationBuilder b(client_, TestName()); b.Tuple({b.ConstantR2({{1, 2}, {3, 4}}), @@ -90,9 +90,9 @@ TEST_F(ClientTest, ExecuteWithTupleLayout) { auto result, client_->ExecuteAndTransfer(computation, {}, &execution_options)); LiteralTestUtil::ExpectR2Equal({{1, 2}, {3, 4}}, - result->tuple_literals(0)); + LiteralView::Create(*result, {0})); LiteralTestUtil::ExpectR2Equal({{10, 20}, {30, 40}}, - result->tuple_literals(1)); + LiteralView::Create(*result, {1})); EXPECT_TRUE(ShapeUtil::IsTuple(result->shape())); EXPECT_EQ(2, ShapeUtil::TupleElementCount(result->shape())); @@ -107,7 +107,8 @@ TEST_F(ClientTest, ExecuteWithTupleLayout) { /*minor_to_major=*/{1, 0}))); } -TEST_F(ClientTest, DISABLED_ON_CPU_PARALLEL(DISABLED_ON_GPU(ExecuteParallel))) { +XLA_TEST_F(ClientTest, + DISABLED_ON_CPU_PARALLEL(DISABLED_ON_GPU(ExecuteParallel))) { Computation add_with_one_arg, mul_with_two_args, dot_with_one_arg; Shape shape = ShapeUtil::MakeShape(S32, {2, 2}); diff --git a/tensorflow/compiler/xla/tests/compute_constant_test.cc b/tensorflow/compiler/xla/tests/compute_constant_test.cc index 5226a78386824a94572d3e5cc3329677108a910a..ec2c580670cfac14ba42e8c9a836c86551af4b89 100644 --- a/tensorflow/compiler/xla/tests/compute_constant_test.cc +++ b/tensorflow/compiler/xla/tests/compute_constant_test.cc @@ -149,7 +149,7 @@ TEST_F(ComputeConstantTest, Param) { auto computation = b.Add(param, b.ConstantR0(1.5f)); std::vector arguments; - arguments.emplace_back(*Literal::CreateR0(42.5f)); + arguments.push_back(std::move(*Literal::CreateR0(42.5f))); EXPECT_TRUE(IsConstant(computation, &b, arguments.size())); auto value = @@ -168,7 +168,7 @@ TEST_F(ComputeConstantTest, DirectParamMissing) { auto value = ComputeConstantScalar(client, computation, &b); EXPECT_TRUE(tensorflow::StringPiece(value.status().ToString()) - .contains("depends on parameter")) + .contains("depends on a parameter")) << value.status(); } } @@ -184,7 +184,7 @@ TEST_F(ComputeConstantTest, IndirectParamMissing) { auto value = ComputeConstantScalar(client, computation, &b); EXPECT_TRUE(tensorflow::StringPiece(value.status().ToString()) - .contains("depends on parameter")) + .contains("depends on a parameter")) << value.status(); } } diff --git a/tensorflow/compiler/xla/tests/conditional_test.cc b/tensorflow/compiler/xla/tests/conditional_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..0016b6cc614469d7ac9b40b740d163a7a4f32abf --- /dev/null +++ b/tensorflow/compiler/xla/tests/conditional_test.cc @@ -0,0 +1,553 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" + +namespace xla { +namespace { + +class ConditionalOpTest : public ClientLibraryTestBase { + protected: + Computation CreateR0ConstantComputation(float value) { + ComputationBuilder builder(client_, "Constant"); + builder.Parameter(0, empty_tuple_, "tuple"); + builder.ConstantR0(value); + auto build_status = builder.Build(); + EXPECT_IS_OK(build_status.status()); + return build_status.ConsumeValueOrDie(); + } + + Computation CreateR0IdentityComputation() { + ComputationBuilder builder(client_, "Identity"); + builder.Parameter(0, r0f32_, "x"); + auto build_status = builder.Build(); + EXPECT_IS_OK(build_status.status()); + return build_status.ConsumeValueOrDie(); + } + + Computation CreateCeilComputation(const Shape& shape) { + ComputationBuilder builder(client_, "Ceil"); + auto param = builder.Parameter(0, shape, "param"); + builder.Ceil(param); + auto build_status = builder.Build(); + EXPECT_IS_OK(build_status.status()); + return build_status.ConsumeValueOrDie(); + } + + Computation CreateR0CeilComputation() { + return CreateCeilComputation(r0f32_); + } + + Computation CreateR1CeilComputation() { + return CreateCeilComputation(r1s2f32_); + } + + Computation CreateFloorComputation(const Shape& shape) { + ComputationBuilder builder(client_, "Floor"); + auto param = builder.Parameter(0, shape, "param"); + builder.Floor(param); + auto build_status = builder.Build(); + EXPECT_IS_OK(build_status.status()); + return build_status.ConsumeValueOrDie(); + } + + Computation CreateR0FloorComputation() { + return CreateFloorComputation(r0f32_); + } + + Computation CreateR1FloorComputation() { + return CreateFloorComputation(r1s2f32_); + } + + Computation CreateTupleCeilComputation(const string& computation_name, + const Shape& tuple_shape) { + ComputationBuilder builder(client_, computation_name); + auto tuple = builder.Parameter(0, tuple_shape, "tuple"); + auto x = builder.GetTupleElement(tuple, 0); + auto y = builder.GetTupleElement(tuple, 1); + auto x_ceil = builder.Ceil(x); + auto y_ceil = builder.Ceil(y); + builder.Tuple({x_ceil, y_ceil}); + auto build_status = builder.Build(); + EXPECT_IS_OK(build_status.status()); + return build_status.ConsumeValueOrDie(); + } + + Computation CreateR0TupleCeilComputation() { + return CreateTupleCeilComputation("CeilR0", tuple_2_r0f32_); + } + + Computation CreateR1TupleCeilComputation() { + return CreateTupleCeilComputation("CeilR1", tuple_2_r1s2f32_); + } + + Computation CreateTupleFloorComputation(const string& computation_name, + const Shape& tuple_shape) { + ComputationBuilder builder(client_, computation_name); + auto tuple = builder.Parameter(0, tuple_shape, "tuple"); + auto x = builder.GetTupleElement(tuple, 0); + auto y = builder.GetTupleElement(tuple, 1); + auto x_floor = builder.Floor(x); + auto y_floor = builder.Floor(y); + builder.Tuple({x_floor, y_floor}); + auto build_status = builder.Build(); + EXPECT_IS_OK(build_status.status()); + return build_status.ConsumeValueOrDie(); + } + + Computation CreateR0TupleFloorComputation() { + return CreateTupleFloorComputation("FloorR0", tuple_2_r0f32_); + } + + Computation CreateR1TupleFloorComputation() { + return CreateTupleFloorComputation("FloorR1", tuple_2_r1s2f32_); + } + + Computation CreateTupleAddComputation(const string& computation_name, + const Shape& tuple_shape) { + ComputationBuilder builder(client_, computation_name); + auto tuple = builder.Parameter(0, tuple_shape, "tuple"); + auto x = builder.GetTupleElement(tuple, 0); + auto y = builder.GetTupleElement(tuple, 1); + builder.Add(x, y); + auto build_status = builder.Build(); + EXPECT_IS_OK(build_status.status()); + return build_status.ConsumeValueOrDie(); + } + + Computation CreateR0TupleAddComputation() { + return CreateTupleAddComputation("AddR0", tuple_2_r0f32_); + } + + Computation CreateR1TupleAddComputation() { + return CreateTupleAddComputation("AddR1", tuple_2_r1s2f32_); + } + + Computation CreateTupleSubComputation(const string& computation_name, + const Shape& tuple_shape) { + ComputationBuilder builder(client_, computation_name); + auto tuple = builder.Parameter(0, tuple_shape, "tuple"); + auto x = builder.GetTupleElement(tuple, 0); + auto y = builder.GetTupleElement(tuple, 1); + builder.Sub(x, y); + auto build_status = builder.Build(); + EXPECT_IS_OK(build_status.status()); + return build_status.ConsumeValueOrDie(); + } + + Computation CreateR0TupleSubComputation() { + return CreateTupleSubComputation("SubR0", tuple_2_r0f32_); + } + + Computation CreateR1TupleSubComputation() { + return CreateTupleSubComputation("SubR1", tuple_2_r1s2f32_); + } + + Shape r0f32_ = ShapeUtil::MakeShape(F32, {}); + Shape r1s2f32_ = ShapeUtil::MakeShape(F32, {2}); + Shape tuple_2_r0f32_ = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {})}); + Shape tuple_2_r1s2f32_ = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {2}), ShapeUtil::MakeShape(F32, {2})}); + Shape empty_tuple_ = ShapeUtil::MakeTupleShape({}); + ErrorSpec error_spec_{0.001}; +}; + +// Test true and false computations that do not take any parameters. +XLA_TEST_F(ConditionalOpTest, Parameters0) { + ComputationBuilder builder(client_, TestName()); + auto pred = builder.ConstantR0(true); + auto operands = builder.Tuple({}); + auto true_computation = CreateR0ConstantComputation(56.0f); + auto false_computation = CreateR0ConstantComputation(12.0f); + auto result = builder.Conditional(pred, operands, true_computation, operands, + false_computation); + + ComputeAndCompareR0(&builder, 56.0f, {}, error_spec_); +} + +// Test true and false computations that take in 1 parameter. +XLA_TEST_F(ConditionalOpTest, Parameters1) { + ComputationBuilder builder(client_, TestName()); + auto pred = builder.ConstantR0(false); + auto operand1 = builder.ConstantR0(56.0f); + auto operand2 = builder.ConstantR0(12.0f); + auto identity = CreateR0IdentityComputation(); + auto result = + builder.Conditional(pred, operand1, identity, operand2, identity); + + ComputeAndCompareR0(&builder, 12.0f, {}, error_spec_); +} + +// Test conditional with two different computations in the true and false cases +// that take in different arguments. +XLA_TEST_F(ConditionalOpTest, DiffComputationsDiffArgs) { + ComputationBuilder builder(client_, TestName()); + auto pred = builder.ConstantR0(false); + auto operand1 = builder.ConstantR0(56.4f); + auto operand2 = builder.ConstantR0(12.6f); + auto result = builder.Conditional(pred, operand1, CreateR0CeilComputation(), + operand2, CreateR0FloorComputation()); + + ComputeAndCompareR0(&builder, 12.0f, {}, error_spec_); +} + +// Test conditional with two different computations in the true and false cases +// that take in the same arguments. +XLA_TEST_F(ConditionalOpTest, DiffComputationsSameArg) { + ComputationBuilder builder(client_, TestName()); + auto pred = builder.ConstantR0(false); + auto operand = builder.ConstantR0(12.6f); + auto result = builder.Conditional(pred, operand, CreateR0CeilComputation(), + operand, CreateR0FloorComputation()); + + ComputeAndCompareR0(&builder, 12.0f, {}, error_spec_); +} + +// Test conditional with the same computation in the true and false cases but +// take in different arguments. +XLA_TEST_F(ConditionalOpTest, SameComputationDiffArgs) { + ComputationBuilder builder(client_, TestName()); + auto pred = builder.ConstantR0(false); + auto operand1 = builder.ConstantR0(56.4f); + auto operand2 = builder.ConstantR0(12.6f); + auto floor = CreateR0FloorComputation(); + auto result = builder.Conditional(pred, operand1, floor, operand2, floor); + + ComputeAndCompareR0(&builder, 12.0f, {}, error_spec_); +} + +// Test conditional with the same computation in the true and false cases that +// take in the same arguments. +XLA_TEST_F(ConditionalOpTest, SameComputationSameArg) { + ComputationBuilder builder(client_, TestName()); + auto pred = builder.ConstantR0(false); + auto operand = builder.ConstantR0(12.6f); + auto floor = CreateR0FloorComputation(); + auto result = builder.Conditional(pred, operand, floor, operand, floor); + + ComputeAndCompareR0(&builder, 12.0f, {}, error_spec_); +} + +// Test conditional with different instances of the same computation in the true +// and false cases. +XLA_TEST_F(ConditionalOpTest, SameComputationDiffInstances) { + ComputationBuilder builder(client_, TestName()); + auto pred = builder.ConstantR0(false); + auto operand1 = builder.ConstantR0(56.4f); + auto operand2 = builder.ConstantR0(12.6f); + auto result = builder.Conditional(pred, operand1, CreateR0FloorComputation(), + operand2, CreateR0FloorComputation()); + + ComputeAndCompareR0(&builder, 12.0f, {}, error_spec_); +} + +// Test the case when a call invokes a computation that contains a conditional. +XLA_TEST_F(ConditionalOpTest, ConditionalWithCall) { + Shape r0bool = ShapeUtil::MakeShape(PRED, {}); + ComputationBuilder inner_builder(client_, TestName() + ".inner_conditional"); + auto pred_cond = inner_builder.Parameter(0, r0bool, "param0"); + auto true_operand = inner_builder.Parameter(1, r0f32_, "param1"); + auto false_operand = inner_builder.Parameter(2, r0f32_, "param2"); + inner_builder.Conditional(pred_cond, true_operand, CreateR0CeilComputation(), + false_operand, CreateR0FloorComputation()); + auto inner_builder_result = inner_builder.Build(); + + ComputationBuilder builder(client_, TestName()); + auto pred = builder.ConstantR0(false); + auto operand1 = builder.ConstantR0(56.4f); + auto operand2 = builder.ConstantR0(12.6f); + builder.Call(inner_builder_result.ConsumeValueOrDie(), + {pred, operand1, operand2}); + + ComputeAndCompareR0(&builder, 12.0f, {}, error_spec_); +} + +// Test true and false computations that take in 2 parameters and predicate is +// true. +XLA_TEST_F(ConditionalOpTest, Parameters2TrueBranch) { + ComputationBuilder builder(client_, TestName()); + auto pred = builder.ConstantR0(true); + auto operand1 = builder.ConstantR0(56.0f); + auto operand2 = builder.ConstantR0(12.0f); + auto operands = builder.Tuple({operand1, operand2}); + auto result = + builder.Conditional(pred, operands, CreateR0TupleAddComputation(), + operands, CreateR0TupleSubComputation()); + + ComputeAndCompareR0(&builder, 68.0f, {}, error_spec_); +} + +// Test true and false computations that take in 2 parameters and predicate is +// false. +XLA_TEST_F(ConditionalOpTest, Parameters2FalseBranch) { + ComputationBuilder builder(client_, TestName()); + auto pred = builder.ConstantR0(false); + auto operand1 = builder.ConstantR0(56.0f); + auto operand2 = builder.ConstantR0(12.0f); + auto operands = builder.Tuple({operand1, operand2}); + auto result = + builder.Conditional(pred, operands, CreateR0TupleAddComputation(), + operands, CreateR0TupleSubComputation()); + + ComputeAndCompareR0(&builder, 44.0f, {}, error_spec_); +} + +// Test true and false computations that take in 2 array parameters and +// predicate is true. +XLA_TEST_F(ConditionalOpTest, Parameters2ArrayTrueBranch) { + ComputationBuilder builder(client_, TestName()); + auto pred = builder.ConstantR0(true); + auto operand1 = builder.ConstantR1({24.0f, 56.0f}); + auto operand2 = builder.ConstantR1({10.0f, 11.0f}); + auto operands = builder.Tuple({operand1, operand2}); + auto result = + builder.Conditional(pred, operands, CreateR1TupleAddComputation(), + operands, CreateR1TupleSubComputation()); + + ComputeAndCompareR1(&builder, {34.0f, 67.0f}, {}, error_spec_); +} + +// Test true and false computations that take in 2 array parameters and +// predicate is false. +XLA_TEST_F(ConditionalOpTest, Parameters2ArrayFalseBranch) { + ComputationBuilder builder(client_, TestName()); + auto pred = builder.ConstantR0(false); + auto operand1 = builder.ConstantR1({24.0f, 56.0f}); + auto operand2 = builder.ConstantR1({10.0f, 11.0f}); + auto operands = builder.Tuple({operand1, operand2}); + auto result = + builder.Conditional(pred, operands, CreateR1TupleAddComputation(), + operands, CreateR1TupleSubComputation()); + + ComputeAndCompareR1(&builder, {14.0f, 45.0f}, {}, error_spec_); +} + +// Test true and false computations that return a tuple of scalars. +XLA_TEST_F(ConditionalOpTest, ReturnTupleOfScalars) { + ComputationBuilder builder(client_, TestName()); + auto pred = builder.ConstantR0(false); + auto operands = builder.Tuple( + {builder.ConstantR0(12.2f), builder.ConstantR0(25.6f)}); + builder.Conditional(pred, operands, CreateR0TupleCeilComputation(), operands, + CreateR0TupleFloorComputation()); + + ComputeAndCompareTuple( + &builder, + *Literal::MakeTuple({Literal::CreateR0(12.0f).get(), + Literal::CreateR0(25.0f).get()}), + {}, error_spec_); +} + +// Test true and false computations that return a tuple of arrays. +// TODO(b/71715476): Returning tuples from Conditional fails in GPU backend. +XLA_TEST_F(ConditionalOpTest, DISABLED_ON_GPU(ReturnTupleOfArrays)) { + ComputationBuilder builder(client_, TestName()); + auto pred = builder.ConstantR0(true); + auto operands = builder.Tuple({builder.ConstantR1({12.2f, 15.8f}), + builder.ConstantR1({25.6f, 29.2f})}); + builder.Conditional(pred, operands, CreateR1TupleCeilComputation(), operands, + CreateR1TupleFloorComputation()); + + ComputeAndCompareTuple( + &builder, + *Literal::MakeTuple({Literal::CreateR1({13.0f, 16.0f}).get(), + Literal::CreateR1({26.0f, 30.0f}).get()}), + {}, error_spec_); +} + +// Test true and false computations that return a tuple of a predicate, a +// scalar, and an array. +// TODO(b/71715476): Returning tuples from Conditional fails in GPU backend. +XLA_TEST_F(ConditionalOpTest, + DISABLED_ON_GPU(ReturnTupleofPredicateScalarArray)) { + ComputationBuilder true_builder(client_, TestName() + ".true"); + { + true_builder.Parameter(0, empty_tuple_, "tuple"); + auto true_pred = true_builder.ConstantR0(true); + auto true_scalar = true_builder.ConstantR0(12.2f); + auto true_array = true_builder.ConstantR1({12.8f, 14.6f}); + true_builder.Tuple({true_pred, true_scalar, true_array}); + } + auto true_builder_result = true_builder.Build(); + EXPECT_IS_OK(true_builder_result.status()); + + ComputationBuilder false_builder(client_, TestName() + ".false"); + { + false_builder.Parameter(0, empty_tuple_, "tuple"); + auto false_pred = false_builder.ConstantR0(false); + auto false_scalar = false_builder.ConstantR0(25.6f); + auto false_array = false_builder.ConstantR1({26.4f, 32.6f}); + false_builder.Tuple({false_pred, false_scalar, false_array}); + } + auto false_builder_result = false_builder.Build(); + EXPECT_IS_OK(false_builder_result.status()); + + ComputationBuilder builder(client_, TestName()); + auto pred = builder.ConstantR0(true); + auto operands = builder.Tuple({}); + builder.Conditional(pred, operands, true_builder_result.ConsumeValueOrDie(), + operands, false_builder_result.ConsumeValueOrDie()); + + ComputeAndCompareTuple( + &builder, + *Literal::MakeTuple({Literal::CreateR0(true).get(), + Literal::CreateR0(12.2f).get(), + Literal::CreateR1({12.8f, 14.6f}).get()}), + {}, error_spec_); +} + +// Test true and false computations that return a nested tuple. +// TODO(b/71715476): Returning tuples from Conditional fails in GPU backend. +XLA_TEST_F(ConditionalOpTest, DISABLED_ON_GPU(ReturnNestedTuple)) { + ComputationBuilder true_builder(client_, TestName() + ".true"); + { + true_builder.Parameter(0, empty_tuple_, "tuple"); + auto true_constant1 = true_builder.ConstantR0(12.2f); + auto true_constant2 = true_builder.ConstantR1({12.8f, 14.6f}); + auto true_constant3 = true_builder.ConstantR1({25.4f, 29.8f}); + auto true_constant4 = true_builder.ConstantR0(35.6f); + true_builder.Tuple({true_builder.Tuple({true_constant1, true_constant2}), + true_builder.Tuple({true_constant3, true_constant4})}); + } + auto true_builder_result = true_builder.Build(); + EXPECT_IS_OK(true_builder_result.status()); + + ComputationBuilder false_builder(client_, TestName() + ".false"); + { + false_builder.Parameter(0, empty_tuple_, "tuple"); + auto false_constant1 = false_builder.ConstantR0(46.6f); + auto false_constant2 = false_builder.ConstantR1({54.4f, 58.4f}); + auto false_constant3 = false_builder.ConstantR1({62.1f, 67.4f}); + auto false_constant4 = false_builder.ConstantR0(9.3f); + false_builder.Tuple( + {false_builder.Tuple({false_constant1, false_constant2}), + false_builder.Tuple({false_constant3, false_constant4})}); + } + auto false_builder_result = false_builder.Build(); + EXPECT_IS_OK(false_builder_result.status()); + + ComputationBuilder builder(client_, TestName()); + auto pred = builder.ConstantR0(false); + auto operands = builder.Tuple({}); + builder.Conditional(pred, operands, true_builder_result.ConsumeValueOrDie(), + operands, false_builder_result.ConsumeValueOrDie()); + + ComputeAndCompareTuple( + &builder, + *Literal::MakeTuple( + {Literal::MakeTuple({Literal::CreateR0(46.6f).get(), + Literal::CreateR1({54.4f, 58.4f}).get()}) + .get(), + Literal::MakeTuple({Literal::CreateR1({62.1f, 67.4f}).get(), + Literal::CreateR0(9.3f).get()}) + .get()}), + {}, error_spec_); +} + +// Test conditional that takes in scalar operands in the form of external +// params. +XLA_TEST_F(ConditionalOpTest, ScalarOperandsFromExternalParams) { + Shape r0bool = ShapeUtil::MakeShape(PRED, {}); + ComputationBuilder builder(client_, TestName()); + + ComputationDataHandle pred, operand1, operand2; + auto pred_arg = CreateR0Parameter(true, 0, "pred", &builder, &pred); + auto operand1_param = + CreateR0Parameter(56.3f, 1, "operand1", &builder, &operand1); + auto operand2_param = + CreateR0Parameter(12.7f, 2, "operand2", &builder, &operand2); + auto result = builder.Conditional(pred, operand1, CreateR0CeilComputation(), + operand2, CreateR0FloorComputation()); + + ComputeAndCompareR0( + &builder, 57.0f, + {pred_arg.get(), operand1_param.get(), operand2_param.get()}, + error_spec_); +} + +// Test conditional that takes in array operands in the form of external params. +XLA_TEST_F(ConditionalOpTest, ArrayOperandsFromExternalParams) { + Shape r0bool = ShapeUtil::MakeShape(PRED, {}); + ComputationBuilder builder(client_, TestName()); + + ComputationDataHandle pred, operand1, operand2; + auto pred_arg = CreateR0Parameter(false, 0, "pred", &builder, &pred); + auto operand1_param = CreateR1Parameter({24.3f, 56.7f}, 1, "operand1", + &builder, &operand1); + auto operand2_param = CreateR1Parameter({10.2f, 11.6f}, 2, "operand2", + &builder, &operand2); + auto result = builder.Conditional(pred, operand1, CreateR1CeilComputation(), + operand2, CreateR1FloorComputation()); + + ComputeAndCompareR1( + &builder, {10.0f, 11.0f}, + {pred_arg.get(), operand1_param.get(), operand2_param.get()}, + error_spec_); +} + +// Test the case where one conditional is nested within another. +XLA_TEST_F(ConditionalOpTest, NestedConditionals) { + ComputationBuilder inner_builder(client_, TestName() + ".inner_conditional"); + { + Shape r0bool = ShapeUtil::MakeShape(PRED, {}); + Shape tuple_shape = ShapeUtil::MakeTupleShape({r0bool, r0f32_, r0f32_}); + auto param0 = inner_builder.Parameter(0, tuple_shape, "param0"); + auto pred_cond = inner_builder.GetTupleElement(param0, 0); + auto true_operand = inner_builder.GetTupleElement(param0, 1); + auto false_operand = inner_builder.GetTupleElement(param0, 2); + inner_builder.Conditional(pred_cond, true_operand, + CreateR0CeilComputation(), false_operand, + CreateR0FloorComputation()); + } + auto inner_builder_result = inner_builder.Build(); + EXPECT_IS_OK(inner_builder_result.status()); + + ComputationBuilder builder(client_, TestName()); + auto pred1 = builder.ConstantR0(true); + auto pred2 = builder.ConstantR0(false); + auto operand1 = builder.ConstantR0(1.1f); + auto operand2 = builder.ConstantR0(12.2f); + auto operand3 = builder.ConstantR0(43.3f); + auto tuple_operand = builder.Tuple({pred2, operand1, operand2}); + builder.Conditional(pred1, tuple_operand, + inner_builder_result.ConsumeValueOrDie(), operand3, + CreateR0IdentityComputation()); + + ComputeAndCompareR0(&builder, 12.0f, {}, error_spec_); +} + +// Test a mismatch in the shape of the true operand and true computation. +XLA_TEST_F(ConditionalOpTest, ShapeMismatch) { + ComputationBuilder builder(client_, TestName()); + auto pred = builder.ConstantR0(true); + auto operand1 = builder.ConstantR0(56.0f); + auto operand2 = builder.ConstantR0(12.0f); + auto operands = builder.Tuple({operand1, operand2}); + builder.Conditional(pred, operands, CreateR1TupleAddComputation(), operands, + CreateR0TupleSubComputation()); + + auto result = builder.Build(); + EXPECT_FALSE(result.ok()); + EXPECT_THAT(result.status().error_message(), + ::testing::HasSubstr("true_operand must match the shape of the " + "only parameter of true_computation")); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/constants_test.cc b/tensorflow/compiler/xla/tests/constants_test.cc index 97bd1553664a6c0fcb097b441ec42efb4eaa9cc2..35aa3f6d696297efb7d95d826ed75a504a24529d 100644 --- a/tensorflow/compiler/xla/tests/constants_test.cc +++ b/tensorflow/compiler/xla/tests/constants_test.cc @@ -141,11 +141,12 @@ TEST_F(ConstantsTest, Small_3x2x1x1) { {5.0f, 4.4f}, // p2 }); input_array.FillWithPZ(pz); - Literal input_literal = *Literal::CreateR4FromArray4D(input_array); + std::unique_ptr input_literal = + Literal::CreateR4FromArray4D(input_array); { ComputationBuilder builder(client_, TestName()); - builder.ConstantLiteral(input_literal); + builder.ConstantLiteral(*input_literal); ComputeAndCompareR4(&builder, input_array, {}, error_spec_); } @@ -165,10 +166,10 @@ TEST_F(ConstantsTest, DISABLED_TupleConstant) { std::unique_ptr result = ExecuteAndTransferOrDie(&builder, {}); - LiteralTestUtil::ExpectR2Near({{1.0}, {2.0}}, - result->tuple_literals(0), error_spec_); - LiteralTestUtil::ExpectR1Near({2.0, 42.0}, result->tuple_literals(1), - error_spec_); + LiteralTestUtil::ExpectR2Near( + {{1.0}, {2.0}}, LiteralView::Create(*result, {0}), error_spec_); + LiteralTestUtil::ExpectR1Near( + {2.0, 42.0}, LiteralView::Create(*result, {1}), error_spec_); } } // namespace diff --git a/tensorflow/compiler/xla/tests/convolution_test.cc b/tensorflow/compiler/xla/tests/convolution_test.cc index 2924c08615fa706bb19addf04bf58e1d5dd5a659..a10e17dbf34b3a6fe503f156fab496708b833c07 100644 --- a/tensorflow/compiler/xla/tests/convolution_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_test.cc @@ -105,8 +105,8 @@ TEST_F(ConvolutionTest, Convolve_1x1x1x2_1x1x1x2_Valid) { })); ComputeAndCompare(&builder, conv, - {*Literal::CreateFromArray(input_data), - *Literal::CreateFromArray(filter_data)}, + {std::move(*Literal::CreateFromArray(input_data)), + std::move(*Literal::CreateFromArray(filter_data))}, error_spec_); } @@ -136,8 +136,8 @@ TEST_F(ConvolutionTest, Convolve_1x1x4x4_1x1x2x2_Valid) { })); // clang-format on ComputeAndCompare(&builder, conv, - {*Literal::CreateFromArray(input_data), - *Literal::CreateFromArray(filter_data)}, + {std::move(*Literal::CreateFromArray(input_data)), + std::move(*Literal::CreateFromArray(filter_data))}, error_spec_); } @@ -167,8 +167,8 @@ TEST_F(ConvolutionTest, Convolve_1x1x4x4_1x1x2x2_Same) { })); // clang-format on ComputeAndCompare(&builder, conv, - {*Literal::CreateFromArray(input_data), - *Literal::CreateFromArray(filter_data)}, + {std::move(*Literal::CreateFromArray(input_data)), + std::move(*Literal::CreateFromArray(filter_data))}, error_spec_); } @@ -200,8 +200,8 @@ TEST_F(ConvolutionTest, Convolve_1x1x4x4_1x1x3x3_Same) { })); // clang-format on ComputeAndCompare(&builder, conv, - {*Literal::CreateFromArray(input_data), - *Literal::CreateFromArray(filter_data)}, + {std::move(*Literal::CreateFromArray(input_data)), + std::move(*Literal::CreateFromArray(filter_data))}, error_spec_); } @@ -501,10 +501,10 @@ XLA_TEST_P(ConvolveWithAndWithoutCanonicalization, Array2D expected_result(29, 10); expected_result.Fill(0); - ComputeAndCompare( - &builder, conv, - {*Literal::CreateFromArray(param0), *Literal::CreateFromArray(param1)}, - error_spec_); + ComputeAndCompare(&builder, conv, + {std::move(*Literal::CreateFromArray(param0)), + std::move(*Literal::CreateFromArray(param1))}, + error_spec_); } INSTANTIATE_TEST_CASE_P(ConvolveWithAndWithoutCanonicalization_Instantiation, diff --git a/tensorflow/compiler/xla/tests/copy_test.cc b/tensorflow/compiler/xla/tests/copy_test.cc index bcb85b04eefa349df1c055e010d584b85b55a4a8..ece7c3b05e7fafa299db7f9cbf50610c8204f95e 100644 --- a/tensorflow/compiler/xla/tests/copy_test.cc +++ b/tensorflow/compiler/xla/tests/copy_test.cc @@ -40,7 +40,7 @@ class CopyOpTest : public HloTestBase { void TestCopyOp(const Literal& literal) { auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(MakeUnique(literal))); + HloInstruction::CreateConstant(literal.CloneToUnique())); builder.AddInstruction(HloInstruction::CreateUnary( constant->shape(), HloOpcode::kCopy, constant)); auto computation = builder.Build(); @@ -56,9 +56,13 @@ class CopyOpTest : public HloTestBase { tensorflow::gtl::ArraySlice permutation); }; -XLA_TEST_F(CopyOpTest, CopyR0Bool) { TestCopyOp(*Literal::CreateR0(true)); } +XLA_TEST_F(CopyOpTest, CopyR0Bool) { + TestCopyOp(*Literal::CreateR0(true)); +} -XLA_TEST_F(CopyOpTest, CopyR1S0U32) { TestCopyOp(*Literal::CreateR1({})); } +XLA_TEST_F(CopyOpTest, CopyR1S0U32) { + TestCopyOp(*Literal::CreateR1({})); +} XLA_TEST_F(CopyOpTest, CopyR1S3U32) { TestCopyOp(*Literal::CreateR1({1, 2, 3})); @@ -85,7 +89,6 @@ XLA_TEST_F(CopyOpTest, CopyParameterScalar) { // Copy literal to device to use as parameter. auto literal = Literal::CreateR0(42.0); Shape shape = literal->shape(); - auto constant_device_base = TransferToDevice(*literal); auto param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, shape, "param0")); @@ -98,7 +101,7 @@ XLA_TEST_F(CopyOpTest, CopyParameterScalar) { module->AddEntryComputation(std::move(computation)); std::unique_ptr result = - ExecuteAndTransfer(std::move(module), {constant_device_base}); + ExecuteAndTransfer(std::move(module), {literal.get()}); LiteralTestUtil::ExpectR0Near(42.0f, *result, error_spec_); } @@ -129,7 +132,8 @@ XLA_TEST_F(CopyOpTest, CopyConstantR2DifferentLayouts) { std::unique_ptr literal = Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); // Reverse the minor-to-major order of the literal. - Layout* literal_layout = literal->mutable_shape()->mutable_layout(); + Layout* literal_layout = + literal->mutable_shape_do_not_use()->mutable_layout(); ASSERT_EQ(2, literal_layout->minor_to_major_size()); literal_layout->mutable_minor_to_major()->SwapElements(0, 1); diff --git a/tensorflow/compiler/xla/tests/custom_call_test.cc b/tensorflow/compiler/xla/tests/custom_call_test.cc index 74f73a1ddc15be033e52b0b45f9961e5dc3a1ecb..2d847a66b0ae7c8f09fa0cb181a4c84ea99be5b1 100644 --- a/tensorflow/compiler/xla/tests/custom_call_test.cc +++ b/tensorflow/compiler/xla/tests/custom_call_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" @@ -128,5 +129,19 @@ XLA_TEST_F(CustomCallTest, Array3D{{{2, 3}, {4, 5}}, {{3, 4}, {5, 6}}}, *result); } +class CustomCallClientAPITest : public ClientLibraryTestBase {}; + +// When using the client API, CustomCall targets can't begin with '$' -- these +// are reserved for internal use. +XLA_TEST_F(CustomCallClientAPITest, IllegalCustomCallTarget) { + ComputationBuilder builder(client_, TestName()); + auto call = builder.CustomCall("$illegal", /*operands=*/{}, + ShapeUtil::MakeShape(F32, {1})); + + StatusOr> result = + Execute(&builder, /*arguments=*/{}); + EXPECT_FALSE(result.ok()); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc index bfb04fd9f9bf6887c4462cb00fee00250517f5c4..cc683701e6305510d202721fe645310f1009081c 100644 --- a/tensorflow/compiler/xla/tests/dot_operation_test.cc +++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc @@ -51,8 +51,6 @@ class DotOperationTest : public ClientLibraryTestBase { template void TestNonsquareMatrixDot(bool lhs_row_major = false, bool rhs_row_major = false); - void TestMatrixDot(int M, int K, int N, bool lhs_row_major = false, - bool rhs_row_major = false); }; XLA_TEST_F(DotOperationTest, ZeroElementVectorDotF32) { @@ -199,158 +197,182 @@ void DotOperationTest::TestSquareMatrixDot(bool lhs_row_major, &builder, expected, {lhs_handle.get(), rhs_handle.get()}, error_spec_); } -void DotOperationTest::TestMatrixDot(int M, int K, int N, bool lhs_row_major, - bool rhs_row_major) { - std::unique_ptr> lhs_data = - MakeLinspaceArray2D(0.0, 1.0, M, K); - std::unique_ptr lhs_lit = Literal::CreateR2FromArray2DWithLayout( - *lhs_data, - LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(lhs_row_major))); - auto lhs_handle = client_->TransferToServer(*lhs_lit).ConsumeValueOrDie(); +struct DotTestParam { + int m; + int k; + int n; + bool dot_lhs_row_major; + bool dot_rhs_row_major; + bool has_addend; + bool addend_row_major; +}; - std::unique_ptr> rhs_data = - MakeLinspaceArray2D(0.0, 1.0, K, N); - std::unique_ptr rhs_lit = Literal::CreateR2FromArray2DWithLayout( - *rhs_data, - LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(rhs_row_major))); - auto rhs_handle = client_->TransferToServer(*rhs_lit).ConsumeValueOrDie(); +string PrintDotTestParam( + const ::testing::TestParamInfo& test_param) { + const DotTestParam& param = test_param.param; + if (param.has_addend) { + return tensorflow::strings::StrCat(param.m, "x", param.k, "x", param.n, + "_MajorToMinor", + param.dot_lhs_row_major ? "T" : "F", + param.dot_rhs_row_major ? "T" : "F", + param.addend_row_major ? "T" : "F"); + } else { + return tensorflow::strings::StrCat(param.m, "x", param.k, "x", param.n, + "_MajorToMinor", + param.dot_lhs_row_major ? "T" : "F", + param.dot_rhs_row_major ? "T" : "F"); + } +} + +class ParametricDotTest : public DotOperationTest, + public ::testing::WithParamInterface {}; + +XLA_TEST_P(ParametricDotTest, TestF32) { + DotTestParam param = GetParam(); + + std::unique_ptr> dot_lhs_data = + MakeLinspaceArray2D(0.0, 1.0, param.m, param.k); + std::unique_ptr dot_lhs_lit = Literal::CreateR2FromArray2DWithLayout( + *dot_lhs_data, LayoutUtil::MakeLayout( + MinorToMajorForIsRowMajor(param.dot_lhs_row_major))); + std::unique_ptr dot_lhs_handle = + client_->TransferToServer(*dot_lhs_lit).ConsumeValueOrDie(); + + std::unique_ptr> dot_rhs_data = + MakeLinspaceArray2D(0.0, 1.0, param.k, param.n); + std::unique_ptr dot_rhs_lit = Literal::CreateR2FromArray2DWithLayout( + *dot_rhs_data, LayoutUtil::MakeLayout( + MinorToMajorForIsRowMajor(param.dot_rhs_row_major))); + std::unique_ptr dot_rhs_handle = + client_->TransferToServer(*dot_rhs_lit).ConsumeValueOrDie(); + + std::unique_ptr> addend_data; + std::unique_ptr addend_lit; + std::unique_ptr addend_handle; + + if (param.has_addend) { + addend_data = MakeLinspaceArray2D(0.0, 1.0, param.m, param.n); + addend_lit = Literal::CreateR2FromArray2DWithLayout( + *addend_data, LayoutUtil::MakeLayout( + MinorToMajorForIsRowMajor(param.addend_row_major))); + addend_handle = client_->TransferToServer(*addend_lit).ConsumeValueOrDie(); + } ComputationBuilder builder(client_, TestName()); auto prim_type = primitive_util::NativeToPrimitiveType(); auto result = builder.Dot( - builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {M, K}), "lhs"), - builder.Parameter(1, ShapeUtil::MakeShape(prim_type, {K, N}), "rhs")); - - std::unique_ptr> expected = - ReferenceUtil::MatmulArray2D(*lhs_data, *rhs_data); - - ComputeAndCompareR2(&builder, *expected, - {lhs_handle.get(), rhs_handle.get()}, - ErrorSpec(0.3, 3e-3)); -} - -XLA_TEST_F(DotOperationTest, MatrixDotF32_12_117_7_MinorToMajorTF) { - TestMatrixDot(12, 117, 7, true, false); -} - -XLA_TEST_F(DotOperationTest, MatrixDotF32_12_117_7_MinorToMajorFT) { - TestMatrixDot(12, 117, 7, false, true); -} - -XLA_TEST_F(DotOperationTest, MatrixDotF32_12_117_7_MinorToMajorTT) { - TestMatrixDot(12, 117, 7, true, true); -} - -XLA_TEST_F(DotOperationTest, MatrixDotF32_12_117_7_MinorToMajorFF) { - TestMatrixDot(12, 117, 7, false, false); -} - -XLA_TEST_F(DotOperationTest, MatrixDotF32_270_270_520_MinorToMajorTT) { - TestMatrixDot(270, 270, 520, true, true); -} - -XLA_TEST_F(DotOperationTest, MatrixDotF32_270_270_520_MinorToMajorTF) { - TestMatrixDot(270, 270, 520, true, false); -} - -XLA_TEST_F(DotOperationTest, MatrixDotF32_270_270_520_MinorToMajorFT) { - TestMatrixDot(270, 270, 520, false, true); -} - -XLA_TEST_F(DotOperationTest, MatrixDotF32_270_270_520_MinorToMajorFF) { - TestMatrixDot(270, 270, 520, false, false); -} - -XLA_TEST_F(DotOperationTest, MatrixDotF32_260_3_520_MinorToMajorTT) { - TestMatrixDot(269, 3, 520, true, true); -} - -XLA_TEST_F(DotOperationTest, MatrixDotF32_260_3_520_MinorToMajorTF) { - TestMatrixDot(260, 3, 520, true, false); -} - -XLA_TEST_F(DotOperationTest, MatrixDotF32_260_3_520_MinorToMajorFT) { - TestMatrixDot(260, 3, 520, false, true); -} - -XLA_TEST_F(DotOperationTest, MatrixDotF32_260_3_520_MinorToMajorFF) { - TestMatrixDot(260, 3, 520, false, false); -} - -XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_1x8x8) { - TestMatrixDot(1, 8, 8, true, true); -} + builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {param.m, param.k}), + "dot_lhs"), + builder.Parameter(1, ShapeUtil::MakeShape(prim_type, {param.k, param.n}), + "dot_rhs")); + + if (param.has_addend) { + result = builder.Add( + result, + builder.Parameter( + 2, ShapeUtil::MakeShape(prim_type, {param.m, param.n}), "addend")); + } -XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_1x130x8) { - TestMatrixDot(1, 130, 8, true, true); -} + std::unique_ptr> expected; + if (param.has_addend) { + expected = ReferenceUtil::ApplyElementwise2D( + std::plus(), + *ReferenceUtil::MatmulArray2D(*dot_lhs_data, *dot_rhs_data), + *addend_data); + } else { + expected = ReferenceUtil::MatmulArray2D(*dot_lhs_data, *dot_rhs_data); + } -XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_1x8x130) { - TestMatrixDot(1, 8, 130, true, true); -} + std::vector args = {dot_lhs_handle.get(), dot_rhs_handle.get()}; + if (param.has_addend) { + args.push_back(addend_handle.get()); + } -XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_1x290x130) { - TestMatrixDot(1, 290, 130, true, true); + ComputeAndCompareR2(&builder, *expected, args, ErrorSpec(0.3, 3e-3)); } -XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_2x1x1) { - TestMatrixDot(2, 1, 1, true, true); -} +std::vector CreateDotTestParameters() { + std::vector params; -XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_8x8x1) { - TestMatrixDot(8, 8, 1, true, true); -} + auto add_matrix_matrix_dot_test = [&](int m, int k, int n) { + for (bool lhs_row_major : {true, false}) { + for (bool rhs_row_major : {true, false}) { + params.push_back({/*m=*/m, /*k=*/k, /*n=*/n, + /*dot_lhs_row_major=*/lhs_row_major, + /*dot_rhs_row_major=*/rhs_row_major, + /*has_addend=*/false, /*addend_row_major=*/true}); + } + } + }; + + auto add_matrix_vector_dot_test = [&](int k, int n) { + for (bool has_addend : {false, true}) { + params.push_back({/*m=*/1, /*k=*/k, /*n=*/n, + /*dot_lhs_row_major=*/true, /*dot_rhs_row_major=*/true, + /*has_addend=*/has_addend, /*addend_row_major=*/true}); + if (n != 1) { + params.push_back( + {/*m=*/n, /*k=*/k, /*n=*/1, + /*dot_lhs_row_major=*/true, /*dot_rhs_row_major=*/true, + /*has_addend=*/has_addend, /*addend_row_major=*/true}); + } + } + }; -XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_16x1x1) { - TestMatrixDot(16, 1, 1, true, true); -} + add_matrix_matrix_dot_test(/*m=*/12, /*k=*/117, /*n=*/7); + add_matrix_matrix_dot_test(/*m=*/270, /*k=*/270, /*n=*/520); + add_matrix_matrix_dot_test(/*m=*/260, /*k=*/3, /*n=*/520); -XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_16x3x1) { - TestMatrixDot(16, 3, 1, true, true); -} + add_matrix_vector_dot_test(/*k=*/8, /*n=*/8); + add_matrix_vector_dot_test(/*k=*/130, /*n=*/8); + add_matrix_vector_dot_test(/*k=*/8, /*n=*/130); + add_matrix_vector_dot_test(/*k=*/290, /*n=*/130); + add_matrix_vector_dot_test(/*k=*/1, /*n=*/1); + add_matrix_vector_dot_test(/*k=*/1, /*n=*/16); + add_matrix_vector_dot_test(/*k=*/3, /*n=*/16); + add_matrix_vector_dot_test(/*k=*/3, /*n=*/3); + add_matrix_vector_dot_test(/*k=*/29, /*n=*/29); + add_matrix_vector_dot_test(/*k=*/8, /*n=*/2); + add_matrix_vector_dot_test(/*k=*/2, /*n=*/8); + add_matrix_vector_dot_test(/*k=*/259, /*n=*/258); -XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_3x3x1) { - TestMatrixDot(3, 3, 1, true, true); + return params; } -XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_29x29x1) { - TestMatrixDot(29, 29, 1, true, true); -} +INSTANTIATE_TEST_CASE_P(DotTests, ParametricDotTest, + ::testing::ValuesIn(CreateDotTestParameters()), + PrintDotTestParam); -XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_1x8x2) { - TestMatrixDot(1, 8, 2, true, true); +XLA_TEST_F(DotOperationTest, SquareMatrixDotF32MinorToMajorFF) { + TestSquareMatrixDot(false, false); } -XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_1x2x8) { - TestMatrixDot(1, 2, 8, true, true); +XLA_TEST_F(DotOperationTest, SquareMatrixDotF32MinorToMajorFT) { + TestSquareMatrixDot(false, true); } -XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_259x258x1) { - TestMatrixDot(259, 258, 1, true, true); +XLA_TEST_F(DotOperationTest, SquareMatrixDotF32MinorToMajorTF) { + TestSquareMatrixDot(true, false); } -XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_259x258x1_FT) { - TestMatrixDot(259, 258, 1, false, true); +XLA_TEST_F(DotOperationTest, SquareMatrixDotF32MinorToMajorTT) { + TestSquareMatrixDot(true, true); } -XLA_TEST_F(DotOperationTest, SquareMatrixDotF32MinorToMajorFF) { - constexpr bool kLhsRowMajor = false; - constexpr bool kRhsRowMajor = false; - TestSquareMatrixDot(kLhsRowMajor, kRhsRowMajor); +XLA_TEST_F(DotOperationTest, SquareMatrixDotC64MinorToMajorFF) { + TestSquareMatrixDot(false, false); } -XLA_TEST_F(DotOperationTest, SquareMatrixDotF32MinorToMajorFT) { - TestSquareMatrixDot(false, true); +XLA_TEST_F(DotOperationTest, SquareMatrixDotC64MinorToMajorFT) { + TestSquareMatrixDot(false, true); } -XLA_TEST_F(DotOperationTest, SquareMatrixDotF32MinorToMajorTF) { - TestSquareMatrixDot(true, false); +XLA_TEST_F(DotOperationTest, SquareMatrixDotC64MinorToMajorTF) { + TestSquareMatrixDot(true, false); } -TEST_F(DotOperationTest, SquareMatrixDotF32MinorToMajorTT) { - constexpr bool kLhsRowMajor = true; - constexpr bool kRhsRowMajor = true; - TestSquareMatrixDot(kLhsRowMajor, kRhsRowMajor); +XLA_TEST_F(DotOperationTest, SquareMatrixDotC64MinorToMajorTT) { + TestSquareMatrixDot(true, true); } XLA_TEST_F(DotOperationTest, SquareMatrixDotF64) { @@ -561,5 +583,95 @@ TEST_F(DotOperationTest, TransposeFolding) { } } +TEST_F(DotOperationTest, DotOfConcatOptimizationWithConstLHS) { + auto prim_type = primitive_util::NativeToPrimitiveType(); + + std::unique_ptr> constant_lhs_array(new Array2D( + {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}})); + + ComputationBuilder builder(client_, TestName()); + auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array); + auto rhs_arg_0 = builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {2, 2}), + "rhs_arg_0"); + auto rhs_arg_1 = builder.Parameter(1, ShapeUtil::MakeShape(prim_type, {3, 2}), + "rhs_arg_1"); + auto rhs_arg_2 = builder.Parameter(2, ShapeUtil::MakeShape(prim_type, {1, 2}), + "rhs_arg_2"); + auto result = builder.Dot( + lhs_constant, builder.ConcatInDim({rhs_arg_0, rhs_arg_1, rhs_arg_2}, 0)); + + std::unique_ptr> arg_0_value_array( + new Array2D({{1.0, 2.0}, {3.0, 4.0}})); + std::unique_ptr> arg_1_value_array( + new Array2D({{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}})); + std::unique_ptr> arg_2_value_array( + new Array2D({{1.0, 2.0}})); + + TF_ASSERT_OK_AND_ASSIGN( + auto arg_0_value, + client_->TransferToServer( + *Literal::CreateR2FromArray2D(*arg_0_value_array))); + TF_ASSERT_OK_AND_ASSIGN( + auto arg_1_value, + client_->TransferToServer( + *Literal::CreateR2FromArray2D(*arg_1_value_array))); + TF_ASSERT_OK_AND_ASSIGN( + auto arg_2_value, + client_->TransferToServer( + *Literal::CreateR2FromArray2D(*arg_2_value_array))); + + Array2D expected({{53.0, 74.0}, {45.0, 66.0}}); + ComputeAndCompareR2( + &builder, expected, + {arg_0_value.get(), arg_1_value.get(), arg_2_value.get()}, error_spec_); +} + +TEST_F(DotOperationTest, DotOfConcatOptimizationWithConstRHS) { + auto prim_type = primitive_util::NativeToPrimitiveType(); + + std::unique_ptr> constant_rhs_array( + new Array2D({{1.0, 2.0}, + {3.0, 4.0}, + {5.0, 6.0}, + {6.0, 5.0}, + {4.0, 3.0}, + {2.0, 1.0}})); + + ComputationBuilder builder(client_, TestName()); + auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array); + auto lhs_arg_0 = builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {2, 2}), + "lhs_arg_0"); + auto lhs_arg_1 = builder.Parameter(1, ShapeUtil::MakeShape(prim_type, {2, 3}), + "lhs_arg_1"); + auto lhs_arg_2 = builder.Parameter(2, ShapeUtil::MakeShape(prim_type, {2, 1}), + "lhs_arg_2"); + auto result = builder.Dot( + builder.ConcatInDim({lhs_arg_0, lhs_arg_1, lhs_arg_2}, 1), rhs_constant); + + std::unique_ptr> arg_0_value_array( + new Array2D({{1.0, 2.0}, {3.0, 4.0}})); + std::unique_ptr> arg_1_value_array( + new Array2D({{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}})); + std::unique_ptr> arg_2_value_array( + new Array2D({{1.0}, {2.0}})); + + TF_ASSERT_OK_AND_ASSIGN( + auto arg_0_value, + client_->TransferToServer( + *Literal::CreateR2FromArray2D(*arg_0_value_array))); + TF_ASSERT_OK_AND_ASSIGN( + auto arg_1_value, + client_->TransferToServer( + *Literal::CreateR2FromArray2D(*arg_1_value_array))); + TF_ASSERT_OK_AND_ASSIGN( + auto arg_2_value, + client_->TransferToServer( + *Literal::CreateR2FromArray2D(*arg_2_value_array))); + + Array2D expected({{38.0, 36.0}, {93.0, 91.0}}); + ComputeAndCompareR2( + &builder, expected, + {arg_0_value.get(), arg_1_value.get(), arg_2_value.get()}, error_spec_); +} } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc index 8baaf39e3cf8fa7f6fa4a0224c1297f82e0d92aa..ae3f887240d0ccffcc9c51a2c409de457a94f967 100644 --- a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc +++ b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc @@ -51,12 +51,16 @@ class DynamicSliceTest : public ClientLibraryTestBase { RunR1({0, 1, 2, 3, 4, 5, 6, 7}, {2}, {3}, {2, 3, 4}); // Slice at dimension boundaries. RunR1({0, 1, 2, 3, 4, 5, 6, 7}, {5}, {3}, {5, 6, 7}); - // Slice at dimension boundaries, but with sizes that cause indices to wrap. - RunR1({0, 1, 2, 3, 4, 5, 6, 7}, {6}, {4}, {6, 7, 0, 1}); // Zero element slice. RunR1({0, 1, 2, 3, 4, 5, 6, 7}, {2}, {0}, {}); } + template + void TestR1Wrap() { + // Slice at dimension boundaries, but with sizes that cause indices to wrap. + RunR1({0, 1, 2, 3, 4, 5, 6, 7}, {6}, {4}, {6, 7, 0, 1}); + } + template void TestR2() { // Slice at dimension start. @@ -68,15 +72,19 @@ class DynamicSliceTest : public ClientLibraryTestBase { // Slice at dimension boundaries. RunR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {1, 1}, {2, 1}, {{5}, {8}}); - // Slice at dimension boundaries, but with sizes that cause indices to wrap. - RunR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {1, 1}, {3, 3}, - {{5, 6, 4}, {8, 9, 7}, {2, 3, 1}}); // Zero element slice: 2x0. RunR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {0, 0}, {2, 0}, {{}, {}}); // Zero element slice: 0x2. RunR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {0, 0}, {0, 2}, - Array2D(0, 2)); + Array2D(0, 2)); + } + + template + void TestR2Wrap() { + // Slice at dimension boundaries, but with sizes that cause indices to wrap. + RunR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {1, 1}, {3, 3}, + {{5, 6, 4}, {8, 9, 7}, {2, 3, 1}}); } template @@ -97,85 +105,119 @@ class DynamicSliceTest : public ClientLibraryTestBase { {{7, 8}, {9, 10}, {11, 12}}}, {0, 1, 1}, {2, 2, 1}, {{{4}, {6}}, {{10}, {12}}}); + // clang-format on + } + template + void TestR3Wrap() { // Slice at dimension boundaries, but with sizes that cause indices to wrap. RunR3( {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}}, {0, 2, 1}, {2, 1, 2}, {{{6, 5}}, {{12, 11}}}); - - // clang-format on } template - void RunR1(tensorflow::gtl::ArraySlice input_values, + void RunR1(tensorflow::gtl::ArraySlice input_values_int, const std::vector slice_starts, const std::vector& slice_sizes, - tensorflow::gtl::ArraySlice expected_values) { + tensorflow::gtl::ArraySlice expected_values_int) { + // bfloat16 has explicit constructors, so it does not implicitly convert the + // way built-in types do, which is why we can't take the parameter as an + // ArraySlice. We also can't convert it to a vector, because + // vector is special so that it cannot be an ArraySlice, which + // is what the code below wants. So instead we do this. + Literal input_values = + std::move(*Literal::CreateR1(input_values_int) + ->Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); + Literal expected_values = + std::move(*Literal::CreateR1(expected_values_int) + ->Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); + ComputationBuilder builder(client_, TestName()); // Initialize and transfer dynamic slice start indices parameter. ComputationDataHandle starts; std::unique_ptr start_data = CreateR1Parameter( slice_starts, 0, "slice_starts", &builder, &starts); // Build dynamic slice computation. - auto input = builder.ConstantR1(input_values); + auto input = builder.ConstantLiteral(input_values); builder.DynamicSlice(input, starts, slice_sizes); // Run computation and compare against expected values. - ComputeAndCompareR1(&builder, expected_values, {start_data.get()}); + ComputeAndCompareLiteral(&builder, expected_values, {start_data.get()}); } template - void RunR2(const Array2D& input_values, + void RunR2(const Array2D& input_values_int, const std::vector slice_starts, const std::vector& slice_sizes, - const Array2D& expected_values) { + const Array2D& expected_values_int) { + Literal input_values = + std::move(*Literal::CreateR2FromArray2D(input_values_int) + ->Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); + Literal expected_values = + std::move(*Literal::CreateR2FromArray2D(expected_values_int) + ->Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); + ComputationBuilder builder(client_, TestName()); // Initialize and transfer dynamic slice start indices parameter. ComputationDataHandle starts; std::unique_ptr start_data = CreateR1Parameter( slice_starts, 0, "slice_starts", &builder, &starts); // Build dynamic slice computation. - auto input = builder.ConstantR2FromArray2D(input_values); + auto input = builder.ConstantLiteral(input_values); builder.DynamicSlice(input, starts, slice_sizes); // Run computation and compare against expected values. - ComputeAndCompareR2(&builder, expected_values, {start_data.get()}); + ComputeAndCompareLiteral(&builder, expected_values, {start_data.get()}); } template - void RunR3(const Array3D& input_values, + void RunR3(const Array3D& input_values_int, const std::vector slice_starts, const std::vector& slice_sizes, - const Array3D& expected_values) { + const Array3D& expected_values_int) { + Literal input_values = + std::move(*Literal::CreateR3FromArray3D(input_values_int) + ->Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); + Literal expected_values = + std::move(*Literal::CreateR3FromArray3D(expected_values_int) + ->Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); + ComputationBuilder builder(client_, TestName()); // Initialize and transfer dynamic slice start indices parameter. ComputationDataHandle starts; std::unique_ptr start_data = CreateR1Parameter( slice_starts, 0, "slice_starts", &builder, &starts); // Build dynamic slice computation. - auto input = builder.ConstantR3FromArray3D(input_values); + auto input = builder.ConstantLiteral(input_values); builder.DynamicSlice(input, starts, slice_sizes); // Run computation and compare against expected values. - ComputeAndCompareR3(&builder, expected_values, {start_data.get()}); + ComputeAndCompareLiteral(&builder, expected_values, {start_data.get()}); } }; +XLA_TEST_F(DynamicSliceTest, Int32R1BF16) { TestR1(); } XLA_TEST_F(DynamicSliceTest, Int32R1) { TestR1(); } - +XLA_TEST_F(DynamicSliceTest, Int32R1Wrap) { TestR1Wrap(); } XLA_TEST_F(DynamicSliceTest, Int64R1) { TestR1(); } - XLA_TEST_F(DynamicSliceTest, UInt64R1) { TestR1(); } -XLA_TEST_F(DynamicSliceTest, Int32R2) { TestR2(); } - +XLA_TEST_F(DynamicSliceTest, Int32R2BF16) { TestR2(); } +XLA_TEST_F(DynamicSliceTest, Int32R2) { TestR2(); } +XLA_TEST_F(DynamicSliceTest, Int32R2Wrap) { TestR2Wrap(); } XLA_TEST_F(DynamicSliceTest, Int64R2) { TestR2(); } - XLA_TEST_F(DynamicSliceTest, UInt64R2) { TestR2(); } -XLA_TEST_F(DynamicSliceTest, Int32R3) { TestR3(); } - +XLA_TEST_F(DynamicSliceTest, Int32R3BF16) { TestR3(); } +XLA_TEST_F(DynamicSliceTest, Int32R3) { TestR3(); } +XLA_TEST_F(DynamicSliceTest, Int32R3Wrap) { TestR3Wrap(); } XLA_TEST_F(DynamicSliceTest, Int64R3) { TestR3(); } - XLA_TEST_F(DynamicSliceTest, UInt64R3) { TestR3(); } XLA_TEST_F(DynamicSliceTest, Int32R1Pred) { @@ -213,7 +255,7 @@ XLA_TEST_F(DynamicSliceTest, Int32R2Pred) { // Zero element slice: 0x2. RunR2( {{true, false, true}, {false, false, true}, {true, true, false}}, {0, 0}, - {0, 2}, Array2D(0, 2)); + {0, 2}, Array2D(0, 2)); } XLA_TEST_F(DynamicSliceTest, Int32R3Pred) { @@ -300,107 +342,154 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { } template - void RunR1(tensorflow::gtl::ArraySlice input_values, - tensorflow::gtl::ArraySlice update_values, + void RunR1(tensorflow::gtl::ArraySlice input_values_int, + tensorflow::gtl::ArraySlice update_values_int, const std::vector slice_starts, - tensorflow::gtl::ArraySlice expected_values) { + tensorflow::gtl::ArraySlice expected_values_int) { + Literal input_values = + std::move(*Literal::CreateR1(input_values_int) + ->Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); + Literal update_values = + std::move(*Literal::CreateR1(update_values_int) + ->Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); + Literal expected_values = + std::move(*Literal::CreateR1(expected_values_int) + ->Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); + ComputationBuilder builder(client_, TestName()); // Initialize and transfer dynamic slice start indices parameter. ComputationDataHandle starts; std::unique_ptr start_data = CreateR1Parameter( slice_starts, 0, "slice_starts", &builder, &starts); // Build dynamic slice computation. - auto input = builder.ConstantR1(input_values); - auto update = builder.ConstantR1(update_values); + auto input = builder.ConstantLiteral(input_values); + auto update = builder.ConstantLiteral(update_values); builder.DynamicUpdateSlice(input, update, starts); // Run computation and compare against expected values. - ComputeAndCompareR1(&builder, expected_values, {start_data.get()}); + ComputeAndCompareLiteral(&builder, expected_values, {start_data.get()}); } template - void RunR2(const Array2D& input_values, - const Array2D& update_values, + void RunR2(const Array2D& input_values_int, + const Array2D& update_values_int, const std::vector slice_starts, - const Array2D& expected_values) { + const Array2D& expected_values_int) { + Literal input_values = + std::move(*Literal::CreateR2FromArray2D(input_values_int) + ->Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); + Literal update_values = + std::move(*Literal::CreateR2FromArray2D(update_values_int) + ->Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); + Literal expected_values = + std::move(*Literal::CreateR2FromArray2D(expected_values_int) + ->Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); + ComputationBuilder builder(client_, TestName()); // Initialize and transfer dynamic slice start indices parameter. ComputationDataHandle starts; std::unique_ptr start_data = CreateR1Parameter( slice_starts, 0, "slice_starts", &builder, &starts); // Build dynamic slice computation. - auto input = builder.ConstantR2FromArray2D(input_values); - auto update = builder.ConstantR2FromArray2D(update_values); + auto input = builder.ConstantLiteral(input_values); + auto update = builder.ConstantLiteral(update_values); builder.DynamicUpdateSlice(input, update, starts); // Run computation and compare against expected values. - ComputeAndCompareR2(&builder, expected_values, {start_data.get()}); + ComputeAndCompareLiteral(&builder, expected_values, {start_data.get()}); } template - void RunR3(const Array3D& input_values, - const Array3D& update_values, + void RunR3(const Array3D& input_values_int, + const Array3D& update_values_int, const std::vector slice_starts, - const Array3D& expected_values) { + const Array3D& expected_values_int) { + Literal input_values = + std::move(*Literal::CreateR3FromArray3D(input_values_int) + ->Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); + Literal update_values = + std::move(*Literal::CreateR3FromArray3D(update_values_int) + ->Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); + Literal expected_values = + std::move(*Literal::CreateR3FromArray3D(expected_values_int) + ->Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); + ComputationBuilder builder(client_, TestName()); // Initialize and transfer dynamic slice start indices parameter. ComputationDataHandle starts; std::unique_ptr start_data = CreateR1Parameter( slice_starts, 0, "slice_starts", &builder, &starts); // Build dynamic slice computation. - auto input = builder.ConstantR3FromArray3D(input_values); - auto update = builder.ConstantR3FromArray3D(update_values); + auto input = builder.ConstantLiteral(input_values); + auto update = builder.ConstantLiteral(update_values); builder.DynamicUpdateSlice(input, update, starts); // Run computation and compare against expected values. - ComputeAndCompareR3(&builder, expected_values, {start_data.get()}); + ComputeAndCompareLiteral(&builder, expected_values, {start_data.get()}); } + template void RunR3Contiguous(std::vector operand_shape, int32 index, int32 size) { +#ifdef XLA_TEST_BACKEND_CPU_PARALLEL + // TODO(b/71820067): The CPU parallel backend failed for this on 2018-01-10. + if (std::is_same::value) { + return; + } +#endif + const int32 kSeq = operand_shape[0]; const int32 kBatch = operand_shape[1]; const int32 kDim = operand_shape[2]; - Array3D input_values(kSeq, kBatch, kDim); - Array3D update_values(size, kBatch, kDim); - Array3D expected_values(kSeq, kBatch, kDim); + Array3D input_values(kSeq, kBatch, kDim); + Array3D update_values(size, kBatch, kDim); + Array3D expected_values(kSeq, kBatch, kDim); - input_values.FillIota(0); - float val = 1000; - update_values.FillIota(val); + input_values.FillIota(static_cast(0)); + T value = static_cast(10); + update_values.FillIota(static_cast(value)); // TODO(b/34128753) Expected values may vary depending on backend when // the update wraps. According to documentation, the results are technically // implementation specific where the update is out of bounds, and hence // we don't really know what to pass into ComputeAndCompareR3. - expected_values.FillIota(0); + expected_values.FillIota(static_cast(0)); for (int i = 0; i < size; i++) { for (int j = 0; j < kBatch; j++) { for (int k = 0; k < kDim; k++) { - expected_values((index + i) % kSeq, j, k) = val++; + expected_values((index + i) % kSeq, j, k) = value++; } } } if (VLOG_IS_ON(1)) { - DumpArray("input", input_values); - DumpArray("update", update_values); - DumpArray("expected", expected_values); + DumpArray("input", input_values); + DumpArray("update", update_values); + DumpArray("expected", expected_values); } // Build dynamic slice computation. ComputationBuilder builder(client_, TestName()); // Initialize and transfer input parameter. ComputationDataHandle input; - std::unique_ptr input_data = CreateR3Parameter( - input_values, 0, "input_values", &builder, &input); + std::unique_ptr input_data = + CreateR3Parameter(input_values, 0, "input_values", &builder, &input); // Initialize and transfer update parameter. ComputationDataHandle update; - std::unique_ptr update_data = CreateR3Parameter( + std::unique_ptr update_data = CreateR3Parameter( update_values, 1, "update_values", &builder, &update); auto starts = builder.ConstantR1({index, 0, 0}); builder.DynamicUpdateSlice(input, update, starts); // Run computation and compare against expected values. - ComputeAndCompareR3(&builder, expected_values, - {input_data.get(), update_data.get()}, - ErrorSpec(0.000001)); + ComputeAndCompareR3(&builder, expected_values, + {input_data.get(), update_data.get()}, + ErrorSpec(0.000001)); } template @@ -411,28 +500,35 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { } }; +// TODO(b/71820067): The CPU parallel backend failed for this on 2018-01-10. +XLA_TEST_F(DynamicUpdateSliceTest, DISABLED_ON_CPU_PARALLEL(Int32R1BF16)) { + TestR1(); +} XLA_TEST_F(DynamicUpdateSliceTest, Int32R1) { TestR1(); } - XLA_TEST_F(DynamicUpdateSliceTest, Int64R1) { TestR1(); } - XLA_TEST_F(DynamicUpdateSliceTest, UInt64R1) { TestR1(); } +// TODO(b/71820067): The CPU parallel backend failed for this on 2018-01-10. +XLA_TEST_F(DynamicUpdateSliceTest, DISABLED_ON_CPU_PARALLEL(Int32R2BF16)) { + TestR2(); +} XLA_TEST_F(DynamicUpdateSliceTest, Int32R2) { TestR2(); } - XLA_TEST_F(DynamicUpdateSliceTest, Int64R2) { TestR2(); } - XLA_TEST_F(DynamicUpdateSliceTest, UInt64R2) { TestR2(); } +// TODO(b/71820067): The CPU parallel backend failed for this on 2018-01-10. +XLA_TEST_F(DynamicUpdateSliceTest, DISABLED_ON_CPU_PARALLEL(Int32R3BF16)) { + TestR3(); +} XLA_TEST_F(DynamicUpdateSliceTest, Int32R3) { TestR3(); } - XLA_TEST_F(DynamicUpdateSliceTest, Int64R3) { TestR3(); } - XLA_TEST_F(DynamicUpdateSliceTest, UInt64R3) { TestR3(); } +XLA_TEST_F(DynamicUpdateSliceTest, DISABLED_ON_CPU_PARALLEL(Int32WrapBF16)) { + TestWrap(); +} XLA_TEST_F(DynamicUpdateSliceTest, Int32Wrap) { TestWrap(); } - XLA_TEST_F(DynamicUpdateSliceTest, Int64Wrap) { TestWrap(); } - XLA_TEST_F(DynamicUpdateSliceTest, UInt64Wrap) { TestWrap(); } XLA_TEST_F(DynamicUpdateSliceTest, Int32R1Pred) { @@ -498,36 +594,42 @@ XLA_TEST_F(DynamicUpdateSliceTest, Int32R3Pred) { XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousSingleElement) { // Single element, no wrap. std::vector operand_shape({4, 5, 2}); - RunR3Contiguous(operand_shape, /*index=*/1, /*size=*/1); + RunR3Contiguous(operand_shape, /*index=*/1, /*size=*/1); + RunR3Contiguous(operand_shape, /*index=*/1, /*size=*/1); } XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousMultipleElements) { // Multiple element, no wrap. std::vector operand_shape({4, 5, 2}); - RunR3Contiguous(operand_shape, /*index=*/1, /*size=*/2); + RunR3Contiguous(operand_shape, /*index=*/1, /*size=*/2); + RunR3Contiguous(operand_shape, /*index=*/1, /*size=*/2); } XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousMultipleWrapping) { // Multiple element, wrapping. std::vector operand_shape({4, 5, 2}); - RunR3Contiguous(operand_shape, /*index=*/3, /*size=*/2); + RunR3Contiguous(operand_shape, /*index=*/3, /*size=*/2); + RunR3Contiguous(operand_shape, /*index=*/3, /*size=*/2); } XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousTooLarge) { // Multiple element, update size larger than operand. std::vector operand_shape({4, 5, 2}); - RunR3Contiguous(operand_shape, /*index=*/5, /*size=*/2); + RunR3Contiguous(operand_shape, /*index=*/5, /*size=*/2); + RunR3Contiguous(operand_shape, /*index=*/5, /*size=*/2); } XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousUnaligned) { std::vector operand_shape({3, 123, 247}); - RunR3Contiguous(operand_shape, /*index=*/1, /*size=*/1); + RunR3Contiguous(operand_shape, /*index=*/1, /*size=*/1); + RunR3Contiguous(operand_shape, /*index=*/1, /*size=*/1); } // TODO(b/34134076) Disabled on GPU 2016-01-06 due to out-of-memory error. XLA_TEST_F(DynamicUpdateSliceTest, DISABLED_ON_GPU(R3ContiguousLarger)) { std::vector operand_shape({32, 128, 1024}); - RunR3Contiguous(operand_shape, /*index=*/7, /*size=*/1); + RunR3Contiguous(operand_shape, /*index=*/7, /*size=*/1); + RunR3Contiguous(operand_shape, /*index=*/7, /*size=*/1); } void BM_DynamicSlice(int num_iters) { @@ -559,20 +661,20 @@ void BM_DynamicSlice(int num_iters) { auto computation = builder.Build().ConsumeValueOrDie(); // Initialize and transfer parameter buffer. - auto shape_size_fn = [client](const Shape& shape) { - return client->backend().transfer_manager()->GetByteSizeRequirement(shape); - }; - auto buffer = ScopedShapedBuffer::Allocate(start_indices_shape, &allocator, 0, - shape_size_fn) + auto buffer = client->backend() + .transfer_manager() + ->AllocateScopedShapedBuffer( + start_indices_shape, &allocator, /*device_ordinal=*/0) .ConsumeValueOrDie(); auto start_indices_literal = Literal::CreateR1({0, 1, 2, 3}); ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice( - executors[device_ordinal], *start_indices_literal, - buffer->mutable_buffer({}))); + executors[device_ordinal], *start_indices_literal, *buffer)); std::unique_ptr executable = - client->Compile(computation, {&buffer->shape()}, ExecutableBuildOptions()) + client + ->Compile(computation, {&buffer->on_host_shape()}, + ExecutableBuildOptions()) .ConsumeValueOrDie(); // Run some warm-up executions. diff --git a/tensorflow/compiler/xla/tests/execution_profile_test.cc b/tensorflow/compiler/xla/tests/execution_profile_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..644cbbf40f296eb2a574ae568b4f32aa3d0bd12f --- /dev/null +++ b/tensorflow/compiler/xla/tests/execution_profile_test.cc @@ -0,0 +1,71 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/global_data.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace { + +class ExecutionProfileTest : public ClientLibraryTestBase {}; + +XLA_TEST_F(ExecutionProfileTest, + DISABLED_ON_CPU_PARALLEL(ExecuteWithExecutionProfile)) { + Shape shape = ShapeUtil::MakeShape(F32, {256, 256}); + + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr input, + client_->TransferToServer( + *Literal::CreateR2F32Linspace(1e0, 1e5, 256, 256))); + + ComputationBuilder b(client_, TestName() + ".add"); + b.Dot(b.Parameter(0, shape, "param_0"), b.Parameter(1, shape, "param_1")); + TF_ASSERT_OK_AND_ASSIGN(Computation dot_product, b.Build()); + + ExecutionProfile execution_profile; + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr data, + client_->Execute(dot_product, {input.get(), input.get()}, + &execution_options_, &execution_profile)); + + VLOG(3) << "execution_profile.compute_cycle_count() = " + << execution_profile.compute_cycle_count(); + VLOG(3) << "execution_profile.compute_and_transfer_time_ns() = " + << execution_profile.compute_and_transfer_time_ns(); + VLOG(3) << "execution_profile.compute_time_ns() = " + << execution_profile.compute_time_ns(); + + bool hlo_profiling_enabled = + execution_options_.debug_options().xla_hlo_profile(); + + // If HLO profiling is enabled we always expect cycle count to be populated. + // If HLO profiling is disabled then depending on the backend the cycle count + // may or may not be populated. + if (hlo_profiling_enabled) { + EXPECT_GT(execution_profile.compute_cycle_count(), 0); + } + + EXPECT_GT(execution_profile.compute_and_transfer_time_ns(), 0); + EXPECT_GT(execution_profile.compute_time_ns(), 0); + + TF_ASSERT_OK_AND_ASSIGN(auto computed, client_->Transfer(*data, &shape)); + (void)computed; +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/fusion_test.cc b/tensorflow/compiler/xla/tests/fusion_test.cc index 2686afccc216095345dbb7b43e916fbbe7c8ea39..a292eab1d198fbf69c6dc81c780487ea46756f72 100644 --- a/tensorflow/compiler/xla/tests/fusion_test.cc +++ b/tensorflow/compiler/xla/tests/fusion_test.cc @@ -816,7 +816,8 @@ void BM_ParallelFusion(int num_iters) { std::unique_ptr executable = client ->Compile(computation, - {&buffer0->shape(), &buffer1->shape(), &buffer2->shape()}, + {&buffer0->on_host_shape(), &buffer1->on_host_shape(), + &buffer2->on_host_shape()}, ExecutableBuildOptions()) .ConsumeValueOrDie(); diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc index d73c05ff92578209143e0679558848160cae99bd..7c1a993b478a0e0878e85c0e4192da053e33619f 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc @@ -15,13 +15,22 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include #include #include #include +#include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/platform_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" +#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -30,44 +39,235 @@ namespace se = ::perftools::gputools; namespace xla { +namespace { + +using tensorflow::StringPiece; +using tensorflow::gtl::ArraySlice; +using tensorflow::gtl::optional; + +constexpr char kInterpreter[] = "interpreter"; + +// Helper functions to get test and reference platforms. +se::Platform* GetReferencePlatform() { + auto result = PlatformUtil::GetPlatform(kInterpreter); + TF_CHECK_OK(result.status()) << "could not get interpreter platform"; + return result.ValueOrDie(); +} + +se::Platform* GetTestPlatform() { + auto result = PlatformUtil::GetDefaultPlatform(); + TF_CHECK_OK(result.status()) << "could not get test platform"; + return result.ValueOrDie(); +} + +bool ProgramShapesEqual(const ProgramShape& lhs, const ProgramShape& rhs) { + if (lhs.parameters_size() != rhs.parameters_size()) { + return false; + } + for (int i = 0; i < lhs.parameters_size(); i++) { + if (!ShapeUtil::Equal(lhs.parameters(i), rhs.parameters(i))) { + return false; + } + } + return ShapeUtil::Equal(lhs.result(), rhs.result()); +} + +ProgramShape GetProgramShapeWithLayout(const HloModule& module) { + ProgramShape program_shape; + const auto* entry = module.entry_computation(); + for (const auto* param : entry->parameter_instructions()) { + *program_shape.add_parameters() = param->shape(); + *program_shape.add_parameter_names() = param->name(); + } + *program_shape.mutable_result() = entry->root_instruction()->shape(); + return program_shape; +} + +} // namespace + +HloTestBase::HloTestBase() + : HloTestBase(GetTestPlatform(), GetReferencePlatform()) {} + +HloTestBase::HloTestBase(se::Platform* test_platform, + se::Platform* reference_platform) + : test_runner_(test_platform), reference_runner_(reference_platform) { + hlo_verifier_ = MakeUnique(); +} + /* static */ std::unique_ptr HloTestBase::CreateNewModule() { HloModuleConfig config; + config.set_debug_options(GetDebugOptionsForTest()); + return MakeUnique(TestName(), VersionedComputationHandle(), + config); +} +/*static*/ DebugOptions HloTestBase::GetDebugOptionsForTest() { auto debug_options = legacy_flags::GetDebugOptionsFromFlags(); // TODO(b/38354253): Change tests to use Parameters instead of Constants. debug_options.add_xla_disable_hlo_passes("constant_folding"); + return debug_options; +} - config.set_debug_options(debug_options); - - return MakeUnique(TestName(), VersionedComputationHandle(), - config); +StatusOr> HloTestBase::Execute( + std::unique_ptr module, + tensorflow::gtl::ArraySlice arguments) { + return test_runner_.Execute(std::move(module), arguments); } -StatusOr HloTestBase::Execute( +std::unique_ptr HloTestBase::ExecuteAndTransfer( std::unique_ptr module, - tensorflow::gtl::ArraySlice - arguments, - Shape* result_shape) { - return runner_.Execute(std::move(module), arguments, result_shape); + tensorflow::gtl::ArraySlice arguments) { + return test_runner_.Execute(std::move(module), arguments).ValueOrDie(); } -se::DeviceMemoryBase HloTestBase::TransferToDevice(const Literal& literal) { - return runner_.TransferToDevice(literal).ValueOrDie(); +StatusOr> HloTestBase::MakeReferenceModule( + const HloModule& test_module, + const std::function& reference_preprocessor) { + std::unique_ptr reference_module = test_module.Clone(); + const auto& program_shape = GetProgramShapeWithLayout(test_module); + + if (reference_preprocessor != nullptr) { + reference_preprocessor(reference_module.get()); + if (!ProgramShapesEqual(program_shape, + GetProgramShapeWithLayout(*reference_module))) { + return InvalidArgument( + "reference preprocessor must not modify the program shape"); + } + } + TF_RETURN_IF_ERROR(VerifyHloModule(*reference_runner_.backend().platform(), + reference_module.get())); + return std::move(reference_module); } -std::unique_ptr HloTestBase::TransferFromDevice( - const Shape& shape, se::DeviceMemoryBase device_base) { - return runner_.TransferFromDevice(shape, device_base).ValueOrDie(); +template +StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal( + std::unique_ptr module, const ArraySlice arguments, + const optional& error, bool run_hlo_passes, + const std::function& reference_preprocessor) { + static_assert( + std::is_same::value || + std::is_same, LiteralPtr>::value, + "The LiteralPtr type only accepts Literal* or std::unique_ptr."); + TF_RETURN_IF_ERROR( + VerifyHloModule(*test_runner_.backend().platform(), module.get())); + TF_ASSIGN_OR_RETURN(auto reference_module, + MakeReferenceModule(*module, reference_preprocessor)); + + // Execute on two backends. + TF_ASSIGN_OR_RETURN( + auto test, + test_runner_.Execute(std::move(module), arguments, run_hlo_passes)); + TF_ASSIGN_OR_RETURN(auto reference, + reference_runner_.Execute(std::move(reference_module), + arguments, run_hlo_passes)); + return LiteralTestUtil::NearOrEqual(/*expected=*/*reference, /*actual=*/*test, + error); } -std::unique_ptr HloTestBase::ExecuteAndTransfer( - std::unique_ptr module, - tensorflow::gtl::ArraySlice arguments) { - return runner_.ExecuteAndTransfer(std::move(module), arguments).ValueOrDie(); +template +::testing::AssertionResult HloTestBase::RunAndCompare( + std::unique_ptr module, const ArraySlice arguments, + const optional& error, + const std::function& reference_preprocessor) { + auto result = + RunAndCompareInternal(std::move(module), arguments, error, + /*run_hlo_passes=*/true, reference_preprocessor); + if (!result.ok()) { + return ::testing::AssertionFailure() << result.status(); + } + return result.ValueOrDie(); +} + +template +::testing::AssertionResult HloTestBase::RunAndCompareNoHloPasses( + std::unique_ptr module, const ArraySlice arguments, + const optional& error, + const std::function& reference_preprocessor) { + auto result = + RunAndCompareInternal(std::move(module), arguments, error, + /*run_hlo_passes=*/false, reference_preprocessor); + if (!result.ok()) { + return ::testing::AssertionFailure() << result.status(); + } + return result.ValueOrDie(); +} + +::testing::AssertionResult HloTestBase::RunAndCompare( + std::unique_ptr module, const optional& error, + const std::function& reference_preprocessor) { + const auto& fake_arguments = + MakeFakeArguments(module.get()).ConsumeValueOrDie(); + return RunAndCompare>( + std::move(module), fake_arguments, error, reference_preprocessor); +} + +::testing::AssertionResult HloTestBase::RunAndCompareNoHloPasses( + std::unique_ptr module, const optional& error, + const std::function& reference_preprocessor) { + const auto& fake_arguments = + MakeFakeArguments(module.get()).ConsumeValueOrDie(); + return RunAndCompareNoHloPasses>( + std::move(module), fake_arguments, error, reference_preprocessor); +} + +::testing::AssertionResult HloTestBase::RunAndCompare( + const StringPiece hlo_string, + const tensorflow::gtl::optional& error, + const std::function& reference_preprocessor) { + auto module_or_status = + HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest()); + if (!module_or_status.ok()) { + return ::testing::AssertionFailure() + << "Error while parsing HLO text format: " + << module_or_status.status().ToString(); + } + return RunAndCompare(module_or_status.ConsumeValueOrDie(), error, + reference_preprocessor); +} + +::testing::AssertionResult HloTestBase::RunAndCompareFromFile( + const string& filename, const tensorflow::gtl::optional& error, + const std::function& reference_preprocessor) { + auto module_or_status = + HloRunner::ReadModule(filename, GetDebugOptionsForTest()); + if (!module_or_status.ok()) { + return ::testing::AssertionFailure() + << "failed reading hlo module from file"; + } + return RunAndCompare(module_or_status.ConsumeValueOrDie(), error, + reference_preprocessor); +} + +::testing::AssertionResult HloTestBase::RunAndCompareNoHloPasses( + const StringPiece hlo_string, + const tensorflow::gtl::optional& error, + const std::function& reference_preprocessor) { + auto module_or_status = + HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest()); + if (!module_or_status.ok()) { + return ::testing::AssertionFailure() + << "Error while parsing HLO text format: " + << module_or_status.status().ToString(); + } + return RunAndCompareNoHloPasses(module_or_status.ConsumeValueOrDie(), error, + reference_preprocessor); +} + +::testing::AssertionResult HloTestBase::RunAndCompareNoHloPassesFromFile( + const string& filename, const tensorflow::gtl::optional& error, + const std::function& reference_preprocessor) { + auto module_or_status = + HloRunner::ReadModule(filename, GetDebugOptionsForTest()); + if (!module_or_status.ok()) { + return ::testing::AssertionFailure() + << "failed reading hlo module from file"; + } + return RunAndCompareNoHloPasses(module_or_status.ConsumeValueOrDie(), error, + reference_preprocessor); } -Backend& HloTestBase::backend() { return runner_.backend(); } +Backend& HloTestBase::backend() { return test_runner_.backend(); } /* static */ string HloTestBase::TestName() { diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h index 7f068dce36be3546298de2f06bf6d33446d07ca2..4aea9fc9fd027231106e529eb16bcd43f23fbe1c 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_test_base.h @@ -24,52 +24,150 @@ limitations under the License. #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_runner.h" +#include "tensorflow/compiler/xla/service/hlo_verifier.h" +#include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/shape_layout.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/test.h" namespace xla { -// A base class for tests which build and run HLO code. This is a lower level of -// abstraction than using the client interface and enables, for one, explicitly -// building a graph of HLO instructions to run. +// A base class for tests which build and/or run HLO code. The class includes +// support for running an HLO module on two platforms and compare the results. +// This is a lower level of abstraction than using the client interface and +// enables, for one, explicitly building a graph of HLO instructions to run. +// +// This can also be used to write text/file-based test cases. Note that the test +// target is responsible for linking the needed backends. A covenient way to do +// this is to make it an xla_test: it will generate test targets linking with +// the respective backends, which will be used as the test backend; the +// interpreter backend is already linked with hlo_test_base so it will be the +// default reference backend. For example, if you want to compare both cpu vs. +// interpreter, and gpu vs. interpreter, you can: +// +// xla_test ( +// name = "sample_text_test", +// srcs = ["sample_text_test.cc"], +// backends = [ +// "cpu", +// "gpu", +// ], +// deps = [ +// "//third_party/tensorflow/compiler/xla/tests:hlo_test_base", +// ... +// ], +// ) +// +// For a more detailed example, see "../tests/sample_text_test.cc". class HloTestBase : public ::testing::Test { protected: - HloTestBase() {} + // This uses the interpreter backend as the reference backend and + // automatically finds another supported backend as the test backend. If the + // interpreter is the only supported backend, it will be both the test backend + // and the reference backend. + HloTestBase(); + + // If your test doesn't use interpreter as the reference backend, you can use + // this constructor. Note that your test target is responsible for linking in + // both needed backends. + HloTestBase(::perftools::gputools::Platform* test_platform, + ::perftools::gputools::Platform* reference_platform); ~HloTestBase() override {} // Creates a new HLO module for a test. The module created will have // TestName() for its name; it will also automatically populate its debug - // options from command-line flags. It's recommended to use this method to - // create all HloModules for tests. + // options from command-line flags. If you want a fresh HloModule object and + // then add HloComputations to it, it's recommended to use this method in your + // tests. static std::unique_ptr CreateNewModule(); - // Executes the given module and returns a global data handle. - StatusOr Execute( + // Populates debug options from command-line flags and adjusts the options for + // testing. It is recommended to use this when you need to pass in + // DebugOptions, e.g. when creating a module from a string or a file. + static DebugOptions GetDebugOptionsForTest(); + + // Executes the given module and return the result as a Literal. + StatusOr> Execute( std::unique_ptr module, - tensorflow::gtl::ArraySlice - arguments, - Shape* result_shape); + tensorflow::gtl::ArraySlice arguments); - // Transfers the given literal to the device and returns the data handle. - perftools::gputools::DeviceMemoryBase TransferToDevice( - const Literal& literal); + std::unique_ptr ExecuteAndTransfer( + std::unique_ptr module, + tensorflow::gtl::ArraySlice arguments); + + // Executes the given hlo module on two backends and compares results. + // + // 'arguments': the input of the hlo module. The LiteralPtr type accepts + // Literal* or std::unique_ptr. + // + // 'error': if has value, expects the results to be near (within the error + // bound). Otherwise, expects the results to be equal. + // + // 'reference_preprocessor': the module should be ready to run on the test + // backend, but it might need to be tailored so that it is able to run on the + // reference backend. Note that the program shape of the module must not be + // modified. + template + ::testing::AssertionResult RunAndCompare( + std::unique_ptr module, + const tensorflow::gtl::ArraySlice arguments, + const tensorflow::gtl::optional& error, + const std::function& reference_preprocessor = nullptr) + TF_MUST_USE_RESULT; + + // Same as above, except that the module will be executed without Hlo + // optimization. + template + ::testing::AssertionResult RunAndCompareNoHloPasses( + std::unique_ptr module, + const tensorflow::gtl::ArraySlice arguments, + const tensorflow::gtl::optional& error, + const std::function& reference_preprocessor = nullptr) + TF_MUST_USE_RESULT; - // Transfers the array referred to by the given handle from the device and - // returns as a Literal. - std::unique_ptr TransferFromDevice( - const Shape& shape, perftools::gputools::DeviceMemoryBase device_base); + // Executes an hlo module with fake inputs and compares the results. + ::testing::AssertionResult RunAndCompare( + std::unique_ptr module, + const tensorflow::gtl::optional& error, + const std::function& reference_preprocessor = nullptr) + TF_MUST_USE_RESULT; - // Executes the given module and return the result as a Literal. - std::unique_ptr ExecuteAndTransfer( + // Same as above, except that the module will be executed without Hlo + // optimization. + ::testing::AssertionResult RunAndCompareNoHloPasses( std::unique_ptr module, - tensorflow::gtl::ArraySlice - arguments); + const tensorflow::gtl::optional& error, + const std::function& reference_preprocessor = nullptr) + TF_MUST_USE_RESULT; + + // Convenient wrappers for executing and comparing an hlo module with fake + // input. Module can be passed in directly, or parsed from an hlo_string, + // or loaded from a file. + ::testing::AssertionResult RunAndCompare( + const tensorflow::StringPiece hlo_string, + const tensorflow::gtl::optional& error, + const std::function& reference_preprocessor = nullptr) + TF_MUST_USE_RESULT; + ::testing::AssertionResult RunAndCompareFromFile( + const string& filename, const tensorflow::gtl::optional& error, + const std::function& reference_preprocessor = nullptr) + TF_MUST_USE_RESULT; + ::testing::AssertionResult RunAndCompareNoHloPasses( + const tensorflow::StringPiece hlo_string, + const tensorflow::gtl::optional& error, + const std::function& reference_preprocessor = nullptr) + TF_MUST_USE_RESULT; + ::testing::AssertionResult RunAndCompareNoHloPassesFromFile( + const string& filename, const tensorflow::gtl::optional& error, + const std::function& reference_preprocessor = nullptr) + TF_MUST_USE_RESULT; // Convenience method to force the layout of a given parameter in a module. // The layout of parameter number 'param_no' in the 'module' is set to @@ -99,14 +197,38 @@ class HloTestBase : public ::testing::Test { ->Clear(); } + // Return an HLO verifier constructed for the test backend. + HloVerifier& verifier() const { return *hlo_verifier_; } + static string TestName(); - // Returns the backend owned by the HloRunner. + // Returns the backend owned by the test runner. Backend& backend(); - HloRunner runner_; + HloRunner test_runner_; + HloRunner reference_runner_; + + std::unique_ptr hlo_verifier_; ErrorSpec error_spec_{0.0001}; + + private: + // Given the test module, makes a reference module that is ready to run on the + // reference platform. This assumes that the given module is ready to run on + // the test platform. + StatusOr> MakeReferenceModule( + const HloModule& test_module, + const std::function& reference_preprocessor); + + // Runs the module on two platforms with or without running hlo passes and + // compares the results. Returns whether the results are near or equal. If any + // error happens before the results are computed, returns the error status. + template + StatusOr<::testing::AssertionResult> RunAndCompareInternal( + std::unique_ptr module, + const tensorflow::gtl::ArraySlice arguments, + const tensorflow::gtl::optional& error, bool run_hlo_passes, + const std::function& reference_preprocessor); }; } // namespace xla diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc index 31060b9e80fcd50aefdedca27c70ec8a9b8be743..506091ddd8d1d8e6519525bb7031f4e8b296b5fb 100644 --- a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc @@ -23,15 +23,8 @@ limitations under the License. namespace xla { -/*static*/ int64 HloVerifiedTestBase::DefaultShapeSize(const Shape& shape) { - constexpr int64 kPointerSize = sizeof(void*); - if (ShapeUtil::IsOpaque(shape)) { - return kPointerSize; - } - return ShapeUtil::ByteSizeOf(shape, kPointerSize); -} - -HloVerifiedTestBase::HloVerifiedTestBase() : shape_size_fn_(DefaultShapeSize) {} +HloVerifiedTestBase::HloVerifiedTestBase() + : shape_verifier_(MakeUnique()) {} HloVerifiedTestBase::~HloVerifiedTestBase() { // We can't call the ASSERT or EXPECT test macros in destructors, so we @@ -47,7 +40,7 @@ void HloVerifiedTestBase::TearDown() { << "TearDown called more than once; it should be called exactly once."; tear_down_called_ = true; if (module_) { - HloVerifier verifier(shape_size_fn_); + HloVerifier verifier; xla::StatusOr mutated = verifier.Run(module_.get()); if (!mutated.ok()) { ADD_FAILURE() << "HloVerifier failed: " << mutated.status(); diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.h b/tensorflow/compiler/xla/tests/hlo_verified_test_base.h index b3d6b5af3b46f932707abf309669d23c327d1334..492688bf7d682cf991cb8c09399492a0437f651b 100644 --- a/tensorflow/compiler/xla/tests/hlo_verified_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base.h @@ -28,14 +28,13 @@ namespace xla { // A base class for HLO tests that stores a default HloModule, and automatically // performs verification on that module on tear-down. class HloVerifiedTestBase : public HloTestBase { - public: - // Returns the size in bytes of the given shape, using a default pointer size. - static int64 DefaultShapeSize(const Shape& shape); - protected: HloVerifiedTestBase(); ~HloVerifiedTestBase() override; + // Constructs a default shape verifier. + std::unique_ptr MakeShapeVerifier(); + // Performs verification on the default HloModule returned by module(). // Automatically called by the testing framework for each test. // @@ -47,14 +46,14 @@ class HloVerifiedTestBase : public HloTestBase { HloModule& module(); // Sets the shape-size function used during hlo verification. If this isn't - // called, DefaultShapeSize is used instead. - void SetShapeSizeFn(std::function shape_size_fn) { - shape_size_fn_ = std::move(shape_size_fn); + // called, a default ShapeVerifier is used instead. + void SetShapeVerifier(std::unique_ptr shape_verifier) { + shape_verifier_ = std::move(shape_verifier); } private: std::unique_ptr module_; // Lazily populated. Access via module(). - std::function shape_size_fn_; + std::unique_ptr shape_verifier_; bool tear_down_called_ = false; }; diff --git a/tensorflow/compiler/xla/tests/isolated_convolution.hlo b/tensorflow/compiler/xla/tests/isolated_convolution.hlo new file mode 100644 index 0000000000000000000000000000000000000000..9452780930efbb1ecc13b35cd4ab53678d36c37f --- /dev/null +++ b/tensorflow/compiler/xla/tests/isolated_convolution.hlo @@ -0,0 +1,8 @@ +HloModule convolution.167: + +ENTRY %convolution.167 (parameter.0: f32[16,28,28,128], parameter.1: f32[3,3,128,128]) -> f32[16,28,28,128] { + %parameter.0 = f32[16,28,28,128]{3,0,2,1} parameter(0) + %parameter.1 = f32[3,3,128,128]{3,2,1,0} parameter(1) + ROOT %convolution.167 = f32[16,28,28,128]{3,0,2,1} convolution(f32[16,28,28,128]{3,0,2,1} %parameter.0, f32[3,3,128,128]{3,2,1,0} %parameter.1), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01oi->b01f +} + diff --git a/tensorflow/compiler/xla/tests/literal_test_util.cc b/tensorflow/compiler/xla/tests/literal_test_util.cc index 6aa27e5470d22a8c6698389a720a38e9ea254617..e5b96c51ce303819e33d67f5f383c119d313bae1 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util.cc +++ b/tensorflow/compiler/xla/tests/literal_test_util.cc @@ -57,7 +57,8 @@ namespace xla { } for (int i = 0; i < expected.tuple_shapes_size(); ++i) { ::testing::AssertionResult result = - EqualShapes(expected.tuple_shapes(i), actual.tuple_shapes(i)); + EqualShapes(expected.tuple_shapes(i), actual.tuple_shapes(i)) + << "mismatch in tuple index " << i; if (!result) { return result; } @@ -100,36 +101,57 @@ namespace xla { ASSERT_EQ(expected.ShortDebugString(), actual.ShortDebugString()); } +namespace { + +// Return a literal with all arrays of type FromNativeT converted to type +// ToNativeT in the given literal. +template +std::unique_ptr ConvertType(const Literal& literal) { + // First construct shape of the result. + Shape result_shape(literal.shape()); + ShapeUtil::ForEachMutableSubshape( + &result_shape, [](Shape* subshape, const ShapeIndex&) { + if (subshape->element_type() == + primitive_util::NativeToPrimitiveType()) { + subshape->set_element_type( + primitive_util::NativeToPrimitiveType()); + } + }); + auto result = MakeUnique(result_shape); + + // Then copy over the data from 'literal' converting FromNativeT values to + // ToNativeT values as necessary. + ShapeUtil::ForEachSubshape( + literal.shape(), + [&](const Shape& subshape, const ShapeIndex& shape_index) { + if (ShapeUtil::IsArray(subshape)) { + if (subshape.element_type() == + primitive_util::NativeToPrimitiveType()) { + auto src = literal.data(shape_index); + auto dest = result->data(shape_index); + for (int64 i = 0; i < src.size(); ++i) { + dest[i] = static_cast(src[i]); + } + } else { + TF_CHECK_OK(result->CopyFrom(literal, + /*dest_shape_index=*/shape_index, + /*src_shape_index=*/shape_index)); + } + } + }); + return result; +} + +} // namespace + /* static */ std::unique_ptr LiteralTestUtil::ConvertBF16ToF32( - const Literal& bf16_literal) { - CHECK_EQ(bf16_literal.shape().element_type(), BF16); - Shape converted_shape = bf16_literal.shape(); - converted_shape.set_element_type(F32); - auto converted = Literal::CreateFromShape(converted_shape); - if (!ShapeUtil::HasZeroElements(converted_shape)) { - std::vector index(converted_shape.dimensions_size(), 0); - do { - converted->Set( - index, static_cast(bf16_literal.Get(index))); - } while (IndexUtil::BumpIndices(converted_shape, &index)); - } - return converted; + const Literal& literal) { + return ConvertType(literal); } /* static */ std::unique_ptr LiteralTestUtil::ConvertF32ToBF16( - const Literal& f32_literal) { - CHECK_EQ(f32_literal.shape().element_type(), F32); - Shape converted_shape = f32_literal.shape(); - converted_shape.set_element_type(BF16); - auto converted = Literal::CreateFromShape(converted_shape); - if (!ShapeUtil::HasZeroElements(converted_shape)) { - std::vector index(converted_shape.dimensions_size(), 0); - do { - converted->Set( - index, static_cast(f32_literal.Get(index))); - } while (IndexUtil::BumpIndices(converted_shape, &index)); - } - return converted; + const Literal& literal) { + return ConvertType(literal); } namespace { @@ -290,9 +312,10 @@ bool ExpectLiteralsEqual(const Literal& expected, const Literal& actual, break; case TUPLE: { bool tuple_match = true; - for (int i = 0; i < actual.tuple_literals_size(); ++i) { - auto result = - Equal(expected.tuple_literals(i), actual.tuple_literals(i)); + for (int i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) { + // Create LiteralViews of the expected and actual elements. + auto result = Equal(LiteralView::Create(expected, {i}), + LiteralView::Create(actual, {i})); tuple_match = tuple_match ? !!result : false; } match = tuple_match; @@ -313,23 +336,45 @@ bool ExpectLiteralsEqual(const Literal& expected, const Literal& actual, return result; } -/* static */ void LiteralTestUtil::ExpectEqualTuple(const Literal& expected, - const Literal& actual) { +/* static */ ::testing::AssertionResult LiteralTestUtil::EqualTuple( + const Literal& expected, const Literal& actual) { VLOG(1) << "expected: " << expected.ToString(); VLOG(1) << "actual: " << actual.ToString(); - ASSERT_TRUE(ShapeUtil::IsTuple(expected.shape())); - ASSERT_TRUE(ShapeUtil::IsTuple(actual.shape())); + if (!ShapeUtil::IsTuple(expected.shape()) || + !ShapeUtil::IsTuple(actual.shape())) { + return ::testing::AssertionFailure() + << "tuples expected shape = " << expected.shape().ShortDebugString() + << " actual shape = " << actual.shape().ShortDebugString(); + } AssertEqualShapes(expected.shape(), actual.shape()); - for (uint64 i = 0; i < expected.tuple_literals_size(); ++i) { - const auto& expected_element = expected.tuple_literals(i); - const auto& actual_element = actual.tuple_literals(i); - if (ShapeUtil::IsTuple(expected_element.shape())) { - ExpectEqualTuple(expected_element, actual_element); - } else { - ExpectEqual(expected_element, actual_element); + + ::testing::AssertionResult err = ::testing::AssertionSuccess(); + for (int64 i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) { + SCOPED_TRACE(tensorflow::strings::StrCat( + "Tuple index ", i, " in ", ShapeUtil::HumanString(expected.shape()))); + const auto expected_element = LiteralView::Create(expected, {i}); + const auto actual_element = LiteralView::Create(actual, {i}); + + ::testing::AssertionResult res = [&] { + if (ShapeUtil::IsTuple(expected_element.shape())) { + return EqualTuple(expected_element, actual_element); + } else { + return Equal(expected_element, actual_element); + } + }(); + + if (!res && err) { + err = res; } } + + return err; +} + +/* static */ void LiteralTestUtil::ExpectEqualTuple(const Literal& expected, + const Literal& actual) { + EXPECT_TRUE(EqualTuple(expected, actual)); } namespace { @@ -365,10 +410,7 @@ class NearComparator { abs_expected_miscompare_sum_ = 0.0; max_rel_err_ = 0.0; max_abs_err_ = 0.0; - *miscompares_.mutable_shape() = - ShapeUtil::ChangeElementType(actual.shape(), PRED); - miscompares_.mutable_preds()->resize( - ShapeUtil::ElementsIn(miscompares_.shape()), false); + miscompares_ = Literal(ShapeUtil::ChangeElementType(actual.shape(), PRED)); multi_index_.resize(expected.shape().dimensions_size(), 0); switch (expected.shape().element_type()) { @@ -595,33 +637,33 @@ bool NearComparator::ExpectValuesNear(bfloat16 expected, if (!ShapeUtil::IsTuple(expected.shape()) || !ShapeUtil::IsTuple(actual.shape())) { return ::testing::AssertionFailure() - << "tuples expected expected shape = " - << expected.shape().ShortDebugString() + << "tuples expected shape = " << expected.shape().ShortDebugString() << " actual shape = " << actual.shape().ShortDebugString(); } AssertEqualShapes(expected.shape(), actual.shape()); - for (uint64 i = 0; i < expected.tuple_literals_size(); ++i) { - const auto& expected_element = expected.tuple_literals(i); - const auto& actual_element = actual.tuple_literals(i); - if (ShapeUtil::IsTuple(expected_element.shape())) { - auto ret = NearTuple(expected_element, actual_element, error); - if (!ret) { - return ret; - } - } else if (ShapeUtil::ElementIsFloating(expected_element.shape())) { - auto ret = Near(expected_element, actual_element, error); - if (!ret) { - return ret; - } - } else { - auto ret = Equal(expected_element, actual_element); - if (!ret) { - return ret; + + ::testing::AssertionResult err = ::testing::AssertionSuccess(); + for (int64 i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) { + SCOPED_TRACE(tensorflow::strings::StrCat( + "Tuple index ", i, " in ", ShapeUtil::HumanString(expected.shape()))); + const auto expected_element = LiteralView::Create(expected, {i}); + const auto actual_element = LiteralView::Create(actual, {i}); + + ::testing::AssertionResult res = [&] { + if (ShapeUtil::IsTuple(expected_element.shape())) { + return NearTuple(expected_element, actual_element, error); + } else if (ShapeUtil::ElementIsFloating(expected_element.shape())) { + return Near(expected_element, actual_element, error); + } else { + return Equal(expected_element, actual_element); } + }(); + + if (err && !res) { + err = res; } } - - return ::testing::AssertionSuccess(); + return err; } /* static */ void LiteralTestUtil::ExpectNearTuple(const Literal& expected, @@ -630,6 +672,32 @@ bool NearComparator::ExpectValuesNear(bfloat16 expected, EXPECT_TRUE(NearTuple(expected, actual, error)); } +/*static*/ ::testing::AssertionResult LiteralTestUtil::NearOrEqual( + const Literal& expected, const Literal& actual, + const tensorflow::gtl::optional& error) { + bool is_tuple = ShapeUtil::IsTuple(expected.shape()); + if (error.has_value()) { + if (is_tuple) { + VLOG(1) << "Expects near tuple"; + return NearTuple(expected, actual, *error); + } + VLOG(1) << "Expects near"; + return Near(expected, actual, *error); + } + if (is_tuple) { + VLOG(1) << "Expects equal tuple"; + return EqualTuple(expected, actual); + } + VLOG(1) << "Expects equal"; + return Equal(expected, actual); +} + +/*static*/ void LiteralTestUtil::ExpectNearOrEqual( + const Literal& expected, const Literal& actual, + const tensorflow::gtl::optional& error) { + EXPECT_TRUE(NearOrEqual(expected, actual, error)); +} + /* static */ string LiteralTestUtil::MultiIndexAsString( tensorflow::gtl::ArraySlice multi_index) { return tensorflow::strings::StrCat( @@ -645,9 +713,8 @@ bool NearComparator::ExpectValuesNear(bfloat16 expected, } CHECK_EQ(ShapeUtil::ElementsIn(literal.shape()), new_num_elements); - auto new_literal = MakeUnique(); - *new_literal->mutable_shape() = - ShapeUtil::MakeShape(literal.shape().element_type(), new_dimensions); + auto new_literal = MakeUnique( + ShapeUtil::MakeShape(literal.shape().element_type(), new_dimensions)); // Create a new shape with the given minor-to-major layout. This shape is used // solely for converting linear address to multi-dimensional addresses when @@ -655,9 +722,6 @@ bool NearComparator::ExpectValuesNear(bfloat16 expected, Shape shape_with_layout = new_literal->shape(); *shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout(minor_to_major); - // Allocate space in the new literal. - new_literal->Reserve(ShapeUtil::ElementsIn(literal.shape())); - // Copy data into new literal, element-by-element. for (int64 i = 0; i < ShapeUtil::ElementsIn(literal.shape()); ++i) { std::vector from_multi_index = diff --git a/tensorflow/compiler/xla/tests/literal_test_util.h b/tensorflow/compiler/xla/tests/literal_test_util.h index 6e4add2690fd958d555eab3cef51cdbbd01819c9..f53553c70170bdcda717e72ffd791016effd0774 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util.h +++ b/tensorflow/compiler/xla/tests/literal_test_util.h @@ -31,6 +31,7 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -59,10 +60,14 @@ class LiteralTestUtil { static void AssertEqualShapesAndLayouts(const Shape& expected, const Shape& actual); - // Converts a bfloat16 literal to a float literal. + // If the given literal's data type is bfloat16, converts it to a float + // literal; otherwise, returns a copy of it. If the literal is a tuple, + // recursively converts its elements. static std::unique_ptr ConvertBF16ToF32(const Literal& bf16_literal); - // Converts a float literal to a bfloat16 literal. + // If the given literal's data type is float, converts it to a bfloat16 + // literal; otherwise, returns a copy of it. If the literal is a tuple, + // recursively converts its elements. static std::unique_ptr ConvertF32ToBF16(const Literal& f32_literal); // Asserts that the expected and actual literals are (bitwise) equal for all @@ -106,6 +111,10 @@ class LiteralTestUtil { static void ExpectR4EqualArray4D(const Array4D& expected, const Literal& actual); + // Returns whether the two tuples are equal. + static ::testing::AssertionResult EqualTuple( + const Literal& expected, const Literal& actual) TF_MUST_USE_RESULT; + // Expects that the values of the elements in the expected and actual tuples // are equal. Tuples are matched recursively. static void ExpectEqualTuple(const Literal& expected, const Literal& actual); @@ -173,6 +182,19 @@ class LiteralTestUtil { static void ExpectNearTuple(const Literal& expected, const Literal& actual, const ErrorSpec& error); + // If the error spec is given, returns whether the expected and the actual are + // within the error bound; otherwise, returns whether they are equal. Tuples + // will be compared recursively. + static ::testing::AssertionResult NearOrEqual( + const Literal& expected, const Literal& actual, + const tensorflow::gtl::optional& error) TF_MUST_USE_RESULT; + + // If the error spec is given, expects the expected and the actual to be near; + // otherwise, expects them to be equal. Tuples will be compared recursively. + static void ExpectNearOrEqual( + const Literal& expected, const Literal& actual, + const tensorflow::gtl::optional& error); + // Returns a multi-dimensional index as a string. For example: '{7, 8}' will // be returned for a 2-dimensional index with dimension 0 index equal to 7, // dimension 1 equal to 8. diff --git a/tensorflow/compiler/xla/tests/literal_test_util_test.cc b/tensorflow/compiler/xla/tests/literal_test_util_test.cc index 2acf27ed390b0732ba40fcf505c746bd7d8b651e..e477784557a3b9340cff644a3695485389d8cc22 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util_test.cc +++ b/tensorflow/compiler/xla/tests/literal_test_util_test.cc @@ -83,13 +83,14 @@ TEST(LiteralTestUtilTest, ExpectNearFailurePlacesResultsInTemporaryDirectory) { LiteralProto literal_proto; TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), result, &literal_proto)); - Literal literal(literal_proto); + std::unique_ptr literal = + Literal::CreateFromProto(literal_proto).ConsumeValueOrDie(); if (result.find("expected") != string::npos) { - EXPECT_EQ("2", literal.ToString()); + EXPECT_EQ("2", literal->ToString()); } else if (result.find("actual") != string::npos) { - EXPECT_EQ("4", literal.ToString()); + EXPECT_EQ("4", literal->ToString()); } else if (result.find("miscompares") != string::npos) { - EXPECT_EQ("true", literal.ToString()); + EXPECT_EQ("true", literal->ToString()); } else { FAIL() << "unknown file in temporary directory: " << result; } diff --git a/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc b/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc index 0cd44a72b5818c1bf66fd4cd1929572038596b47..4d3b513b092e0b447a1452a3809fb7099e54dbb9 100644 --- a/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc +++ b/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc @@ -63,8 +63,6 @@ int main(int argc, char** argv) { triple_string = "x86_64-apple-macosx"; } else if (target_cpu == "arm") { triple_string = "aarch64-none-linux-gnu"; - } else if (target_cpu == "ppc") { - triple_string = "powerpc64le-unknown-linux-gnu"; } else if (target_cpu == "local") { triple_string = xla::llvm_ir::AsString(llvm::sys::getDefaultTargetTriple()); } else { diff --git a/tensorflow/compiler/xla/tests/local_client_execute_test.cc b/tensorflow/compiler/xla/tests/local_client_execute_test.cc index ad71d40197fe48b4343ee5f5f7f71b282a05cbf5..2462ea39f914b1dbb525ea777a48d9ce66035638 100644 --- a/tensorflow/compiler/xla/tests/local_client_execute_test.cc +++ b/tensorflow/compiler/xla/tests/local_client_execute_test.cc @@ -138,13 +138,13 @@ XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentInputLayouts) { // Create x as a col-major array. auto x_array = LiteralToShapedBuffer(*Literal::CreateR2WithLayout( {{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({0, 1}))); - EXPECT_TRUE(LayoutUtil::Equal(x_array->shape().layout(), + EXPECT_TRUE(LayoutUtil::Equal(x_array->on_device_shape().layout(), LayoutUtil::MakeLayout({0, 1}))); // Create y as a row-major array. auto y_array = LiteralToShapedBuffer(*Literal::CreateR2WithLayout( {{10.0f, 20.0f}, {30.0f, 40.0f}}, LayoutUtil::MakeLayout({1, 0}))); - EXPECT_TRUE(LayoutUtil::Equal(y_array->shape().layout(), + EXPECT_TRUE(LayoutUtil::Equal(y_array->on_device_shape().layout(), LayoutUtil::MakeLayout({1, 0}))); std::unique_ptr result_colmaj = @@ -179,7 +179,7 @@ XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentOutputLayouts) { DefaultExecutableBuildOptions().set_result_layout( ShapeUtil::MakeShapeWithLayout(F32, /*dimensions=*/{2, 2}, {0, 1})), DefaultExecutableRunOptions()); - EXPECT_TRUE(LayoutUtil::Equal(result_colmaj->shape().layout(), + EXPECT_TRUE(LayoutUtil::Equal(result_colmaj->on_device_shape().layout(), LayoutUtil::MakeLayout({0, 1}))); LiteralTestUtil::ExpectR2Near({{11.0f, 22.0f}, {33.0f, 44.0f}}, *ShapedBufferToLiteral(*result_colmaj), @@ -191,7 +191,7 @@ XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentOutputLayouts) { DefaultExecutableBuildOptions().set_result_layout( ShapeUtil::MakeShapeWithLayout(F32, /*dimensions=*/{2, 2}, {1, 0})), DefaultExecutableRunOptions()); - EXPECT_TRUE(LayoutUtil::Equal(result_rowmaj->shape().layout(), + EXPECT_TRUE(LayoutUtil::Equal(result_rowmaj->on_device_shape().layout(), LayoutUtil::MakeLayout({1, 0}))); LiteralTestUtil::ExpectR2Near({{11.0f, 22.0f}, {33.0f, 44.0f}}, *ShapedBufferToLiteral(*result_rowmaj), @@ -213,16 +213,17 @@ XLA_TEST_F(LocalClientExecuteTest, TupleResult) { std::unique_ptr result = ExecuteLocallyOrDie(computation, {x_array.get(), y_array.get()}); - EXPECT_TRUE(ShapeUtil::IsTuple(result->shape())); - EXPECT_EQ(3, ShapeUtil::TupleElementCount(result->shape())); + EXPECT_TRUE(ShapeUtil::IsTuple(result->on_host_shape())); + EXPECT_EQ(3, ShapeUtil::TupleElementCount(result->on_host_shape())); std::unique_ptr result_literal = ShapedBufferToLiteral(*result); - LiteralTestUtil::ExpectR2Equal({{1.0f, 2.0f}, {3.0f, 4.0f}}, - result_literal->tuple_literals(0)); - LiteralTestUtil::ExpectR2Equal({{10.0f, 20.0f}, {30.0f, 40.0f}}, - result_literal->tuple_literals(1)); - LiteralTestUtil::ExpectR2Equal({{1.0f, 2.0f}, {3.0f, 4.0f}}, - result_literal->tuple_literals(2)); + LiteralTestUtil::ExpectR2Equal( + {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralView::Create(*result_literal, {0})); + LiteralTestUtil::ExpectR2Equal( + {{10.0f, 20.0f}, {30.0f, 40.0f}}, + LiteralView::Create(*result_literal, {1})); + LiteralTestUtil::ExpectR2Equal( + {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralView::Create(*result_literal, {2})); } XLA_TEST_F(LocalClientExecuteTest, NestedTupleResult) { @@ -241,19 +242,21 @@ XLA_TEST_F(LocalClientExecuteTest, NestedTupleResult) { std::unique_ptr result = ExecuteLocallyOrDie(computation, {x_array.get(), y_array.get()}); - EXPECT_TRUE(ShapeUtil::IsTuple(result->shape())); - EXPECT_EQ(2, ShapeUtil::TupleElementCount(result->shape())); + EXPECT_TRUE(ShapeUtil::IsTuple(result->on_host_shape())); + EXPECT_EQ(2, ShapeUtil::TupleElementCount(result->on_host_shape())); std::unique_ptr result_literal = ShapedBufferToLiteral(*result); - LiteralTestUtil::ExpectR2Equal({{1.0f, 2.0f}, {3.0f, 4.0f}}, - result_literal->tuple_literals(1)); - const Literal& inner_tuple_literal = result_literal->tuple_literals(0); - LiteralTestUtil::ExpectR2Equal({{1.0f, 2.0f}, {3.0f, 4.0f}}, - inner_tuple_literal.tuple_literals(0)); - LiteralTestUtil::ExpectR2Equal({{10.0f, 20.0f}, {30.0f, 40.0f}}, - inner_tuple_literal.tuple_literals(1)); - LiteralTestUtil::ExpectR2Equal({{1.0f, 2.0f}, {3.0f, 4.0f}}, - inner_tuple_literal.tuple_literals(2)); + LiteralTestUtil::ExpectR2Equal( + {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralView::Create(*result_literal, {1})); + LiteralTestUtil::ExpectR2Equal( + {{1.0f, 2.0f}, {3.0f, 4.0f}}, + LiteralView::Create(*result_literal, {0, 0})); + LiteralTestUtil::ExpectR2Equal( + {{10.0f, 20.0f}, {30.0f, 40.0f}}, + LiteralView::Create(*result_literal, {0, 1})); + LiteralTestUtil::ExpectR2Equal( + {{1.0f, 2.0f}, {3.0f, 4.0f}}, + LiteralView::Create(*result_literal, {0, 2})); } XLA_TEST_F(LocalClientExecuteTest, TupleResultWithLayout) { @@ -278,10 +281,10 @@ XLA_TEST_F(LocalClientExecuteTest, TupleResultWithLayout) { DefaultExecutableRunOptions()); std::unique_ptr result_literal = ShapedBufferToLiteral(*result); - LiteralTestUtil::ExpectR2Equal({{1.0f, 2.0f}, {3.0f, 4.0f}}, - result_literal->tuple_literals(0)); - LiteralTestUtil::ExpectR2Equal({{1.0f, 2.0f}, {3.0f, 4.0f}}, - result_literal->tuple_literals(1)); + LiteralTestUtil::ExpectR2Equal( + {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralView::Create(*result_literal, {0})); + LiteralTestUtil::ExpectR2Equal( + {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralView::Create(*result_literal, {1})); } XLA_TEST_F(LocalClientExecuteTest, TupleArguments) { @@ -320,14 +323,15 @@ XLA_TEST_F(LocalClientExecuteTest, TupleArguments) { std::unique_ptr result = ExecuteLocallyOrDie(computation, {x_buffer.get(), y_buffer.get()}); - EXPECT_TRUE(ShapeUtil::IsTuple(result->shape())); - EXPECT_EQ(2, ShapeUtil::TupleElementCount(result->shape())); + EXPECT_TRUE(ShapeUtil::IsTuple(result->on_host_shape())); + EXPECT_EQ(2, ShapeUtil::TupleElementCount(result->on_host_shape())); std::unique_ptr result_literal = ShapedBufferToLiteral(*result); - LiteralTestUtil::ExpectR2Equal({{56.0f, 46.0f}, {36.0f, 26.0f}}, - result_literal->tuple_literals(0)); - LiteralTestUtil::ExpectR1Equal({40.0f, 71.0f, 117.0f}, - result_literal->tuple_literals(1)); + LiteralTestUtil::ExpectR2Equal( + {{56.0f, 46.0f}, {36.0f, 26.0f}}, + LiteralView::Create(*result_literal, {0})); + LiteralTestUtil::ExpectR1Equal( + {40.0f, 71.0f, 117.0f}, LiteralView::Create(*result_literal, {1})); } XLA_TEST_F(LocalClientExecuteTest, NestedTupleArgument) { @@ -365,10 +369,10 @@ XLA_TEST_F(LocalClientExecuteTest, NestedTupleArgument) { ExecuteLocallyOrDie(computation, {arg_buffer.get()}); std::unique_ptr result_literal = ShapedBufferToLiteral(*result); - LiteralTestUtil::ExpectR2Equal({{-1.0, -2.0}, {-3.0, -4}}, - result_literal->tuple_literals(0)); - LiteralTestUtil::ExpectR1Equal({264.0, 73.0, 133.0}, - result_literal->tuple_literals(1)); + LiteralTestUtil::ExpectR2Equal( + {{-1.0, -2.0}, {-3.0, -4}}, LiteralView::Create(*result_literal, {0})); + LiteralTestUtil::ExpectR1Equal( + {264.0, 73.0, 133.0}, LiteralView::Create(*result_literal, {1})); } XLA_TEST_F(LocalClientExecuteTest, PassingTupleResultBackIntoComputation) { @@ -395,18 +399,19 @@ XLA_TEST_F(LocalClientExecuteTest, PassingTupleResultBackIntoComputation) { std::unique_ptr result_0 = ExecuteLocallyOrDie(computation, {arg_buffer.get()}); std::unique_ptr result_0_literal = ShapedBufferToLiteral(*result_0); - LiteralTestUtil::ExpectR2Equal({{-1.0, -2.0}, {-3.0, -4.0}}, - result_0_literal->tuple_literals(0)); - LiteralTestUtil::ExpectR2Equal({{22.0, 6.0}, {8.0, 10}}, - result_0_literal->tuple_literals(1)); + LiteralTestUtil::ExpectR2Equal( + {{-1.0, -2.0}, {-3.0, -4.0}}, + LiteralView::Create(*result_0_literal, {0})); + LiteralTestUtil::ExpectR2Equal( + {{22.0, 6.0}, {8.0, 10}}, LiteralView::Create(*result_0_literal, {1})); std::unique_ptr result_1 = ExecuteLocallyOrDie(computation, {result_0.get()}); std::unique_ptr result_1_literal = ShapedBufferToLiteral(*result_1); - LiteralTestUtil::ExpectR2Equal({{1.0, 2.0}, {3.0, 4.0}}, - result_1_literal->tuple_literals(0)); - LiteralTestUtil::ExpectR2Equal({{44.0, 12.0}, {16.0, 20}}, - result_1_literal->tuple_literals(1)); + LiteralTestUtil::ExpectR2Equal( + {{1.0, 2.0}, {3.0, 4.0}}, LiteralView::Create(*result_1_literal, {0})); + LiteralTestUtil::ExpectR2Equal( + {{44.0, 12.0}, {16.0, 20}}, LiteralView::Create(*result_1_literal, {1})); } XLA_TEST_F(LocalClientExecuteTest, LargeTuple) { @@ -455,7 +460,8 @@ XLA_TEST_F(LocalClientExecuteTest, LargeTuple) { for (int i = 0; i < kElementCount; ++i) { LiteralTestUtil::ExpectR1Near( - {2.0f * i, 0.0f}, result_literal->tuple_literals(i), error_spec_); + {2.0f * i, 0.0f}, LiteralView::Create(*result_literal, {i}), + error_spec_); } } @@ -512,8 +518,8 @@ XLA_TEST_F(LocalClientExecuteTest, DISABLED_ON_CPU_PARALLEL(LargeNestedTuple)) { for (int i = 0; i < kFanout; ++i) { for (int j = 0; j < kFanout; ++j) { LiteralTestUtil::ExpectR0Near( - i + j + i * kFanout + j, - result_literal->tuple_literals(i).tuple_literals(j), error_spec_); + i + j + i * kFanout + j, LiteralView::Create(*result_literal, {i, j}), + error_spec_); } } } @@ -554,11 +560,12 @@ XLA_TEST_F(LocalClientExecuteTest, DeepTuple) { ExecuteLocallyOrDie(computation, {arg_buffer.get()}); std::unique_ptr result_literal = ShapedBufferToLiteral(*result); - const Literal* result_element = result_literal.get(); + ShapeIndex index; for (int i = 0; i < kTupleDepth; ++i) { - result_element = &result_element->tuple_literals(0); + index.push_back(0); } - LiteralTestUtil::ExpectR0Equal(165.0, *result_element); + LiteralTestUtil::ExpectR0Equal( + 165.0, LiteralView::Create(*result_literal, index)); } XLA_TEST_F(LocalClientExecuteTest, InvalidNumberOfArguments) { @@ -575,7 +582,7 @@ XLA_TEST_F(LocalClientExecuteTest, InvalidNumberOfArguments) { EXPECT_FALSE(execute_status.ok()); EXPECT_THAT(execute_status.status().error_message(), - ContainsRegex("invalid number of arguments")); + ContainsRegex("Invalid number of arguments")); } XLA_TEST_F(LocalClientExecuteTest, IncorrectArgumentShape) { @@ -591,7 +598,7 @@ XLA_TEST_F(LocalClientExecuteTest, IncorrectArgumentShape) { EXPECT_FALSE(execute_status.ok()); EXPECT_THAT(execute_status.status().error_message(), - ContainsRegex("invalid argument shape")) + ContainsRegex("Invalid argument shape")) << execute_status.status(); } @@ -763,10 +770,10 @@ XLA_TEST_F(LocalClientExecuteTest, SelectBetweenTuples) { std::unique_ptr result = ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {}); std::unique_ptr tuple_literal = ShapedBufferToLiteral(*result); - LiteralTestUtil::ExpectR1Equal({2.0f, 4.0f, 6.0f}, - tuple_literal->tuple_literals(0)); - LiteralTestUtil::ExpectR1Equal({1.0f, 2.0f, 3.0f}, - tuple_literal->tuple_literals(1)); + LiteralTestUtil::ExpectR1Equal( + {2.0f, 4.0f, 6.0f}, LiteralView::Create(*tuple_literal, {0})); + LiteralTestUtil::ExpectR1Equal( + {1.0f, 2.0f, 3.0f}, LiteralView::Create(*tuple_literal, {1})); } XLA_TEST_F(LocalClientExecuteTest, CompileExecutable) { @@ -906,20 +913,18 @@ void BM_LocalClientOverhead(int num_iters) { builder.Add(x, x); auto computation = builder.Build().ConsumeValueOrDie(); - auto shape_size_fn = [client](const Shape& shape) { - return client->backend().transfer_manager()->GetByteSizeRequirement(shape); - }; - auto buffer = ScopedShapedBuffer::Allocate( - shape, &allocator, /*device_ordinal=*/0, shape_size_fn) - .ConsumeValueOrDie(); + auto buffer = + transfer_manager + ->AllocateScopedShapedBuffer(shape, &allocator, /*device_ordinal=*/0) + .ConsumeValueOrDie(); auto literal = Literal::CreateR2({{0, 0, 0}, {0, 0, 0}}); ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice( - executors[device_ordinal], *literal, buffer->mutable_buffer({}))); + executors[device_ordinal], *literal, *buffer)); const int kWarmups = 2; - auto executable_status = client->Compile(computation, {&buffer->shape()}, - ExecutableBuildOptions()); + auto executable_status = client->Compile( + computation, {&buffer->on_host_shape()}, ExecutableBuildOptions()); ASSERT_IS_OK(executable_status); std::unique_ptr executable = executable_status.ConsumeValueOrDie(); diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.cc b/tensorflow/compiler/xla/tests/local_client_test_base.cc index 062a9246e49598d5d03dce8c1f437138923449bf..96b976d25d75d35f46adfd104a03aceb363661eb 100644 --- a/tensorflow/compiler/xla/tests/local_client_test_base.cc +++ b/tensorflow/compiler/xla/tests/local_client_test_base.cc @@ -188,7 +188,7 @@ LocalClientTestBase::ExecuteLocally( const ExecutableRunOptions& run_options) { std::vector argument_layouts(arguments.size()); for (int i = 0; i < arguments.size(); ++i) { - argument_layouts[i] = &arguments[i]->shape(); + argument_layouts[i] = &arguments[i]->on_host_shape(); } TF_ASSIGN_OR_RETURN( std::unique_ptr executable, diff --git a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc index 22d2b917a1d55f4f453e21c2d8fea38e32ff796b..6e6cb7ff1e2ac74dc54f14d8811c9a5d3662bbd2 100644 --- a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc +++ b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc @@ -76,8 +76,11 @@ class MultiOutputFusionTest : public HloTestBase { elem_shape2, HloOpcode::kAdd, broadcast, param1)); HloInstruction* sub = builder.AddInstruction(HloInstruction::CreateBinary( elem_shape2, HloOpcode::kSubtract, param1, broadcast)); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); HloInstruction* dot = builder.AddInstruction( - HloInstruction::CreateBinary(elem_shape2, HloOpcode::kDot, sub, add2)); + HloInstruction::CreateDot(elem_shape2, sub, add2, dot_dnums)); auto computation = hlo_module->AddEntryComputation(builder.Build(dot)); if (manual_fusion) { @@ -96,14 +99,13 @@ class MultiOutputFusionTest : public HloTestBase { nullptr); } - Literal input; - input.PopulateWithValue(2.5f, {size, size}); - auto p1 = TransferToDevice(input); - auto p0 = TransferToDevice(*Literal::CreateR0(-9.0f)); + Literal arg1(ShapeUtil::MakeShape(F32, {size, size})); + arg1.PopulateWithValue(2.5f); - Literal expect; - expect.PopulateWithValue(size * 1.5f * 3.5f, {size, size}); - auto actual = ExecuteAndTransfer(std::move(hlo_module), {p0, p1}); + Literal expect(ShapeUtil::MakeShape(F32, {size, size})); + expect.PopulateWithValue(size * 1.5f * 3.5f); + auto actual = ExecuteAndTransfer( + std::move(hlo_module), {Literal::CreateR0(-9.0f).get(), &arg1}); LiteralTestUtil::ExpectNear(expect, *actual, error_spec_); } @@ -133,8 +135,11 @@ class MultiOutputFusionTest : public HloTestBase { HloInstruction* reshape = builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(F32, {size, 1}), add)); - HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {1}), HloOpcode::kDot, sub, reshape)); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(0); + dot_dnums.add_rhs_contracting_dimensions(0); + HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot( + ShapeUtil::MakeShape(F32, {1}), sub, reshape, dot_dnums)); auto computation = hlo_module->AddEntryComputation(builder.Build(dot)); if (manual_fusion) { @@ -154,14 +159,13 @@ class MultiOutputFusionTest : public HloTestBase { nullptr); } - Literal input0, input1; - input0.PopulateWithValue(2.5f, {size}); - input1.PopulateWithValue(1, {size}); - auto p0 = TransferToDevice(input0); - auto p1 = TransferToDevice(input1); + Literal input0(ShapeUtil::MakeShape(F32, {size})); + input0.PopulateWithValue(2.5f); + Literal input1(ShapeUtil::MakeShape(F64, {size})); + input1.PopulateWithValue(1.); - Literal expect = *Literal::CreateR1({size * 1.5f * 3.5f}); - auto actual = ExecuteAndTransfer(std::move(hlo_module), {p0, p1}); + Literal expect = std::move(*Literal::CreateR1({size * 1.5f * 3.5f})); + auto actual = ExecuteAndTransfer(std::move(hlo_module), {&input0, &input1}); LiteralTestUtil::ExpectNear(expect, *actual, error_spec_); } }; diff --git a/tensorflow/compiler/xla/tests/params_test.cc b/tensorflow/compiler/xla/tests/params_test.cc index b7f62b8aa167b2d9ef1bb2fa83af5aaeda1d6652..bb7e800df84121f2045141bc366c34b94ba694ea 100644 --- a/tensorflow/compiler/xla/tests/params_test.cc +++ b/tensorflow/compiler/xla/tests/params_test.cc @@ -334,10 +334,109 @@ XLA_TEST_F(ParamsTest, DISABLED_ON_CPU(DISABLED_ON_GPU( ComputeAndCompareTuple(&builder, *Literal::MakeTuple(ptrs), param_data); } +// Test large number of parameters flowing into a while-loop. +// Construct conceptually the following HLO graph: +// +// p0 = parameter(0) +// p1 = parameter(1) +// ... +// pN = parameter(N) +// result = while (false) { +// p0 += (1, 1); +// p1 += (1, 1); +// ... +// pN += (1, 1) +// } +// result = {p0, p1, ..., pN} +// +// TODO(b/70173746): Times out during compilation on GPU and CPU backends as of +// 2017-12-12. +XLA_TEST_F(ParamsTest, + DISABLED_ON_CPU(DISABLED_ON_GPU(ManyParametersIntoWhileLoop))) { + ComputationBuilder builder(client_, TestName()); + + std::vector> param_data_owner; + constexpr int kParamCount = 1900; + std::vector params; + std::vector parameter_shapes; + for (int i = 0; i < kParamCount; ++i) { + std::unique_ptr literal = Literal::CreateR1({i, i}); + param_data_owner.push_back( + std::move(client_->TransferToServer(*literal)).ValueOrDie()); + ComputationDataHandle param = + builder.Parameter(i, literal->shape(), "param"); + params.push_back(param); + parameter_shapes.push_back(literal->shape()); + } + + // Add bool parameter for the loop condition. Use a parameter HLO instead of a + // constant because DCE may eliminate the while-body otherwise. + std::unique_ptr bool_literal = Literal::CreateR0(false); + param_data_owner.push_back( + std::move(client_->TransferToServer(*bool_literal)).ValueOrDie()); + ComputationDataHandle bool_param = + builder.Parameter(kParamCount, bool_literal->shape(), "bool_param"); + params.push_back(bool_param); + parameter_shapes.push_back(bool_literal->shape()); + + auto init = builder.Tuple(params); + + // Create a computation for the condition: while(bool_param). + Shape while_shape = ShapeUtil::MakeTupleShape(parameter_shapes); + Computation condition; + { + ComputationBuilder builder(client_, "condition"); + auto condition_parameter = + builder.Parameter(0, while_shape, "condition_parameter"); + builder.GetTupleElement(condition_parameter, kParamCount); + condition = builder.Build().ConsumeValueOrDie(); + } + + // Create a computation for the body. + // Add {1, 1} to the each tuple element. + Computation body; + { + ComputationBuilder builder(client_, "body"); + auto body_parameter = builder.Parameter(0, while_shape, "body_parameter"); + std::vector updates; + for (int i = 0; i < kParamCount; ++i) { + auto add = builder.Add(builder.GetTupleElement(body_parameter, i), + builder.ConstantR1({1, 1})); + updates.push_back(add); + } + // Add bool parameter. + updates.push_back(builder.GetTupleElement(body_parameter, kParamCount)); + + builder.Tuple(updates); + body = builder.Build().ConsumeValueOrDie(); + } + + auto loop = builder.While(condition, body, init); + + std::vector outputs; + for (int i = 0; i < kParamCount; ++i) { + outputs.push_back(builder.GetTupleElement(loop, i)); + } + builder.Tuple(outputs); + + std::vector param_data; + param_data.reserve(param_data_owner.size()); + for (const std::unique_ptr& data : param_data_owner) { + param_data.push_back(data.get()); + } + + std::vector> elements; + std::vector ptrs; + for (int i = 0; i < kParamCount; ++i) { + elements.push_back(Literal::CreateR1({i, i})); + ptrs.push_back(elements.back().get()); + } + ComputeAndCompareTuple(&builder, *Literal::MakeTuple(ptrs), param_data); +} + #endif -XLA_TEST_F(ParamsTest, - DISABLED_ON_CPU_PARALLEL(TupleOfR1ParametersAddedTogether)) { +XLA_TEST_F(ParamsTest, TupleOfR1ParametersAddedTogether) { ComputationBuilder builder(client_, TestName()); Shape r1f32_3 = ShapeUtil::MakeShape(F32, {3}); @@ -363,10 +462,8 @@ XLA_TEST_F(ParamsTest, // Verifies that passing a 2x2 with {0, 1} layout returns the same value back // when (transferred to the server and) passed through a parameter. XLA_TEST_F(ParamsTest, R2_2x2_Layout_01) { - std::unique_ptr literal = Literal::CreateR2({ - {1, 2}, {3, 4}, - }); - *literal->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1}); + std::unique_ptr literal = Literal::CreateR2WithLayout( + {{1, 2}, {3, 4}}, LayoutUtil::MakeLayout({0, 1})); ComputationBuilder builder(client_, TestName()); builder.Parameter(0, literal->shape(), "input"); @@ -377,10 +474,8 @@ XLA_TEST_F(ParamsTest, R2_2x2_Layout_01) { // As above, but for {1, 0} layout. XLA_TEST_F(ParamsTest, R2_2x2_Layout_10) { - std::unique_ptr literal = Literal::CreateR2({ - {1, 3}, {2, 4}, - }); - *literal->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({1, 0}); + std::unique_ptr literal = Literal::CreateR2WithLayout( + {{1, 3}, {2, 4}}, LayoutUtil::MakeLayout({1, 0})); ComputationBuilder builder(client_, TestName()); builder.Parameter(0, literal->shape(), "input"); @@ -401,7 +496,7 @@ XLA_TEST_F(ParamsTest, R2_2x2_TryToPassReverseLayoutToParameter) { original.layout().minor_to_major().begin(), original.layout().minor_to_major().end()); std::reverse(original_layout.begin(), original_layout.end()); - *literal->mutable_shape()->mutable_layout() = + *literal->mutable_shape_do_not_use()->mutable_layout() = LayoutUtil::MakeLayout(original_layout); ASSERT_EQ(2, literal->Get({0, 1})); } diff --git a/tensorflow/compiler/xla/tests/prng_test.cc b/tensorflow/compiler/xla/tests/prng_test.cc index 209f063cc5a34648453d12deae79f261b95dc3b4..6489eee9f34c6c4426d52e166f7b401d5948742f 100644 --- a/tensorflow/compiler/xla/tests/prng_test.cc +++ b/tensorflow/compiler/xla/tests/prng_test.cc @@ -37,6 +37,8 @@ class PrngTest : public ClientLibraryTestBase { protected: template void UniformTest(T a, T b, tensorflow::gtl::ArraySlice dims); + + template void BernoulliTest(float p, tensorflow::gtl::ArraySlice dims); // Computes the χ² statistic of a sample of the discrete uniform distribution @@ -62,37 +64,6 @@ void PrngTest::UniformTest(T a, T b, tensorflow::gtl::ArraySlice dims) { }); } -void PrngTest::BernoulliTest(float p, tensorflow::gtl::ArraySlice dims) { - ComputationBuilder builder(client_, TestName()); - auto shape = ShapeUtil::MakeShape(U32, dims); - builder.RngBernoulli(builder.ConstantR0(p), shape); - - TF_ASSERT_OK_AND_ASSIGN(auto computation, builder.Build()); - ExecutionOptions execution_options = execution_options_; - execution_options.set_seed(42); - TF_ASSERT_OK_AND_ASSIGN( - auto actual, client_->ExecuteAndTransfer(computation, /*arguments=*/{}, - &execution_options)); - EXPECT_THAT(dims, ::testing::ElementsAreArray(actual->shape().dimensions())); - int32 sum = 0; - actual->EachCell( - [&sum](tensorflow::gtl::ArraySlice, uint32 value) { - EXPECT_TRUE(value == 0 || value == 1); - sum += value; - }); - int32 total = ShapeUtil::ElementsIn(shape); - float p_tilde = sum / static_cast(total); - - // Test within expected range using normal approximation. The test uses a - // fixed seed and has a fixed output per p and backend. Using the normal - // approximation as this test is invoked for different `p` and the different - // backends could use different random number generators and produce different - // values. Choose 95% confidence level, so that z_{1-\alpha/2} = 1.96. - float normal_approximation_term = 1.96 * sqrt(p * (1 - p) / total); - EXPECT_GE(p_tilde, p - normal_approximation_term); - EXPECT_LE(p_tilde, p + normal_approximation_term); -} - // Uniform random number generation tests XLA_TEST_F(PrngTest, ScalarU01) { UniformTest(0, 1, {}); } XLA_TEST_F(PrngTest, ZeroValuesU01) { UniformTest(0, 1, {0}); } @@ -181,10 +152,12 @@ XLA_TEST_F(PrngTest, MapUsingRng) { computation, /*arguments=*/{param0_data.get()}, &execution_options)); - EXPECT_EQ(actual->f32s_size(), param0_literal->f32s_size()); - for (int i = 0; i < param0_literal->f32s_size(); ++i) { - EXPECT_GE(actual->f32s(i), param0_literal->f32s(i)); - EXPECT_LT(actual->f32s(i), param0_literal->f32s(i) + 1.0f); + EXPECT_EQ(ShapeUtil::ElementsIn(actual->shape()), + ShapeUtil::ElementsIn(param0_literal->shape())); + for (int i = 0; i < ShapeUtil::ElementsIn(actual->shape()); ++i) { + EXPECT_GE(actual->data()[i], param0_literal->data()[i]); + EXPECT_LT(actual->data()[i], + param0_literal->data()[i] + 1.0f); } } @@ -250,10 +223,6 @@ XLA_TEST_F(PrngTest, PassInGlobalRngSeed) { LiteralTestUtil::ExpectNotEqual(*result5, *result6); } -// Bernoulli random number generation tests -XLA_TEST_F(PrngTest, HundredValuesB10p5) { BernoulliTest(0.5, {100}); } -XLA_TEST_F(PrngTest, HundredValuesB10p1) { BernoulliTest(0.1, {100}); } - XLA_TEST_F(PrngTest, TenValuesN01) { ComputationBuilder builder(client_, TestName()); builder.RngNormal(builder.ConstantR0(0), builder.ConstantR0(1), diff --git a/tensorflow/compiler/xla/tests/reduce_test.cc b/tensorflow/compiler/xla/tests/reduce_test.cc index 7bc3185c367f076c9a7d211c9799557e1a91d92f..a766fa2db0e193c52171490981855843ab3ee158 100644 --- a/tensorflow/compiler/xla/tests/reduce_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_test.cc @@ -143,6 +143,55 @@ class ReduceTest : public ClientLibraryTestBase { ComputeAndCompareR0(&builder, expected, {input_global_data.get()}); } + // Reduce predicate tensor with dimension rows * cols to dimension cols, to + // test the implementation of atomic operations on misaligned small data + // types. + template + void RunR2ToR1PredTest(bool and_reduce, int64 rows, int64 minor = 1, + int64 major = 0) { + ComputationBuilder builder(client_, TestName()); + const Shape input_shape = ShapeUtil::MakeShape(U8, {rows, cols}); + auto input = builder.Parameter(0, input_shape, "input"); + auto input_pred = builder.Eq(input, builder.ConstantR0(1)); + + ComputationDataHandle init_value; + Computation reduce_op; + if (and_reduce) { + init_value = builder.ConstantR0(true); + reduce_op = CreateScalarAndComputation(&builder); + } else { + init_value = builder.ConstantR0(false); + reduce_op = CreateScalarOrComputation(&builder); + } + + builder.Reduce(input_pred, init_value, reduce_op, + /*dimensions_to_reduce=*/{0}); + + Array2D input_data(rows, cols); + input_data.FillRandom(0, 1); + std::unique_ptr input_literal = + Literal::CreateR2FromArray2D(input_data); + input_literal = + input_literal->Relayout(LayoutUtil::MakeLayout({minor, major})); + std::unique_ptr input_global_data = + client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + + std::array expected; + for (int64 colno = 0; colno < cols; ++colno) { + bool column_sum = and_reduce ? true : false; + for (int64 rowno = 0; rowno < rows; ++rowno) { + if (and_reduce) { + column_sum = column_sum && input_data(rowno, colno); + } else { + column_sum = column_sum || input_data(rowno, colno); + } + } + expected[colno] = column_sum; + } + + ComputeAndCompareR1(&builder, expected, {input_global_data.get()}); + } + // Runs an R2 => R0 reduction test with the given number of (rows, cols). void RunR2ToR0Test(int64 rows, int64 cols, int64 minor = 1, int64 major = 0) { ComputationBuilder builder(client_, TestName()); @@ -352,15 +401,13 @@ XLA_TEST_F(ReduceTest, ReduceR2_111x50_01_To_R1) { XLA_TEST_F(ReduceTest, ReduceR2_1024x1024_To_R1) { RunR2ToR1Test(1024, 1024); } XLA_TEST_F(ReduceTest, ReduceR2_1000x1500_To_R1) { RunR2ToR1Test(1000, 1500); } -// TODO(b/34969189): Invalid CAS generated on GPU. -XLA_TEST_F(ReduceTest, DISABLED_ON_GPU(AndReduceAllOnesR1_10_Pred)) { +XLA_TEST_F(ReduceTest, AndReduceAllOnesR1_10_Pred) { constexpr int element_count = 10; std::vector input(element_count, 1); RunR1ToR0PredTest(/*and_reduce=*/true, input); } -// TODO(b/34969189): Invalid CAS generated on GPU. -XLA_TEST_F(ReduceTest, DISABLED_ON_GPU(AndReduceOnesAndZerosR1_10_Pred)) { +XLA_TEST_F(ReduceTest, AndReduceOnesAndZerosR1_10_Pred) { constexpr int element_count = 10; std::vector input(element_count); for (int i = 0; i < element_count; ++i) { @@ -369,15 +416,13 @@ XLA_TEST_F(ReduceTest, DISABLED_ON_GPU(AndReduceOnesAndZerosR1_10_Pred)) { RunR1ToR0PredTest(/*and_reduce=*/true, input); } -// TODO(b/34969189): Invalid CAS generated on GPU. -XLA_TEST_F(ReduceTest, DISABLED_ON_GPU(OrReduceAllOnesR1_10_Pred)) { +XLA_TEST_F(ReduceTest, OrReduceAllOnesR1_10_Pred) { constexpr int element_count = 10; std::vector input(element_count, 1); RunR1ToR0PredTest(/*and_reduce=*/false, input); } -// TODO(b/34969189): Invalid CAS generated on GPU. -XLA_TEST_F(ReduceTest, DISABLED_ON_GPU(OrReduceOnesAndZerosR1_10_Pred)) { +XLA_TEST_F(ReduceTest, OrReduceOnesAndZerosR1_10_Pred) { constexpr int element_count = 10; std::vector input(element_count); for (int i = 0; i < element_count; ++i) { @@ -812,5 +857,12 @@ XLA_TEST_F(ReduceTest, DISABLED_ON_GPU(OperationOnConstantAsInitValue)) { ComputeAndCompareR0(&builder, 4.0f, {b_data.get()}); } +XLA_TEST_F(ReduceTest, ReduceAndPredR2_128x64_To_R1) { + RunR2ToR1PredTest(/*and_reduce=true*/ true, /*rows=128*/ 128); +} +XLA_TEST_F(ReduceTest, ReduceOrPredR2_64x32_To_R1) { + RunR2ToR1PredTest(/*and_reduce=false*/ false, /*rows=64*/ 64); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc index 0601a1466bd87ab721443e0da725006e2d73e392..01f23efcd52e3b227309df3b7d965f3b4c3a0cdf 100644 --- a/tensorflow/compiler/xla/tests/reduce_window_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc @@ -41,16 +41,40 @@ limitations under the License. namespace xla { namespace { -class ReduceWindowTest : public ClientLibraryTestBase { +#ifdef XLA_BACKEND_SUPPORTS_BFLOAT16 +// Tests both F32 and BF16. +static std::array use_bfloat16_params{false, true}; +#else +// Only tests F32. +static std::array use_bfloat16_params{false}; +#endif + +class ReduceWindowTestBase : public ClientLibraryTestBase { public: - ReduceWindowTest() : builder_(client_, TestName()) {} + ErrorSpec DefaultErrorSpec() const { + if (use_bfloat16()) { + return ErrorSpec(1e-1, 5e-2); + } else { + return ErrorSpec(1e-3, 1e-3); + } + } +}; + +class ReduceWindowTest : public ::testing::WithParamInterface, + public ReduceWindowTestBase { + public: + ReduceWindowTest() : builder_(client_, TestName()) { + set_use_bfloat16(GetParam()); + } void ReduceWindowAdd(const ComputationDataHandle& input, tensorflow::gtl::ArraySlice window_dimensions, tensorflow::gtl::ArraySlice window_strides, Padding padding) { - builder_.ReduceWindow(input, builder_.ConstantR0(0.0f), - CreateScalarAddComputation(F32, &builder_), + auto init = + CreateConstantFromLiteral(*Literal::CreateR0(0.0f), &builder_); + builder_.ReduceWindow(input, init, + CreateScalarAddComputation(FloatType(), &builder_), window_dimensions, window_strides, padding); } @@ -58,30 +82,32 @@ class ReduceWindowTest : public ClientLibraryTestBase { tensorflow::gtl::ArraySlice window_dimensions, tensorflow::gtl::ArraySlice window_strides, Padding padding) { - builder_.ReduceWindow( - input, builder_.ConstantLiteral(Literal::MinValue(F32)), - CreateScalarMax(), window_dimensions, window_strides, padding); + auto init = CreateConstantFromLiteral(Literal::MinValue(F32), &builder_); + builder_.ReduceWindow(input, init, CreateScalarMax(), window_dimensions, + window_strides, padding); } void ReduceWindowMin(const ComputationDataHandle& input, tensorflow::gtl::ArraySlice window_dimensions, tensorflow::gtl::ArraySlice window_strides, Padding padding) { - builder_.ReduceWindow(input, - builder_.ConstantLiteral(Literal::MaxValue(F32)), - CreateScalarMinComputation(F32, &builder_), + auto init = CreateConstantFromLiteral(Literal::MaxValue(F32), &builder_); + builder_.ReduceWindow(input, init, + CreateScalarMinComputation(FloatType(), &builder_), window_dimensions, window_strides, padding); } ComputationBuilder builder_; }; -TEST_F(ReduceWindowTest, MismatchedRanksGivesErrorStatus) { - const auto input = builder_.ConstantR1({1, 1, 1, 1}); - const auto init_value = builder_.ConstantR0(0); +TEST_P(ReduceWindowTest, MismatchedRanksGivesErrorStatus) { + const auto input = CreateConstantFromLiteral( + *Literal::CreateR1({1, 1, 1, 1}), &builder_); + const auto init_value = + CreateConstantFromLiteral(*Literal::CreateR0(0), &builder_); TF_ASSERT_OK(builder_.first_error()); builder_.ReduceWindow(input, init_value, - CreateScalarAddComputation(F32, &builder_), + CreateScalarAddComputation(FloatType(), &builder_), /*window_dimensions=*/{1, 2}, /*window_strides=*/{1}, Padding::kValid); ASSERT_EQ(builder_.first_error().code(), tensorflow::error::INVALID_ARGUMENT) @@ -91,88 +117,106 @@ TEST_F(ReduceWindowTest, MismatchedRanksGivesErrorStatus) { } // Regression test for b/68964348. -TEST_F(ReduceWindowTest, R0ReduceWindow) { - auto input = builder_.ConstantR0(42); - auto init = builder_.ConstantR0(1.0); - builder_.ReduceWindow(input, init, CreateScalarAddComputation(F32, &builder_), +TEST_P(ReduceWindowTest, R0ReduceWindow) { + const auto input = + CreateConstantFromLiteral(*Literal::CreateR0(42.0), &builder_); + const auto init = + CreateConstantFromLiteral(*Literal::CreateR0(1.0), &builder_); + builder_.ReduceWindow(input, init, + CreateScalarAddComputation(FloatType(), &builder_), /*window_dimensions=*/{}, /*window_strides=*/{}, Padding::kSame); - ComputeAndCompareR0(&builder_, 43, {}, ErrorSpec(0.00001)); + ComputeAndCompareLiteral(&builder_, *Literal::CreateR0(43.0), {}, + ErrorSpec(0.00001)); } -TEST_F(ReduceWindowTest, Min3In5Stride2) { - const auto input = builder_.ConstantR1({10000, 1000, 100, 10, 1}); +TEST_P(ReduceWindowTest, Min3In5Stride2) { + const auto input = CreateConstantFromLiteral( + *Literal::CreateR1({10000, 1000, 100, 10, 1}), &builder_); ReduceWindowMin(input, {3}, {2}, Padding::kValid); - ComputeAndCompareR1(&builder_, {100, 1}, {}, ErrorSpec(0.0001)); + ComputeAndCompareLiteral(&builder_, *Literal::CreateR1({100, 1}), {}, + ErrorSpec(0.00001)); } -XLA_TEST_F(ReduceWindowTest, ZeroElementSmall) { - Array4D input_array(1, 0, 2, 1); +TEST_P(ReduceWindowTest, Min3In5Stride1WithSamePadding) { + const auto input = CreateConstantFromLiteral( + *Literal::CreateR1({10000, 1000, 100, 10, 1}), &builder_); + ReduceWindowMin(input, /*window_dimensions=*/{3}, /*window_strides=*/{1}, + Padding::kSame); + ComputeAndCompareLiteral(&builder_, + *Literal::CreateR1({1000, 100, 10, 1, 1}), {}, + ErrorSpec(0.00001)); +} - const auto input = builder_.ConstantR4FromArray4D(input_array); +XLA_TEST_P(ReduceWindowTest, ZeroElementSmall) { + Array4D input_array(1, 0, 2, 1); + const auto input = CreateConstantFromArray(input_array, &builder_); Padding padding = Padding::kSame; ReduceWindowAdd(input, {1, 1, 2, 1}, {1, 1, 1, 1}, padding); auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 2, 1}, {1, 1, 1, 1}, padding); - ComputeAndCompareR4(&builder_, *res, {}, ErrorSpec(1e-3, 1e-3)); + ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), {}, + DefaultErrorSpec()); } -TEST_F(ReduceWindowTest, NonSquareSmall) { +TEST_P(ReduceWindowTest, NonSquareSmall) { Array4D input_array(1, 2, 2, 1); - input_array.FillRandom(2.f); + input_array.FillRandom(2.f, 2.f); + const auto input = CreateConstantFromArray(input_array, &builder_); - const auto input = builder_.ConstantR4FromArray4D(input_array); Padding padding = Padding::kSame; ReduceWindowAdd(input, {1, 1, 2, 1}, {1, 1, 1, 1}, padding); auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 2, 1}, {1, 1, 1, 1}, padding); - ComputeAndCompareR4(&builder_, *res, {}, ErrorSpec(1e-3, 1e-3)); + ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), {}, + DefaultErrorSpec()); } -TEST_F(ReduceWindowTest, MiddleDimsSmall) { +TEST_P(ReduceWindowTest, MiddleDimsSmall) { Array4D input_array(1, 3, 3, 1); - input_array.FillRandom(2.f); - - const auto input = builder_.ConstantR4FromArray4D(input_array); + input_array.FillRandom(2.f, 2.f); + const auto input = CreateConstantFromArray(input_array, &builder_); Padding padding = Padding::kSame; ReduceWindowAdd(input, {1, 1, 1, 1}, {1, 2, 2, 1}, padding); auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 1, 1}, {1, 2, 2, 1}, padding); - ComputeAndCompareR4(&builder_, *res, {}, ErrorSpec(1e-3, 1e-3)); + ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), {}, + DefaultErrorSpec()); } -TEST_F(ReduceWindowTest, Along2ndMinorDim) { +TEST_P(ReduceWindowTest, Along2ndMinorDim) { Array4D input_array(3, 6, 7, 32); - input_array.FillRandom(2.f); + input_array.FillRandom(2.f, 2.f); + const auto input = CreateConstantFromArray(input_array, &builder_); // The parameters of this reduction mimic feature norm (e.g. LRN). int lrn_diameter = 7; // diameter = 2*radius + 1 --> must be odd - const auto input = builder_.ConstantR4FromArray4D(input_array); Padding padding = Padding::kSame; ReduceWindowAdd(input, {1, 1, lrn_diameter, 1}, {1, 1, 1, 1}, padding); auto res = ReferenceUtil::ReduceWindow4DAdd( input_array, 0.0f, {1, 1, lrn_diameter, 1}, {1, 1, 1, 1}, padding); - ComputeAndCompareR4(&builder_, *res, {}, ErrorSpec(1e-3, 1e-3)); + ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), {}, + DefaultErrorSpec()); } -TEST_F(ReduceWindowTest, AmongMajor2Dims) { +TEST_P(ReduceWindowTest, AmongMajor2Dims) { Array4D input_array(4, 4, 6, 8); input_array.FillWithMinorDimNum(); + const auto input_data_handle = + CreateConstantFromArray(input_array, &builder_); int win_len = 3; int win_stride = 1; Padding padding = Padding::kSame; - const auto input_data_handle = - builder_.ConstantR4FromArray4D(input_array); // Reduce only along the x and y dimensions, according to the win_len. ReduceWindowAdd(input_data_handle, {win_len, win_len, 1, 1}, {win_stride, win_stride, 1, 1}, padding); @@ -180,18 +224,20 @@ TEST_F(ReduceWindowTest, AmongMajor2Dims) { auto result = ReferenceUtil::ReduceWindow4DAdd( input_array, 0.0f, {win_len, win_len, 1, 1}, {win_stride, win_stride, 1, 1}, padding); - ComputeAndCompareR4(&builder_, *result, {}, ErrorSpec(1e-3, 1e-3)); + + ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*result), {}, + DefaultErrorSpec()); } -TEST_F(ReduceWindowTest, AmongMajor2DimsMediumSize) { +TEST_P(ReduceWindowTest, AmongMajor2DimsMediumSize) { Array4D input_array(9, 12, 4, 89); - input_array.FillRandom(2.0f); + input_array.FillRandom(2.f, 2.f); int win_len = 3; int win_stride = 2; const auto input_data_handle = - builder_.ConstantR4FromArray4D(input_array); + CreateConstantFromArray(input_array, &builder_); Padding padding = Padding::kSame; // Reduce only along the x and y dimensions, according to the win_len. @@ -202,116 +248,34 @@ TEST_F(ReduceWindowTest, AmongMajor2DimsMediumSize) { input_array, 0.0f, {win_len, win_len, 1, 1}, {win_stride, win_stride, 1, 1}, padding); - ComputeAndCompareR4(&builder_, *result, {}, ErrorSpec(1e-3, 1e-3)); -} - -// TODO(b/32173947): Test support for arbitrary-sized padding. -TEST_F(ReduceWindowTest, DISABLED_AmongMajor2DimsMediumSizeLargePadding) { - Array4D input_array(9, 12, 4, 89); // simulate Dim0IsMinor layout - input_array.FillRandom(2.0f); - - int64 rank = 4; - int win_len = 3; - int win_stride = 2; - - const auto input_data_handle = - builder_.ConstantR4FromArray4D(input_array); - - Padding padding = Padding::kSame; - // Reduce only along the x and y dimensions, according to the win_len. - // Create padding vector with large padding values in the reduction dims. - std::vector> low_high_padding; - low_high_padding.resize(rank, {4, 4}); - - builder_.ReduceWindowWithGeneralPadding( - input_data_handle, builder_.ConstantR0(0.0f), - CreateScalarAddComputation(F32, &builder_), {win_len, win_len, 1, 1}, - {win_stride, win_stride, 1, 1}, low_high_padding); - - auto result = ReferenceUtil::ReduceWindow4DAdd( - input_array, 0.0f, {win_len, win_len, 1, 1}, - {win_stride, win_stride, 1, 1}, padding); - - ComputeAndCompareR4(&builder_, *result, {}, ErrorSpec(1e-3, 1e-3)); -} - -XLA_TEST_F(ReduceWindowTest, Add1x1x2In2x1x2) { - Array3D input_array(2, 1, 2); - input_array(0, 0, 0) = 1000; - input_array(0, 0, 1) = 100; - input_array(1, 0, 0) = 10; - input_array(1, 0, 1) = 1; - auto input = builder_.ConstantR3FromArray3D(input_array); - - ReduceWindowAdd(input, {1, 1, 2}, {1, 1, 1}, Padding::kValid); - - Array3D expected(2, 1, 1); - expected(0, 0, 0) = 1100; - expected(1, 0, 0) = 11; - ComputeAndCompareR3(&builder_, expected, {}, ErrorSpec(0.0001)); -} - -XLA_TEST_F(ReduceWindowTest, Add1x1x2In2x1x3Stride1x1x2) { - Array3D input_array(2, 1, 3); - input_array(0, 0, 0) = 100; - input_array(0, 0, 1) = 10; - input_array(0, 0, 2) = 1; - input_array(1, 0, 0) = 500; - input_array(1, 0, 1) = 50; - input_array(1, 0, 2) = 5; - auto input = builder_.ConstantR3FromArray3D(input_array); - - ReduceWindowAdd(input, {1, 1, 2}, {1, 1, 2}, Padding::kValid); - - Array3D expected(2, 1, 1); - expected(0, 0, 0) = 110; - expected(1, 0, 0) = 550; - ComputeAndCompareR3(&builder_, expected, {}, ErrorSpec(0.0001)); -} - -XLA_TEST_F(ReduceWindowTest, Add1x1x2In2x1x3SamePad) { - Array3D input_array(2, 1, 3); - input_array(0, 0, 0) = 100; - input_array(0, 0, 1) = 10; - input_array(0, 0, 2) = 1; - input_array(1, 0, 0) = 500; - input_array(1, 0, 1) = 50; - input_array(1, 0, 2) = 5; - auto input = builder_.ConstantR3FromArray3D(input_array); - - ReduceWindowAdd(input, {1, 1, 2}, {1, 1, 1}, Padding::kSame); - - Array3D expected(2, 1, 3); - expected(0, 0, 0) = 110; - expected(0, 0, 1) = 11; - expected(0, 0, 2) = 1; - expected(1, 0, 0) = 550; - expected(1, 0, 1) = 55; - expected(1, 0, 2) = 5; - ComputeAndCompareR3(&builder_, expected, {}, ErrorSpec(0.0001)); + ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*result), {}, + DefaultErrorSpec()); } // Tests a reduction function that is not a simple add/min/max/etc. -XLA_TEST_F(ReduceWindowTest, NonstandardReduceFunction) { +XLA_TEST_P(ReduceWindowTest, NonstandardReduceFunction) { Array4D input_array(1, 2, 2, 1); input_array(0, 0, 0, 0) = 1; input_array(0, 0, 1, 0) = 2; input_array(0, 1, 0, 0) = 3; input_array(0, 1, 1, 0) = 4; + const auto input = CreateConstantFromArray(input_array, &builder_); - const auto input = builder_.ConstantR4FromArray4D(input_array); Padding padding = Padding::kValid; - - const Shape scalar = ShapeUtil::MakeShape(F32, {}); + const Shape scalar = ShapeUtil::MakeShape(FloatType(), {}); auto b = builder_.CreateSubBuilder("unusual"); auto lhs = b->Parameter(0, scalar, "lhs"); auto rhs = b->Parameter(1, scalar, "rhs"); - b->Min(b->Add(lhs, rhs), b->ConstantR0(8.0f)); + b->Min(b->Add(lhs, rhs), + CreateConstantFromLiteral(*Literal::CreateR0(8.0f), b.get())); Computation reduce_fn = b->BuildAndNoteError(); - builder_.ReduceWindow(input, builder_.ConstantR0(3.0f), reduce_fn, - /*window_dimensions=*/{1, 1, 2, 1}, - /*window_strides=*/{1, 1, 1, 1}, padding); + builder_.ReduceWindow( + input, + CreateConstantFromLiteral(*Literal::CreateR0(3.0f), &builder_), + reduce_fn, + /*window_dimensions=*/{1, 1, 2, 1}, + /*window_strides=*/{1, 1, 1, 1}, padding); const auto reduce_func = [](float arg1, float arg2) { return std::min(arg1 + arg2, 8.0f); @@ -322,17 +286,19 @@ XLA_TEST_F(ReduceWindowTest, NonstandardReduceFunction) { /*window=*/{1, 1, 2, 1}, /*stride=*/{1, 1, 1, 1}, padding); - ComputeAndCompareR4(&builder_, *expected, {}, ErrorSpec(1e-3, 1e-3)); + ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*expected), {}, + DefaultErrorSpec()); } -TEST_F(ReduceWindowTest, R4UnitWindow) { +TEST_P(ReduceWindowTest, R4UnitWindow) { Array4D input_array(13, 12, 8, 15); - input_array.Fill(1.0f); + input_array.FillRandom(2.f, 2.f); std::unique_ptr input_literal = Literal::CreateR4FromArray4DWithLayout( input_array, LayoutUtil::MakeLayout({0, 3, 2, 1})); - ComputationDataHandle input = - builder_.Parameter(0, input_literal->shape(), "operand"); + ComputationDataHandle input; + auto input_data = CreateParameterAndTransferLiteral( + 0, *input_literal, "parameter", &builder_, &input); Padding padding = Padding::kSame; ReduceWindowAdd(input, {1, 1, 7, 1}, {1, 4, 1, 1}, padding); @@ -340,15 +306,11 @@ TEST_F(ReduceWindowTest, R4UnitWindow) { auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 7, 1}, {1, 4, 1, 1}, padding); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr input_data, - client_->TransferToServer(*input_literal)); - ComputeAndCompareR4(&builder_, *res, {input_data.get()}, - ErrorSpec(1e-3, 1e-3)); + ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), + {input_data.get()}, DefaultErrorSpec()); } -XLA_TEST_F(HloTestBase, R6AddMultipleStrides) { - auto b = HloComputation::Builder(TestName()); - +XLA_TEST_P(ReduceWindowTest, R6AddMultipleStrides) { std::vector input_dims(6, 8); auto shape = ShapeUtil::MakeShape(F32, input_dims); @@ -358,56 +320,15 @@ XLA_TEST_F(HloTestBase, R6AddMultipleStrides) { }; TF_EXPECT_OK(arg_literal->Populate(generator)); - auto input = - b.AddInstruction(HloInstruction::CreateConstant(std::move(arg_literal))); - - auto init_value = b.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.f))); - - HloComputation::Builder add_computation("add"); - Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); - auto param_lhs = add_computation.AddInstruction( - HloInstruction::CreateParameter(0, scalar_shape, "lhs")); - auto param_rhs = add_computation.AddInstruction( - HloInstruction::CreateParameter(1, scalar_shape, "rhs")); - add_computation.AddInstruction(HloInstruction::CreateBinary( - scalar_shape, HloOpcode::kAdd, param_lhs, param_rhs)); - - auto module = CreateNewModule(); - auto add_func = module->AddEmbeddedComputation(add_computation.Build()); - - WindowDimension trivial_dim; - trivial_dim.set_size(1); - trivial_dim.set_stride(1); - trivial_dim.set_padding_low(0); - trivial_dim.set_padding_high(0); - trivial_dim.set_window_dilation(1); - trivial_dim.set_base_dilation(1); - - WindowDimension active_dim; - active_dim.set_size(3); - active_dim.set_stride(1); - active_dim.set_padding_low(0); - active_dim.set_padding_high(0); - active_dim.set_window_dilation(1); - active_dim.set_base_dilation(1); - - Window window; - *window.add_dimensions() = active_dim; - *window.add_dimensions() = trivial_dim; - *window.add_dimensions() = active_dim; - *window.add_dimensions() = active_dim; - *window.add_dimensions() = trivial_dim; - *window.add_dimensions() = trivial_dim; - - // Non-monotonic output layout with minor dims trivial. + const auto input = CreateConstantFromLiteral(*arg_literal, &builder_); + + Padding padding = Padding::kValid; + ReduceWindowAdd(input, {3, 1, 3, 3, 1, 1}, {1, 1, 1, 1, 1, 1}, padding); + std::vector output_layout = {1, 5, 3, 2, 0, 4}; std::vector output_dims = {6, 8, 6, 6, 8, 8}; Shape result_shape = ShapeUtil::MakeShapeWithLayout(F32, output_dims, output_layout); - b.AddInstruction(HloInstruction::CreateReduceWindow( - result_shape, input, init_value, window, add_func)); - std::unique_ptr expected = Literal::CreateFromShape(result_shape); auto out_generator = [&](tensorflow::gtl::ArraySlice indexes) -> float { @@ -415,82 +336,37 @@ XLA_TEST_F(HloTestBase, R6AddMultipleStrides) { }; TF_EXPECT_OK(expected->Populate(out_generator)); - module->AddEntryComputation(b.Build()); - auto actual = ExecuteAndTransfer(std::move(module), {}); - - LiteralTestUtil::ExpectNear(*actual, *expected, ErrorSpec(1e-3, 1e-3)); + ComputeAndCompareLiteral(&builder_, *expected, {}, DefaultErrorSpec()); } -XLA_TEST_F(HloTestBase, R6Add) { - auto b = HloComputation::Builder(TestName()); - +XLA_TEST_P(ReduceWindowTest, R6Add) { std::vector input_dims(6, 8); + auto shape = ShapeUtil::MakeShape(F32, input_dims); + std::unique_ptr arg_literal = - Literal::CreateFullWithMonotonicDim0MajorLayout(input_dims, 1.0f); - auto input = - b.AddInstruction(HloInstruction::CreateConstant(std::move(arg_literal))); - - auto init_value = b.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.f))); - - HloComputation::Builder add_computation("add"); - Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); - auto param_lhs = add_computation.AddInstruction( - HloInstruction::CreateParameter(0, scalar_shape, "lhs")); - auto param_rhs = add_computation.AddInstruction( - HloInstruction::CreateParameter(1, scalar_shape, "rhs")); - add_computation.AddInstruction(HloInstruction::CreateBinary( - scalar_shape, HloOpcode::kAdd, param_lhs, param_rhs)); - - auto module = CreateNewModule(); - auto add_func = module->AddEmbeddedComputation(add_computation.Build()); - - WindowDimension trivial_dim; - trivial_dim.set_size(1); - trivial_dim.set_stride(1); - trivial_dim.set_padding_low(0); - trivial_dim.set_padding_high(0); - trivial_dim.set_window_dilation(1); - trivial_dim.set_base_dilation(1); - - WindowDimension active_dim; - active_dim.set_size(3); - active_dim.set_stride(1); - active_dim.set_padding_low(0); - active_dim.set_padding_high(0); - active_dim.set_window_dilation(1); - active_dim.set_base_dilation(1); - - Window window; - *window.add_dimensions() = trivial_dim; - *window.add_dimensions() = trivial_dim; - *window.add_dimensions() = active_dim; - *window.add_dimensions() = active_dim; - *window.add_dimensions() = trivial_dim; - *window.add_dimensions() = trivial_dim; - - Shape shape = ShapeUtil::MakeShape(F32, {8, 8, 6, 6, 8, 8}); - b.AddInstruction(HloInstruction::CreateReduceWindow(shape, input, init_value, - window, add_func)); + Literal::CreateFullWithDescendingLayout(input_dims, 1.0f); + + const auto input = CreateConstantFromLiteral(*arg_literal, &builder_); + + Padding padding = Padding::kValid; + ReduceWindowAdd(input, {1, 1, 3, 3, 1, 1}, {1, 1, 1, 1, 1, 1}, padding); std::vector output_dims = {8, 8, 6, 6, 8, 8}; std::unique_ptr expected = - Literal::CreateFullWithMonotonicDim0MajorLayout(output_dims, 9.0f); - - module->AddEntryComputation(b.Build()); - auto actual = ExecuteAndTransfer(std::move(module), {}); + Literal::CreateFullWithDescendingLayout(output_dims, 9.0f); - LiteralTestUtil::ExpectNear(*actual, *expected, ErrorSpec(1e-3, 1e-3)); + ComputeAndCompareLiteral(&builder_, *expected, {}, DefaultErrorSpec()); } -XLA_TEST_F(ReduceWindowTest, R4SecondMinorStride) { +XLA_TEST_P(ReduceWindowTest, R4SecondMinorStride) { Array4D input_array(2, 1, 27, 119); input_array.FillRandom(2.0f); std::unique_ptr input_literal = Literal::CreateR4FromArray4DWithLayout( input_array, LayoutUtil::MakeLayout({3, 2, 1, 0})); - ComputationDataHandle input = - builder_.Parameter(0, input_literal->shape(), "operand"); + ComputationDataHandle input; + auto input_data = CreateParameterAndTransferLiteral( + 0, *input_literal, "parameter", &builder_, &input); int win_len = 1; int stride = 8; @@ -500,20 +376,19 @@ XLA_TEST_F(ReduceWindowTest, R4SecondMinorStride) { auto res = ReferenceUtil::ReduceWindow4DAdd( input_array, 0.0f, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr input_data, - client_->TransferToServer(*input_literal)); - ComputeAndCompareR4(&builder_, *res, {input_data.get()}, - ErrorSpec(1e-3, 1e-3)); + ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), + {input_data.get()}, DefaultErrorSpec()); } -XLA_TEST_F(ReduceWindowTest, R4SecondMinorUnitStride) { +XLA_TEST_P(ReduceWindowTest, R4SecondMinorUnitStride) { Array4D input_array(3, 2, 4, 64); input_array.FillRandom(2.0f); std::unique_ptr input_literal = Literal::CreateR4FromArray4DWithLayout( input_array, LayoutUtil::MakeLayout({3, 2, 1, 0})); - ComputationDataHandle input = - builder_.Parameter(0, input_literal->shape(), "operand"); + ComputationDataHandle input; + auto input_data = CreateParameterAndTransferLiteral( + 0, *input_literal, "parameter", &builder_, &input); int win_len = 3; int stride = 1; @@ -523,20 +398,19 @@ XLA_TEST_F(ReduceWindowTest, R4SecondMinorUnitStride) { auto res = ReferenceUtil::ReduceWindow4DAdd( input_array, 0.0f, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr input_data, - client_->TransferToServer(*input_literal)); - ComputeAndCompareR4(&builder_, *res, {input_data.get()}, - ErrorSpec(1e-3, 1e-3)); + ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), + {input_data.get()}, DefaultErrorSpec()); } -XLA_TEST_F(ReduceWindowTest, R4SecondMinorWin) { +XLA_TEST_P(ReduceWindowTest, R4SecondMinorWin) { Array4D input_array(1, 3, 12, 200); input_array.FillRandom(2.0f); std::unique_ptr input_literal = Literal::CreateR4FromArray4DWithLayout( input_array, LayoutUtil::MakeLayout({3, 2, 1, 0})); - ComputationDataHandle input = - builder_.Parameter(0, input_literal->shape(), "operand"); + ComputationDataHandle input; + auto input_data = CreateParameterAndTransferLiteral( + 0, *input_literal, "parameter", &builder_, &input); int win_len = 8; int stride = 5; @@ -546,13 +420,11 @@ XLA_TEST_F(ReduceWindowTest, R4SecondMinorWin) { auto res = ReferenceUtil::ReduceWindow4DAdd( input_array, 0.0f, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr input_data, - client_->TransferToServer(*input_literal)); - ComputeAndCompareR4(&builder_, *res, {input_data.get()}, - ErrorSpec(1e-3, 1e-3)); + ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), + {input_data.get()}, DefaultErrorSpec()); } -TEST_F(ReduceWindowTest, AmongMajor2DimsMultipleMinor) { +TEST_P(ReduceWindowTest, AmongMajor2DimsMultipleMinor) { Array4D input_array(6, 4, 10, 130); input_array.FillRandom(2.0f); @@ -561,7 +433,7 @@ TEST_F(ReduceWindowTest, AmongMajor2DimsMultipleMinor) { Padding padding = Padding::kSame; const auto input_data_handle = - builder_.ConstantR4FromArray4D(input_array); + CreateConstantFromArray(input_array, &builder_); // Reduce only along the x and y dimensions, according to the win_len. ReduceWindowAdd(input_data_handle, {win_len, win_len, 1, 1}, {win_stride, win_stride, 1, 1}, padding); @@ -569,36 +441,59 @@ TEST_F(ReduceWindowTest, AmongMajor2DimsMultipleMinor) { auto result = ReferenceUtil::ReduceWindow4DAdd( input_array, 0.0f, {win_len, win_len, 1, 1}, {win_stride, win_stride, 1, 1}, padding); - ComputeAndCompareR4(&builder_, *result, {}, ErrorSpec(1e-3, 1e-3)); + ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*result), {}, + DefaultErrorSpec()); } -XLA_TEST_F(ReduceWindowTest, Add24In1152_NoOverlap) { +XLA_TEST_P(ReduceWindowTest, Add24In1152_NoOverlap) { std::vector input_vector(128 * 9, 1); - const auto input = builder_.ConstantR1(input_vector); + const auto input = CreateConstantFromLiteral( + *Literal::CreateR1(input_vector), &builder_); ReduceWindowAdd(input, {32}, {128}, Padding::kValid); - ComputeAndCompareR1(&builder_, {32, 32, 32, 32, 32, 32, 32, 32, 32}, - {}, ErrorSpec(0.0001)); + ComputeAndCompareLiteral( + &builder_, + *Literal::CreateR1({32, 32, 32, 32, 32, 32, 32, 32, 32}), {}, + DefaultErrorSpec()); } -XLA_TEST_F(ReduceWindowTest, Add128In128Stride128) { - const auto input = builder_.ConstantR1( - {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); +XLA_TEST_P(ReduceWindowTest, Add128In128Stride128) { + std::vector input_vector{ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; + const auto input = CreateConstantFromLiteral( + *Literal::CreateR1(input_vector), &builder_); ReduceWindowAdd(input, {128}, {128}, Padding::kValid); - ComputeAndCompareR1(&builder_, {1088}, {}, ErrorSpec(0.0001)); + ComputeAndCompareLiteral(&builder_, *Literal::CreateR1({1088}), {}, + DefaultErrorSpec()); +} + +XLA_TEST_P(ReduceWindowTest, Add128In128) { + std::vector input_vector{ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; + const auto input = CreateConstantFromLiteral( + *Literal::CreateR1(input_vector), &builder_); + ReduceWindowAdd(input, {128}, {1}, Padding::kValid); + ComputeAndCompareLiteral(&builder_, *Literal::CreateR1({1088}), {}, + DefaultErrorSpec()); } // Regression test for a bug that appeared in Inception (b/34784899). -TEST_F(ReduceWindowTest, R2ReduceWindowInceptionFromBroadcast) { +TEST_P(ReduceWindowTest, R2ReduceWindowInceptionFromBroadcast) { Array2D input_array(14, 14, 1.0f); - ComputationDataHandle input = - builder_.Broadcast(builder_.ConstantLiteral(Literal::One(F32)), {14, 14}); + const auto input = CreateConstantFromArray(input_array, &builder_); int win_len = 3; int stride = 1; @@ -608,13 +503,14 @@ TEST_F(ReduceWindowTest, R2ReduceWindowInceptionFromBroadcast) { auto res = ReferenceUtil::ReduceWindow2DAdd( input_array, 0.0f, {win_len, win_len}, {stride, stride}, padding); - ComputeAndCompareR2(&builder_, *res, {}, ErrorSpec(1e-3, 1e-3)); + ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), + {}, DefaultErrorSpec()); } -TEST_F(ReduceWindowTest, R2ReduceWindowNonOverlappingFromBroadcast) { +TEST_P(ReduceWindowTest, R2ReduceWindowNonOverlappingFromBroadcast) { Array2D input_array(6, 4, 1.0f); - ComputationDataHandle input = - builder_.Broadcast(builder_.ConstantLiteral(Literal::One(F32)), {6, 4}); + ComputationDataHandle input = builder_.Broadcast( + CreateConstantFromLiteral(Literal::One(F32), &builder_), {6, 4}); Padding padding = Padding::kSame; ReduceWindowAdd(input, {4, 2}, {3, 3}, padding); @@ -622,9 +518,13 @@ TEST_F(ReduceWindowTest, R2ReduceWindowNonOverlappingFromBroadcast) { auto res = ReferenceUtil::ReduceWindow2DAdd(input_array, 0.0f, {4, 2}, {3, 3}, padding); - ComputeAndCompareR2(&builder_, *res, {}, ErrorSpec(1e-3, 1e-3)); + ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), + {}, DefaultErrorSpec()); } +INSTANTIATE_TEST_CASE_P(ReduceWindowTestInstance, ReduceWindowTest, + ::testing::ValuesIn(use_bfloat16_params)); + enum Reducer { kAdd, kMax }; struct R4ReduceWindowTestData { @@ -638,30 +538,36 @@ struct R4ReduceWindowTestData { }; string R4ReduceWindowTestDataToString( - const ::testing::TestParamInfo& data) { + const ::testing::TestParamInfo< + ::testing::tuple>& data) { + const auto& param = ::testing::get<0>(data.param); string str = tensorflow::strings::StrCat( - "base_bounds_", - tensorflow::str_util::Join(data.param.base_bounds, "x"), // + "base_bounds_", tensorflow::str_util::Join(param.base_bounds, "x"), // "__window_bounds_", - tensorflow::str_util::Join(data.param.window_bounds, "x"), // - "__strides_", tensorflow::str_util::Join(data.param.strides, "x"), // - "__pad_low_", tensorflow::str_util::Join(data.param.pad_low, "x"), // - "__pad_high_", tensorflow::str_util::Join(data.param.pad_high, "x"), // - (data.param.reducer == kAdd) ? "add" : "max"); - CHECK(data.param.reducer == kAdd || data.param.reducer == kMax); + tensorflow::str_util::Join(param.window_bounds, "x"), // + "__strides_", tensorflow::str_util::Join(param.strides, "x"), // + "__pad_low_", tensorflow::str_util::Join(param.pad_low, "x"), // + "__pad_high_", tensorflow::str_util::Join(param.pad_high, "x"), // + (param.reducer == kAdd) ? "add" : "max"); + CHECK(param.reducer == kAdd || param.reducer == kMax); // Test names are not allowed to contain the '-' character. std::replace(str.begin(), str.end(), '-', 'n'); + if (::testing::get<1>(data.param)) { + str = tensorflow::strings::StrCat(str, "_bfloat16"); + } return str; } -class R4ReduceWindowTest - : public ClientLibraryTestBase, - public ::testing::WithParamInterface { +class R4ReduceWindowTest : public ReduceWindowTestBase, + public ::testing::WithParamInterface< + ::testing::tuple> { protected: + R4ReduceWindowTest() { set_use_bfloat16(::testing::get<1>(GetParam())); } + void DoIt() { ComputationBuilder b(client_, TestName()); - const auto& param = GetParam(); + const auto& param = ::testing::get<0>(GetParam()); const float kInitValue = 0.0f; @@ -670,23 +576,24 @@ class R4ReduceWindowTest input.FillIota(1); std::unique_ptr input_literal = Literal::CreateR4FromArray4D(input); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr input_arg, - client_->TransferToServer(*input_literal)); + ComputationDataHandle parameter; + auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0", + &b, ¶meter); std::vector> padding(4); for (int i = 0; i < 4; ++i) { padding[i] = {param.pad_low[i], param.pad_high[i]}; } - auto parameter = b.Parameter(0, input_literal->shape(), "p0"); - auto pad_value = b.ConstantR0(kInitValue); + auto init_value = + CreateConstantFromLiteral(*Literal::CreateR0(kInitValue), &b); CHECK(param.reducer == kAdd || param.reducer == kMax); auto computation = param.reducer == kAdd - ? CreateScalarAddComputation(F32, &b) - : CreateScalarMaxComputation(F32, &b); + ? CreateScalarAddComputation(FloatType(), &b) + : CreateScalarMaxComputation(FloatType(), &b); b.ReduceWindowWithGeneralPadding( /*operand=*/parameter, - /*init_value=*/pad_value, + /*init_value=*/init_value, /*computation=*/computation, /*window_dimensions=*/param.window_bounds, /*window_strides=*/param.strides, @@ -704,8 +611,8 @@ class R4ReduceWindowTest /*window=*/param.window_bounds, /*stride=*/param.strides, /*padding=*/padding); - ComputeAndCompareR4(&b, *expected, {input_arg.get()}, - ErrorSpec(1e-3, 1e-3)); + ComputeAndCompareLiteral(&b, *Literal::CreateFromArray(*expected), + {input_arg.get()}, DefaultErrorSpec()); } }; @@ -721,6 +628,14 @@ const R4ReduceWindowTestData kR4ReduceWindowTestValues[] = { /*pad_high=*/{0, 0, 0, 0}, /*reducer=*/kAdd}, + // Arbitrary padding (not kSame or kValid). + R4ReduceWindowTestData{/*base_bounds=*/{9, 12, 4, 89}, + /*window_bounds=*/{3, 3, 1, 1}, + /*strides=*/{2, 2, 1, 1}, + /*pad_low=*/{4, 4, 0, 0}, + /*pad_high=*/{4, 4, 0, 0}, + /*reducer=*/kAdd}, + // Zero base bound edge case. R4ReduceWindowTestData{/*base_bounds=*/{1, 0, 1, 1}, /*window_bounds=*/{1, 1, 1, 1}, @@ -834,13 +749,15 @@ const R4ReduceWindowTestData kR4ReduceWindowTestValues[] = { /*reducer=*/kAdd}, }; -INSTANTIATE_TEST_CASE_P(R4ReduceWindowTestInstantiation, R4ReduceWindowTest, - ::testing::ValuesIn(kR4ReduceWindowTestValues), - R4ReduceWindowTestDataToString); +INSTANTIATE_TEST_CASE_P( + R4ReduceWindowTestInstantiation, R4ReduceWindowTest, + ::testing::Combine(::testing::ValuesIn(kR4ReduceWindowTestValues), + ::testing::ValuesIn(use_bfloat16_params)), + R4ReduceWindowTestDataToString); class R4ReduceWindowLargeTest : public R4ReduceWindowTest {}; -XLA_TEST_P(R4ReduceWindowLargeTest, DoIt) { DoIt(); } +XLA_TEST_P(R4ReduceWindowLargeTest, DISABLED_ON_INTERPRETER(DoIt)) { DoIt(); } // Test cases that are large/slow/failed. const R4ReduceWindowTestData kR4ReduceWindowLargeTestValues[] = { @@ -859,10 +776,103 @@ const R4ReduceWindowTestData kR4ReduceWindowLargeTestValues[] = { /*reducer=*/kAdd}, }; -INSTANTIATE_TEST_CASE_P(R4ReduceWindowLargeTestInstantiation, - R4ReduceWindowLargeTest, - ::testing::ValuesIn(kR4ReduceWindowLargeTestValues), - R4ReduceWindowTestDataToString); +INSTANTIATE_TEST_CASE_P( + R4ReduceWindowLargeTestInstantiation, R4ReduceWindowLargeTest, + ::testing::Combine(::testing::ValuesIn(kR4ReduceWindowLargeTestValues), + ::testing::ValuesIn(use_bfloat16_params)), + R4ReduceWindowTestDataToString); + +struct R3ReduceWindowTestData { + int64 base_bounds[3]; + int64 window_bounds[3]; + int64 strides[3]; + int64 layout[3]; + Padding padding; + Reducer reducer; +} kR3TestCases[] = { + {/*base_bounds=*/{2, 1, 2}, /*window_bounds=*/{1, 1, 2}, + /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0}, + /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + {/*base_bounds=*/{4, 3, 3}, /*window_bounds=*/{2, 2, 2}, + /*strides=*/{2, 2, 2}, /*layout=*/{2, 1, 0}, + /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd}, + {/*base_bounds=*/{4, 3, 3}, /*window_bounds=*/{2, 2, 2}, + /*strides=*/{2, 2, 2}, /*layout=*/{2, 1, 0}, + /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + {/*base_bounds=*/{6, 21, 3}, /*window_bounds=*/{2, 3, 2}, + /*strides=*/{1, 2, 2}, /*layout=*/{2, 1, 0}, + /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + {/*base_bounds=*/{10, 21, 129}, /*window_bounds=*/{2, 9, 1}, + /*strides=*/{5, 2, 1}, /*layout=*/{2, 1, 0}, + /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd}, + {/*base_bounds=*/{6, 21, 3}, /*window_bounds=*/{2, 3, 2}, + /*strides=*/{1, 2, 2}, /*layout=*/{0, 1, 2}, + /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + {/*base_bounds=*/{6, 21, 3}, /*window_bounds=*/{2, 3, 2}, + /*strides=*/{1, 2, 2}, /*layout=*/{1, 0, 2}, + /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, +}; + +string R3ReduceWindowTestDataToString( + const ::testing::TestParamInfo< + ::testing::tuple>& data) { + const auto& param = ::testing::get<0>(data.param); + string str = tensorflow::strings::StrCat( + "base_bounds_", tensorflow::str_util::Join(param.base_bounds, "x"), + "__window_bounds_", tensorflow::str_util::Join(param.window_bounds, "x"), + "__strides_", tensorflow::str_util::Join(param.strides, "x"), + "__padding_", param.padding == Padding::kSame ? "same" : "valid", + "__layout_", param.layout[0], "_", param.layout[1], "_", param.layout[2], + "__reducer_", param.reducer == kAdd ? "add" : "max"); + if (::testing::get<1>(data.param)) { + str = tensorflow::strings::StrCat(str, "_bfloat16"); + } + return str; +} + +class R3ReduceWindowTest : public ReduceWindowTestBase, + public ::testing::WithParamInterface< + ::testing::tuple> { + protected: + R3ReduceWindowTest() { set_use_bfloat16(::testing::get<1>(GetParam())); } +}; + +TEST_P(R3ReduceWindowTest, Add) { + ComputationBuilder b(client_, TestName()); + const auto& param = ::testing::get<0>(GetParam()); + CHECK(param.reducer == kAdd); + + const float kInitValue = 0.0f; + Array3D input(param.base_bounds[0], param.base_bounds[1], + param.base_bounds[2], 1.0f); + std::unique_ptr input_literal = + Literal::CreateR3FromArray3DWithLayout( + input, LayoutUtil::MakeLayout(param.layout)); + + ComputationDataHandle parameter; + auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0", + &b, ¶meter); + auto init_value = + CreateConstantFromLiteral(*Literal::CreateR0(kInitValue), &b); + b.ReduceWindow(/*operand=*/parameter, + /*init_value=*/init_value, + /*computation=*/CreateScalarAddComputation(FloatType(), &b), + /*window_dimensions=*/param.window_bounds, + /*window_strides=*/param.strides, /*padding=*/param.padding); + + auto expected = ReferenceUtil::ReduceWindow3DAdd( + /*operand=*/input, /*init=*/kInitValue, /*window=*/param.window_bounds, + /*stride=*/param.strides, /*padding=*/param.padding); + + ComputeAndCompareLiteral(&b, *Literal::CreateFromArray(*expected), + {input_arg.get()}, DefaultErrorSpec()); +} + +INSTANTIATE_TEST_CASE_P( + R3ReduceWindowTestInstantiation, R3ReduceWindowTest, + ::testing::Combine(::testing::ValuesIn(kR3TestCases), + ::testing::ValuesIn(use_bfloat16_params)), + R3ReduceWindowTestDataToString); struct R2ReduceWindowTestData { int64 base_bounds[2]; @@ -910,26 +920,33 @@ struct R2ReduceWindowTestData { }; string R2ReduceWindowTestDataToString( - const ::testing::TestParamInfo& data) { + const ::testing::TestParamInfo< + ::testing::tuple>& data) { + const auto& param = ::testing::get<0>(data.param); string str = tensorflow::strings::StrCat( - "base_bounds_", - tensorflow::str_util::Join(data.param.base_bounds, "x"), // + "base_bounds_", tensorflow::str_util::Join(param.base_bounds, "x"), // "__window_bounds_", - tensorflow::str_util::Join(data.param.window_bounds, "x"), // - "__strides_", tensorflow::str_util::Join(data.param.strides, "x"), // - "__padding_", data.param.padding == Padding::kSame ? "same" : "valid", // - "__layout_", data.param.layout[0], "_", data.param.layout[1], // - "__reducer_", data.param.reducer == kAdd ? "add" : "max"); + tensorflow::str_util::Join(param.window_bounds, "x"), // + "__strides_", tensorflow::str_util::Join(param.strides, "x"), // + "__padding_", param.padding == Padding::kSame ? "same" : "valid", // + "__layout_", param.layout[0], "_", param.layout[1], // + "__reducer_", param.reducer == kAdd ? "add" : "max"); + if (::testing::get<1>(data.param)) { + str = tensorflow::strings::StrCat(str, "_bfloat16"); + } return str; } -class R2ReduceWindowTest - : public ClientLibraryTestBase, - public ::testing::WithParamInterface {}; +class R2ReduceWindowTest : public ReduceWindowTestBase, + public ::testing::WithParamInterface< + ::testing::tuple> { + protected: + R2ReduceWindowTest() { set_use_bfloat16(::testing::get<1>(GetParam())); } +}; TEST_P(R2ReduceWindowTest, Add) { ComputationBuilder b(client_, TestName()); - const auto& param = GetParam(); + const auto& param = ::testing::get<0>(GetParam()); CHECK(param.reducer == kAdd); const float kInitValue = 0.0f; @@ -937,12 +954,15 @@ TEST_P(R2ReduceWindowTest, Add) { std::unique_ptr input_literal = Literal::CreateR2FromArray2DWithLayout( input, LayoutUtil::MakeLayout(param.layout)); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr input_arg, - client_->TransferToServer(*input_literal)); - b.ReduceWindow(/*operand=*/ - b.Parameter(0, input_literal->shape(), "p0"), - /*init_value=*/b.ConstantR0(kInitValue), - /*computation=*/CreateScalarAddComputation(F32, &b), + + ComputationDataHandle parameter; + auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0", + &b, ¶meter); + auto init_value = + CreateConstantFromLiteral(*Literal::CreateR0(kInitValue), &b); + b.ReduceWindow(/*operand=*/parameter, + /*init_value=*/init_value, + /*computation=*/CreateScalarAddComputation(FloatType(), &b), /*window_dimensions=*/param.window_bounds, /*window_strides=*/param.strides, /*padding=*/param.padding); @@ -950,90 +970,145 @@ TEST_P(R2ReduceWindowTest, Add) { /*operand=*/input, /*init=*/kInitValue, /*window=*/param.window_bounds, /*stride=*/param.strides, /*padding=*/param.padding); - ComputeAndCompareR2(&b, *expected, {input_arg.get()}, - ErrorSpec(1e-3, 1e-3)); + ComputeAndCompareLiteral(&b, *Literal::CreateFromArray(*expected), + {input_arg.get()}, DefaultErrorSpec()); } -INSTANTIATE_TEST_CASE_P(R2ReduceWindowTestInstantiation, R2ReduceWindowTest, - ::testing::ValuesIn(kR2TestCases), - R2ReduceWindowTestDataToString); +INSTANTIATE_TEST_CASE_P( + R2ReduceWindowTestInstantiation, R2ReduceWindowTest, + ::testing::Combine(::testing::ValuesIn(kR2TestCases), + ::testing::ValuesIn(use_bfloat16_params)), + R2ReduceWindowTestDataToString); struct R1ReduceWindowTestData { int64 base_bounds[1]; int64 window_bounds[1]; int64 strides[1]; - Padding padding; + int64 pad_low[1]; + int64 pad_high[1]; Reducer reducer; } kR1TestCases[] = { {/*base_bounds=*/{1}, /*window_bounds=*/{1}, /*strides=*/{1}, - /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + /*pad_low=*/{xla::MakePadding({1}, {1}, {1}, Padding::kValid)[0].first}, + /*pad_high=*/{xla::MakePadding({1}, {1}, {1}, Padding::kValid)[0].second}, + /*reducer=*/Reducer::kAdd}, {/*base_bounds=*/{3}, /*window_bounds=*/{3}, /*strides=*/{1}, - /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + /*pad_low=*/{xla::MakePadding({3}, {3}, {1}, Padding::kValid)[0].first}, + /*pad_high=*/{xla::MakePadding({3}, {3}, {1}, Padding::kValid)[0].second}, + /*reducer=*/Reducer::kAdd}, {/*base_bounds=*/{3}, /*window_bounds=*/{2}, /*strides=*/{1}, - /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + /*pad_low=*/{xla::MakePadding({3}, {2}, {1}, Padding::kValid)[0].first}, + /*pad_high=*/{xla::MakePadding({3}, {2}, {1}, Padding::kValid)[0].second}, + /*reducer=*/Reducer::kAdd}, {/*base_bounds=*/{5}, /*window_bounds=*/{1}, /*strides=*/{1}, - /*padding=*/Padding::kValid, /*reducer=*/Reducer::kMax}, + /*pad_low=*/{xla::MakePadding({5}, {1}, {1}, Padding::kValid)[0].first}, + /*pad_high=*/{xla::MakePadding({5}, {1}, {1}, Padding::kValid)[0].second}, + /*reducer=*/Reducer::kMax}, {/*base_bounds=*/{16}, /*window_bounds=*/{4}, /*strides=*/{4}, - /*padding=*/Padding::kValid, /*reducer=*/Reducer::kMax}, + /*pad_low=*/{xla::MakePadding({16}, {4}, {4}, Padding::kValid)[0].first}, + /*pad_high=*/{xla::MakePadding({16}, {4}, {4}, Padding::kValid)[0].second}, + /*reducer=*/Reducer::kMax}, {/*base_bounds=*/{16}, /*window_bounds=*/{4}, /*strides=*/{3}, - /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + /*pad_low=*/{xla::MakePadding({16}, {4}, {3}, Padding::kValid)[0].first}, + /*pad_high=*/{xla::MakePadding({16}, {4}, {3}, Padding::kValid)[0].second}, + /*reducer=*/Reducer::kAdd}, - {/*base_bounds=*/{128 * 2}, /*window_bounds=*/{30}, + {/*base_bounds=*/{128 * 2}, + /*window_bounds=*/{30}, /*strides=*/{27}, - /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, - - {/*base_bounds=*/{128 * 17}, /*window_bounds=*/{7}, + /*pad_low=*/ + {xla::MakePadding({128 * 2}, {30}, {27}, Padding::kValid)[0].first}, + /*pad_high=*/ + {xla::MakePadding({128 * 2}, {30}, {27}, Padding::kValid)[0].second}, + /*reducer=*/Reducer::kAdd}, + + {/*base_bounds=*/{128 * 17}, + /*window_bounds=*/{7}, /*strides=*/{64}, - /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, - - {/*base_bounds=*/{128 * 2}, /*window_bounds=*/{32}, + /*pad_low=*/ + {xla::MakePadding({128 * 17}, {7}, {64}, Padding::kValid)[0].first}, + /*pad_high=*/ + {xla::MakePadding({128 * 17}, {7}, {64}, Padding::kValid)[0].second}, + /*reducer=*/Reducer::kAdd}, + + {/*base_bounds=*/{128 * 2}, + /*window_bounds=*/{32}, /*strides=*/{56}, - /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + /*pad_low=*/ + {xla::MakePadding({128 * 2}, {32}, {56}, Padding::kValid)[0].first}, + /*pad_high=*/ + {xla::MakePadding({128 * 2}, {32}, {56}, Padding::kValid)[0].second}, + /*reducer=*/Reducer::kAdd}, {/*base_bounds=*/{3}, /*window_bounds=*/{2}, /*strides=*/{1}, - /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd}, + /*pad_low=*/{xla::MakePadding({3}, {2}, {1}, Padding::kSame)[0].first}, + /*pad_high=*/{xla::MakePadding({3}, {2}, {1}, Padding::kSame)[0].second}, + /*reducer=*/Reducer::kAdd}, {/*base_bounds=*/{5}, /*window_bounds=*/{3}, /*strides=*/{2}, - /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd}, + /*pad_low=*/{xla::MakePadding({5}, {3}, {2}, Padding::kSame)[0].first}, + /*pad_high=*/{xla::MakePadding({5}, {3}, {2}, Padding::kSame)[0].second}, + /*reducer=*/Reducer::kAdd}, {/*base_bounds=*/{16}, /*window_bounds=*/{4}, /*strides=*/{3}, - /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd}, + /*pad_low=*/{xla::MakePadding({16}, {4}, {3}, Padding::kSame)[0].first}, + /*pad_high=*/{xla::MakePadding({16}, {4}, {3}, Padding::kSame)[0].second}, + /*reducer=*/Reducer::kAdd}, + + {/*base_bounds=*/{5}, /*window_bounds=*/{5}, + /*strides=*/{1}, + /*pad_low=*/{0}, + /*pad_high=*/{5}, + /*reducer=*/Reducer::kAdd}, + + {/*base_bounds=*/{5}, /*window_bounds=*/{5}, + /*strides=*/{1}, + /*pad_low=*/{5}, + /*pad_high=*/{0}, + /*reducer=*/Reducer::kAdd}, }; string R1ReduceWindowTestDataToString( - const ::testing::TestParamInfo& data) { + const ::testing::TestParamInfo< + ::testing::tuple>& data) { + const auto& param = ::testing::get<0>(data.param); string str = tensorflow::strings::StrCat( - "base_bounds_", - tensorflow::str_util::Join(data.param.base_bounds, "x"), // - "__window_bounds_", - tensorflow::str_util::Join(data.param.window_bounds, "x"), // - "__strides_", tensorflow::str_util::Join(data.param.strides, "x"), // - "__padding_", data.param.padding == Padding::kSame ? "same" : "valid", // - "__reducer_", data.param.reducer == kAdd ? "add" : "max"); + "base_bounds_", tensorflow::str_util::Join(param.base_bounds, "x"), + "__window_bounds_", tensorflow::str_util::Join(param.window_bounds, "x"), + "__strides_", tensorflow::str_util::Join(param.strides, "x"), + "__pad_low_", tensorflow::str_util::Join(param.pad_low, "x"), + "__pad_high_", tensorflow::str_util::Join(param.pad_high, "x"), + "__reducer_", param.reducer == kAdd ? "add" : "max"); + if (::testing::get<1>(data.param)) { + str = tensorflow::strings::StrCat(str, "_bfloat16"); + } return str; } -class R1ReduceWindowTest - : public ClientLibraryTestBase, - public ::testing::WithParamInterface {}; +class R1ReduceWindowTest : public ReduceWindowTestBase, + public ::testing::WithParamInterface< + ::testing::tuple> { + protected: + R1ReduceWindowTest() { set_use_bfloat16(::testing::get<1>(GetParam())); } +}; TEST_P(R1ReduceWindowTest, DoIt) { ComputationBuilder b(client_, TestName()); - const auto& param = GetParam(); + const auto& param = ::testing::get<0>(GetParam()); CHECK(param.reducer == kAdd || param.reducer == kMax); const float kInitValue = 0.0f; @@ -1041,18 +1116,24 @@ TEST_P(R1ReduceWindowTest, DoIt) { std::iota(std::begin(input_vector), std::end(input_vector), 0); std::unique_ptr input_literal = Literal::CreateR1(tensorflow::gtl::ArraySlice(input_vector)); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr input_arg, - client_->TransferToServer(*input_literal)); + ComputationDataHandle parameter; + auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0", + &b, ¶meter); + + std::vector> padding(1); + padding[0] = {param.pad_low[0], param.pad_high[0]}; auto computation = param.reducer == kAdd - ? CreateScalarAddComputation(F32, &b) - : CreateScalarMaxComputation(F32, &b); - b.ReduceWindow(/*operand=*/ - b.Parameter(0, input_literal->shape(), "p0"), - /*init_value=*/b.ConstantR0(kInitValue), - /*computation=*/computation, - /*window_dimensions=*/param.window_bounds, - /*window_strides=*/param.strides, /*padding=*/param.padding); + ? CreateScalarAddComputation(FloatType(), &b) + : CreateScalarMaxComputation(FloatType(), &b); + auto init_value = + CreateConstantFromLiteral(*Literal::CreateR0(kInitValue), &b); + b.ReduceWindowWithGeneralPadding( + /*operand=*/parameter, + /*init_value=*/init_value, + /*computation=*/computation, + /*window_dimensions=*/param.window_bounds, + /*window_strides=*/param.strides, /*padding=*/padding); auto reduce_func = param.reducer == kAdd ? +[](float a, float b) { return a + b; } @@ -1062,14 +1143,73 @@ TEST_P(R1ReduceWindowTest, DoIt) { /*init=*/kInitValue, /*reduce_func=*/reduce_func, /*window=*/param.window_bounds, - /*stride=*/param.strides, /*padding=*/param.padding); + /*stride=*/param.strides, + /*padding=*/padding); - ComputeAndCompareR1(&b, tensorflow::gtl::ArraySlice(*expected), - {input_arg.get()}, ErrorSpec(1e-3, 1e-3)); + ComputeAndCompareLiteral(&b, *Literal::CreateR1(*expected), + {input_arg.get()}, DefaultErrorSpec()); +} + +INSTANTIATE_TEST_CASE_P( + R1ReduceWindowTestInstantiation, R1ReduceWindowTest, + ::testing::Combine(::testing::ValuesIn(kR1TestCases), + ::testing::ValuesIn(use_bfloat16_params)), + R1ReduceWindowTestDataToString); + +// Test class for text-based test cases. Note that this compares with the +// results on the interpreter backend. +class ReduceWindowTextTest : public HloTestBase {}; + +TEST_F(ReduceWindowTextTest, R2General256x384) { + const string& hlo_string = R"( +HloModule R2Window +mul { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT mul = f32[] multiply(lhs, rhs) +} +ENTRY R2Window { + operand = f32[256,384]{1,0} parameter(0) + constant = f32[] constant(1) + ROOT reduce-window = f32[256,384]{1,0} reduce-window(operand, constant), window={size=2x3 pad=0_1x1_1}, to_apply=mul +} +)"; + EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0.001})); +} + +TEST_F(ReduceWindowTextTest, R2General256x384Layout01) { + const string& hlo_string = R"( +HloModule R2Window +mul { +lhs = f32[] parameter(0) +rhs = f32[] parameter(1) +ROOT mul = f32[] multiply(lhs, rhs) +} +ENTRY R2Window { +operand = f32[256,384]{0,1} parameter(0) +constant = f32[] constant(1) +ROOT reduce-window = f32[256,384]{0,1} reduce-window(operand, constant), window={size=2x3 pad=0_1x1_1}, to_apply=mul +} +)"; + EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0.001})); +} + +TEST_F(ReduceWindowTextTest, R2General2x5) { + const string& hlo_string = R"( +HloModule R2Window +mul { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT mul = f32[] multiply(lhs, rhs) +} +ENTRY R2Window { + operand = f32[2,5]{1,0} parameter(0) + constant = f32[] constant(1) + ROOT reduce-window = f32[3,5]{1,0} reduce-window(operand, constant), window={size=2x1 pad=0_2x0_0}, to_apply=mul +} +)"; + EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0.001})); } -INSTANTIATE_TEST_CASE_P(R1ReduceWindowTestInstantiation, R1ReduceWindowTest, - ::testing::ValuesIn(kR1TestCases), - R1ReduceWindowTestDataToString); } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/reshape_test.cc b/tensorflow/compiler/xla/tests/reshape_test.cc index d235b9a1580ecbd6b82a69fca53d259912ff375e..f7b04debd4f5c40a904e32c832b6fc384a03c33b 100644 --- a/tensorflow/compiler/xla/tests/reshape_test.cc +++ b/tensorflow/compiler/xla/tests/reshape_test.cc @@ -41,326 +41,467 @@ limitations under the License. namespace xla { namespace { -class ReshapeTest : public ClientLibraryTestBase { +// Use a bool parameter to indicate whether to use bfloat16. +class ReshapeTest : public ::testing::WithParamInterface, + public ClientLibraryTestBase { public: + ReshapeTest() { set_use_bfloat16(GetParam()); } + ErrorSpec zero_error_spec_{0.0}; }; // Collapses 2-dimensional pseudo-scalar (single-element array) to 1 dimension. -XLA_TEST_F(ReshapeTest, CollapseTrivial1x1) { +XLA_TEST_P(ReshapeTest, CollapseTrivial1x1) { ComputationBuilder builder(client_, TestName()); - auto a = builder.ConstantR2({{1.0}}); - builder.Collapse(/*operand=*/a, /*dimensions=*/{0, 1}); - - ComputeAndCompareR1(&builder, {1.0f}, {}, zero_error_spec_); + Array2D input_array(1, 1); + input_array.Fill(1.0f); + auto input_literal = Literal::CreateR2FromArray2D(input_array); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter", + &builder, ¶meter); + builder.Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); + + auto expected_literal = Literal::CreateR1({1.0f}); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } -XLA_TEST_F(ReshapeTest, CollapseTrivialR1EmptyDims) { +XLA_TEST_P(ReshapeTest, CollapseTrivialR1EmptyDims) { ComputationBuilder builder(client_, TestName()); - auto a = builder.ConstantR1({1.0}); - builder.Collapse(/*operand=*/a, /*dimensions=*/{}); - - ComputeAndCompareR1(&builder, {1.0f}, {}, zero_error_spec_); + auto input_literal = Literal::CreateR1({1.0f}); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter", + &builder, ¶meter); + builder.Collapse(/*operand=*/parameter, /*dimensions=*/{}); + + auto expected_literal = Literal::CreateR1({1.0f}); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } -XLA_TEST_F(ReshapeTest, CollapseTrivialR1OnlyDim) { +XLA_TEST_P(ReshapeTest, CollapseTrivialR1OnlyDim) { ComputationBuilder builder(client_, TestName()); - auto a = builder.ConstantR1({1.0}); - builder.Collapse(/*operand=*/a, /*dimensions=*/{0}); - - ComputeAndCompareR1(&builder, {1.0f}, {}, zero_error_spec_); + auto input_literal = Literal::CreateR1({1.0f}); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter", + &builder, ¶meter); + builder.Collapse(/*operand=*/parameter, /*dimensions=*/{0}); + + auto expected_literal = Literal::CreateR1({1.0f}); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } // Collapses 2-dimensional pseudo-scalar (single-element array) to scalar. -XLA_TEST_F(ReshapeTest, SingleElementArrayToScalar) { +XLA_TEST_P(ReshapeTest, SingleElementArrayToScalar) { ComputationBuilder builder(client_, TestName()); - auto a = builder.ConstantR2({{1.0}}); - auto reshape = - builder.Reshape(/*operand=*/a, /*dimensions=*/{0, 1}, /*new_sizes=*/{}); + Array2D input_array(1, 1); + input_array.Fill(1.0f); + auto input_literal = Literal::CreateR2FromArray2D(input_array); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter", + &builder, ¶meter); + auto reshape = builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, + /*new_sizes=*/{}); auto new_shape = builder.GetShape(reshape).ConsumeValueOrDie(); - ComputeAndCompareR0(&builder, 1.0f, {}, zero_error_spec_); + auto expected_literal = Literal::CreateR0(1.0f); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } -XLA_TEST_F(ReshapeTest, ScalarToSingleElementArray) { +XLA_TEST_P(ReshapeTest, ScalarToSingleElementArray) { ComputationBuilder builder(client_, TestName()); std::unique_ptr param0_literal = Literal::CreateR0(1.0f); - std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); - - auto a = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "param0"); - a = builder.Neg(a); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *param0_literal, "param0", + &builder, ¶meter); + auto a = builder.Neg(parameter); auto reshape = builder.Reshape(/*operand=*/a, /*dimensions=*/{}, /*new_sizes=*/{1}); - ComputeAndCompareR1(&builder, {-1.0f}, {param0_data.get()}, - zero_error_spec_); + auto expected_literal = Literal::CreateR1({-1.0f}); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } -XLA_TEST_F(ReshapeTest, Trivial0x3) { +// TODO(b/29185393): Make this work with the GPU backend. The GPU backend +// does not handle zero-sized shapes correctly. Failed last on 2017-11-30 +// with an incorrect result rank. +XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Trivial0x3)) { ComputationBuilder builder(client_, TestName()); - auto a = builder.ConstantR2FromArray2D(Array2D(0, 3)); - auto result = builder.Collapse(/*operand=*/a, /*dimensions=*/{0, 1}); - - ComputeAndCompareR1(&builder, {}, {}, zero_error_spec_); + Array2D input_array(0, 3); + auto input_literal = Literal::CreateR2FromArray2D(input_array); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); + auto expected_literal = Literal::CreateR1({}); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } // TODO(b/29185393): Make this work with the GPU backend. The GPU backend // does not handle zero-sized shapes correctly. Failed last on 2017-05-15 // with an incorrect result rank. -XLA_TEST_F(ReshapeTest, DISABLED_ON_GPU(Trivial0x3WithParameter)) { +XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Trivial0x3WithParameter)) { ComputationBuilder builder(client_, TestName()); std::unique_ptr param0_literal = Literal::CreateR2FromArray2D(Array2D(0, 3)); - std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); - - auto a = builder.Parameter(0, ShapeUtil::MakeShape(F32, {0, 3}), "param0"); - auto result = builder.Collapse(/*operand=*/a, /*dimensions=*/{0, 1}); - - ComputeAndCompareR1(&builder, {}, {param0_data.get()}, - zero_error_spec_); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *param0_literal, "param0", + &builder, ¶meter); + builder.Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); + auto expected_literal = Literal::CreateR1({}); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } -XLA_TEST_F(ReshapeTest, Trivial3x0) { +// TODO(b/29185393): Make this work with the GPU backend. The GPU backend +// does not handle zero-sized shapes correctly. Failed last on 2017-11-30 +// with an incorrect result rank. +XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Trivial3x0)) { ComputationBuilder builder(client_, TestName()); - auto a = builder.ConstantR2FromArray2D(Array2D(3, 0)); - auto result = builder.Collapse(/*operand=*/a, /*dimensions=*/{0, 1}); - - ComputeAndCompareR1(&builder, {}, {}, zero_error_spec_); + Array2D input_array(3, 0); + auto input_literal = Literal::CreateR2FromArray2D(input_array); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); + auto expected_literal = Literal::CreateR1({}); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } // Collapses a 2-dimensional row vector to 1 dimension. -XLA_TEST_F(ReshapeTest, Trivial1x3) { +XLA_TEST_P(ReshapeTest, Trivial1x3) { ComputationBuilder builder(client_, TestName()); - auto a = builder.ConstantR2({{1.0f, 2.0f, 3.0f}}); - auto result = builder.Collapse(/*operand=*/a, /*dimensions=*/{0, 1}); - - ComputeAndCompareR1(&builder, {1.0f, 2.0f, 3.0f}, {}, - zero_error_spec_); + auto input_literal = Literal::CreateR2({{1.0f, 2.0f, 3.0f}}); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); + auto expected_literal = Literal::CreateR1({1.0f, 2.0f, 3.0f}); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } // Collapses a 2-dimensional column vector to 1 dimension. -XLA_TEST_F(ReshapeTest, Trivial3x1) { +XLA_TEST_P(ReshapeTest, Trivial3x1) { ComputationBuilder builder(client_, TestName()); - auto a = builder.ConstantR2({{1.0f}, {2.0f}, {3.0f}}); - auto result = builder.Collapse(/*operand=*/a, /*dimensions=*/{0, 1}); - - ComputeAndCompareR1(&builder, {1.0f, 2.0f, 3.0f}, {}, - zero_error_spec_); + auto input_literal = Literal::CreateR2({{1.0f}, {2.0f}, {3.0f}}); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); + auto expected_literal = Literal::CreateR1({1.0f, 2.0f, 3.0f}); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } +// TODO(b/29185393): Make this work with the GPU backend. The GPU backend +// does not handle zero-sized shapes correctly. Failed last on 2017-11-30 +// with an incorrect result rank. +// // Splits an empty vector into an empty matrix. -XLA_TEST_F(ReshapeTest, R1ToR2_0_To_2x0) { +XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(R1ToR2_0_To_2x0)) { ComputationBuilder builder(client_, TestName()); - auto a = builder.ConstantR1({}); - auto result = - builder.Reshape(/*operand=*/a, /*dimensions=*/{0}, /*new_sizes=*/{2, 0}); - ComputeAndCompareR2(&builder, Array2D(2, 0), {}, - zero_error_spec_); + auto input_literal = Literal::CreateR1({}); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0}, + /*new_sizes=*/{2, 0}); + auto expected_literal = Literal::CreateR2({{}, {}}); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } // Splits a vector into a matrix. -XLA_TEST_F(ReshapeTest, R1ToR2_6_To_2x3) { +XLA_TEST_P(ReshapeTest, R1ToR2_6_To_2x3) { ComputationBuilder builder(client_, TestName()); - auto a = builder.ConstantR1({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}); - auto result = - builder.Reshape(/*operand=*/a, /*dimensions=*/{0}, /*new_sizes=*/{2, 3}); - Array2D expected_2x3({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}); - ComputeAndCompareR2(&builder, expected_2x3, {}, zero_error_spec_); + auto input_literal = + Literal::CreateR1({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0}, + /*new_sizes=*/{2, 3}); + auto expected_literal = + Literal::CreateR2({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } +// TODO(b/29185393): Make this work with the GPU backend. The GPU backend +// does not handle zero-sized shapes correctly. Failed last on 2017-11-30 +// with an incorrect result rank. +// // Transposes a 2x0 array to a 0x2 array. -XLA_TEST_F(ReshapeTest, Reshape0x2To2x0) { +XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Reshape0x2To2x0)) { ComputationBuilder builder(client_, TestName()); - auto a = builder.ConstantR2FromArray2D(Array2D(0, 2)); - auto result = builder.Reshape(/*operand=*/a, /*dimensions=*/{0, 1}, - /*new_sizes=*/{2, 0}); - - ComputeAndCompareR2(&builder, Array2D(2, 0), {}, - zero_error_spec_); + auto input_literal = Literal::CreateFromArray(Array2D(0, 2)); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, + /*new_sizes=*/{2, 0}); + auto expected_literal = Literal::CreateR2({{}, {}}); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } // Transposes a 2-dimensional row vector to a column vector. -XLA_TEST_F(ReshapeTest, ReshapeRowToCol) { +XLA_TEST_P(ReshapeTest, ReshapeRowToCol) { ComputationBuilder builder(client_, TestName()); auto simple = MakeLinspaceArray2D(1.0f, 3.0f, 1, 3); - auto a = builder.ConstantR2FromArray2D(*simple); - auto result = builder.Reshape(/*operand=*/a, /*dimensions=*/{0, 1}, - /*new_sizes=*/{3, 1}); + auto input_literal = Literal::CreateFromArray(*simple); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, + /*new_sizes=*/{3, 1}); auto expected = ReferenceUtil::TransposeArray2D(*simple); - ComputeAndCompareR2(&builder, *expected, {}, zero_error_spec_); + auto expected_literal = Literal::CreateFromArray(*expected); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } // Transposes a 2-dimensional array. -XLA_TEST_F(ReshapeTest, TransposeAsReshape) { +XLA_TEST_P(ReshapeTest, TransposeAsReshape) { ComputationBuilder builder(client_, TestName()); auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3); - auto a = builder.ConstantR2FromArray2D(*a4x3); - auto result = builder.Reshape(/*operand=*/a, /*dimensions=*/{1, 0}, - /*new_sizes=*/{3, 4}); - - auto expected3x4 = ReferenceUtil::TransposeArray2D(*a4x3); - ComputeAndCompareR2(&builder, *expected3x4, {}, zero_error_spec_); + auto input_literal = Literal::CreateFromArray(*a4x3); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Reshape(/*operand=*/parameter, /*dimensions=*/{1, 0}, + /*new_sizes=*/{3, 4}); + + auto expected = ReferenceUtil::TransposeArray2D(*a4x3); + auto expected_literal = Literal::CreateFromArray(*expected); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } +// TODO(b/29185393): Make this work with the GPU backend. The GPU backend +// does not handle zero-sized shapes correctly. Failed last on 2017-11-30 +// with an incorrect result rank. +// // Transposes a 0x4 array with ComputationBuilder::Trans. -XLA_TEST_F(ReshapeTest, Transpose0x4) { +XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Transpose0x4)) { ComputationBuilder builder(client_, TestName()); - auto a = builder.ConstantR2FromArray2D(Array2D(0, 4)); - auto result = builder.Transpose(a, {1, 0}); - - ComputeAndCompareR2(&builder, Array2D(4, 0), {}, - zero_error_spec_); + auto input_literal = Literal::CreateFromArray(Array2D(0, 4)); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Transpose(parameter, {1, 0}); + auto expected_literal = Literal::CreateR2({{}, {}, {}, {}}); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } // Transposes a 2-dimensional array with ComputationBuilder::Trans. -XLA_TEST_F(ReshapeTest, Transpose4x3) { +XLA_TEST_P(ReshapeTest, Transpose4x3) { ComputationBuilder builder(client_, TestName()); auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3); - auto a = builder.ConstantR2FromArray2D(*a4x3); - auto result = builder.Transpose(a, {1, 0}); - - auto expected3x4 = ReferenceUtil::TransposeArray2D(*a4x3); - ComputeAndCompareR2(&builder, *expected3x4, {}, zero_error_spec_); + auto input_literal = Literal::CreateFromArray(*a4x3); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Transpose(parameter, {1, 0}); + + auto expected = ReferenceUtil::TransposeArray2D(*a4x3); + auto expected_literal = Literal::CreateFromArray(*expected); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } +// TODO(b/29185393): Make this work with the GPU backend. The GPU backend +// does not handle zero-sized shapes correctly. Failed last on 2017-11-30 +// with an incorrect result rank. +// // Reshapes an empty 2-dimensional array with dimensions that are not just a // rearrangement of the originals (split), but no reordering (no shuffle). -XLA_TEST_F(ReshapeTest, ReshapeSplitNoShuffleZeroElements) { +XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(ReshapeSplitNoShuffleZeroElements)) { ComputationBuilder builder(client_, TestName()); - auto a = builder.ConstantR2FromArray2D(Array2D(6, 0)); - auto result = builder.Reshape(/*operand=*/a, /*dimensions=*/{0, 1}, - /*new_sizes=*/{2, 3, 0, 0}); - - ComputeAndCompareR4(&builder, Array4D(2, 3, 0, 0), {}, - zero_error_spec_); + auto input_literal = Literal::CreateFromArray(Array2D(6, 0)); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, + /*new_sizes=*/{2, 3, 0, 0}); + auto expected_literal = Literal::CreateFromArray(Array4D(2, 3, 0, 0)); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } -XLA_TEST_F(ReshapeTest, ReshapeR4ToR2ZeroElements) { +// TODO(b/29185393): Make this work with the GPU backend. The GPU backend +// does not handle zero-sized shapes correctly. Failed last on 2017-11-30 +// with an incorrect result rank. +XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(ReshapeR4ToR2ZeroElements)) { ComputationBuilder builder(client_, TestName()); - auto a = builder.ConstantR4FromArray4D(Array4D(2, 3, 4, 0)); - auto result = builder.Reshape(/*operand=*/a, /*dimensions=*/{0, 1, 2, 3}, - /*new_sizes=*/{24, 0}); - - ComputeAndCompareR2(&builder, Array2D(24, 0), {}, - zero_error_spec_); + auto input_literal = Literal::CreateFromArray(Array4D(2, 3, 4, 0)); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2, 3}, + /*new_sizes=*/{24, 0}); + auto expected_literal = Literal::CreateFromArray(Array2D(24, 0)); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } // Reshapes a 2-dimensional array with dimensions that are not just a // rearrangement of the originals (split), but no reordering (no shuffle). -XLA_TEST_F(ReshapeTest, ReshapeSplitNoShuffle) { +XLA_TEST_P(ReshapeTest, ReshapeSplitNoShuffle) { ComputationBuilder builder(client_, TestName()); auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3); - auto a = builder.ConstantR2FromArray2D(*a4x3); - auto result = builder.Reshape(/*operand=*/a, /*dimensions=*/{0, 1}, - /*new_sizes=*/{2, 6}); - - auto expected2x6 = MakeLinspaceArray2D(1.0f, 12.0f, 2, 6); - ComputeAndCompareR2(&builder, *expected2x6, {}, zero_error_spec_); + auto input_literal = Literal::CreateFromArray(*a4x3); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, + /*new_sizes=*/{2, 6}); + + auto expected = MakeLinspaceArray2D(1.0f, 12.0f, 2, 6); + auto expected_literal = Literal::CreateFromArray(*expected); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } -// Reshapes a 2-dimensional array with dimensions that are not just a -// rearrangement of the originals (split), and reorder the input (shuffle). -XLA_TEST_F(ReshapeTest, ReshapeSplitAndShuffleZeroElements) { +// TODO(b/29185393): Make this work with the GPU backend. The GPU backend +// does not handle zero-sized shapes correctly. Failed last on 2017-11-30 +// with an incorrect result rank. +// +XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(ReshapeSplitAndShuffleZeroElements)) { ComputationBuilder builder(client_, TestName()); - auto a = builder.ConstantR2FromArray2D(Array2D(0, 6)); - auto result = builder.Reshape(/*operand=*/a, /*dimensions=*/{1, 0}, - /*new_sizes=*/{3, 0}); - - ComputeAndCompareR2(&builder, Array2D(3, 0), {}, - zero_error_spec_); + auto input_literal = Literal::CreateFromArray(Array2D(0, 6)); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Reshape(/*operand=*/parameter, /*dimensions=*/{1, 0}, + /*new_sizes=*/{3, 0}); + auto expected_literal = Literal::CreateFromArray(Array2D(3, 0)); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } // Reshapes a 2-dimensional array with dimensions that are not just a // rearrangement of the originals (split), and reorder the input (shuffle). -XLA_TEST_F(ReshapeTest, ReshapeSplitAndShuffle) { +XLA_TEST_P(ReshapeTest, ReshapeSplitAndShuffle) { ComputationBuilder builder(client_, TestName()); auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3); - auto a = builder.ConstantR2FromArray2D(*a4x3); - auto result = builder.Reshape(/*operand=*/a, /*dimensions=*/{1, 0}, - /*new_sizes=*/{2, 6}); - - Array2D expected2x6({{1.0f, 4.0f, 7.0f, 10.0f, 2.0f, 5.0f}, - {8.0f, 11.0f, 3.0f, 6.0f, 9.0f, 12.0f}}); - ComputeAndCompareR2(&builder, expected2x6, {}, zero_error_spec_); + auto input_literal = Literal::CreateFromArray(*a4x3); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Reshape(/*operand=*/parameter, /*dimensions=*/{1, 0}, + /*new_sizes=*/{2, 6}); + Array2D expected({{1.0f, 4.0f, 7.0f, 10.0f, 2.0f, 5.0f}, + {8.0f, 11.0f, 3.0f, 6.0f, 9.0f, 12.0f}}); + auto expected_literal = Literal::CreateFromArray(expected); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } // The following tests use the same input 3D array; they test the examples we // show for the Reshape operation in the operation_semantics document. // TODO(b/34503277): find a way to show this code in the documentation without // duplication on the TF documentation server. -Array3D v_array_for_doc_R3_tests({{{10, 11, 12}, {15, 16, 17}}, - {{20, 21, 22}, {25, 26, 27}}, - {{30, 31, 32}, {35, 36, 37}}, - {{40, 41, 42}, {45, 46, 47}}}); - -XLA_TEST_F(ReshapeTest, DocR3_R1_Collapse_012) { - ComputationBuilder builder(client_, TestName()); - auto v = builder.ConstantR3FromArray3D(v_array_for_doc_R3_tests); - auto result = builder.Reshape(/*operand=*/v, /*dimensions=*/{0, 1, 2}, - /*new_sizes=*/{24}); - ComputeAndCompareR1(&builder, - {10, 11, 12, 15, 16, 17, 20, 21, 22, 25, 26, 27, - 30, 31, 32, 35, 36, 37, 40, 41, 42, 45, 46, 47}, - {}); -} - -XLA_TEST_F(ReshapeTest, DocR3_R2_Collapse_012_Refine_83) { - ComputationBuilder builder(client_, TestName()); - auto v = builder.ConstantR3FromArray3D(v_array_for_doc_R3_tests); - auto result = builder.Reshape(/*operand=*/v, /*dimensions=*/{0, 1, 2}, - /*new_sizes=*/{8, 3}); - Array2D expected({{10, 11, 12}, - {15, 16, 17}, - {20, 21, 22}, - {25, 26, 27}, - {30, 31, 32}, - {35, 36, 37}, - {40, 41, 42}, - {45, 46, 47}}); - ComputeAndCompareR2(&builder, expected, {}); -} - -XLA_TEST_F(ReshapeTest, DocR3_R1_Collapse_120) { - ComputationBuilder builder(client_, TestName()); - auto v = builder.ConstantR3FromArray3D(v_array_for_doc_R3_tests); - auto result = builder.Reshape(/*operand=*/v, /*dimensions=*/{1, 2, 0}, - /*new_sizes=*/{24}); - ComputeAndCompareR1(&builder, - {10, 20, 30, 40, 11, 21, 31, 41, 12, 22, 32, 42, - 15, 25, 35, 45, 16, 26, 36, 46, 17, 27, 37, 47}, - {}); -} - -XLA_TEST_F(ReshapeTest, DocR3_R2_Collapse_120_Refine_83) { - ComputationBuilder builder(client_, TestName()); - auto v = builder.ConstantR3FromArray3D(v_array_for_doc_R3_tests); - auto result = builder.Reshape(/*operand=*/v, /*dimensions=*/{1, 2, 0}, - /*new_sizes=*/{8, 3}); - Array2D expected({{10, 20, 30}, - {40, 11, 21}, - {31, 41, 12}, - {22, 32, 42}, - {15, 25, 35}, - {45, 16, 26}, - {36, 46, 17}, - {27, 37, 47}}); - ComputeAndCompareR2(&builder, expected, {}); -} - -XLA_TEST_F(ReshapeTest, DocR3_R3_Collapse_120_Refine_262) { - ComputationBuilder builder(client_, TestName()); - auto v = builder.ConstantR3FromArray3D(v_array_for_doc_R3_tests); - auto result = builder.Reshape(/*operand=*/v, /*dimensions=*/{1, 2, 0}, - /*new_sizes=*/{2, 6, 2}); - Array3D expected( +static Array3D ArrayForDocR3Tests() { + return Array3D({{{10, 11, 12}, {15, 16, 17}}, + {{20, 21, 22}, {25, 26, 27}}, + {{30, 31, 32}, {35, 36, 37}}, + {{40, 41, 42}, {45, 46, 47}}}); +} + +XLA_TEST_P(ReshapeTest, DocR3_R1_Collapse_012) { + ComputationBuilder builder(client_, TestName()); + auto input_literal = Literal::CreateFromArray(ArrayForDocR3Tests()); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2}, + /*new_sizes=*/{24}); + auto expected_literal = Literal::CreateR1( + {10, 11, 12, 15, 16, 17, 20, 21, 22, 25, 26, 27, + 30, 31, 32, 35, 36, 37, 40, 41, 42, 45, 46, 47}); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); +} + +XLA_TEST_P(ReshapeTest, DocR3_R2_Collapse_012_Refine_83) { + ComputationBuilder builder(client_, TestName()); + auto input_literal = Literal::CreateFromArray(ArrayForDocR3Tests()); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2}, + /*new_sizes=*/{8, 3}); + auto expected_literal = Literal::CreateR2({{10, 11, 12}, + {15, 16, 17}, + {20, 21, 22}, + {25, 26, 27}, + {30, 31, 32}, + {35, 36, 37}, + {40, 41, 42}, + {45, 46, 47}}); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); +} + +XLA_TEST_P(ReshapeTest, DocR3_R1_Collapse_120) { + ComputationBuilder builder(client_, TestName()); + auto input_literal = Literal::CreateFromArray(ArrayForDocR3Tests()); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Reshape(/*operand=*/parameter, /*dimensions=*/{1, 2, 0}, + /*new_sizes=*/{24}); + auto expected_literal = Literal::CreateR1( + {10, 20, 30, 40, 11, 21, 31, 41, 12, 22, 32, 42, + 15, 25, 35, 45, 16, 26, 36, 46, 17, 27, 37, 47}); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); +} + +XLA_TEST_P(ReshapeTest, DocR3_R2_Collapse_120_Refine_83) { + ComputationBuilder builder(client_, TestName()); + auto input_literal = Literal::CreateFromArray(ArrayForDocR3Tests()); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Reshape(/*operand=*/parameter, /*dimensions=*/{1, 2, 0}, + /*new_sizes=*/{8, 3}); + auto expected_literal = Literal::CreateR2({{10, 20, 30}, + {40, 11, 21}, + {31, 41, 12}, + {22, 32, 42}, + {15, 25, 35}, + {45, 16, 26}, + {36, 46, 17}, + {27, 37, 47}}); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); +} + +XLA_TEST_P(ReshapeTest, DocR3_R3_Collapse_120_Refine_262) { + ComputationBuilder builder(client_, TestName()); + auto input_literal = Literal::CreateFromArray(ArrayForDocR3Tests()); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Reshape(/*operand=*/parameter, /*dimensions=*/{1, 2, 0}, + /*new_sizes=*/{2, 6, 2}); + auto expected_literal = Literal::CreateR3( {{{10, 20}, {30, 40}, {11, 21}, {31, 41}, {12, 22}, {32, 42}}, {{15, 25}, {35, 45}, {16, 26}, {36, 46}, {17, 27}, {37, 47}}}); - ComputeAndCompareR3(&builder, expected, {}); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } // Collapses the low dimensions of a 4D tensor to get a 2D matrix, without @@ -378,23 +519,26 @@ XLA_TEST_F(ReshapeTest, DocR3_R3_Collapse_120_Refine_262) { // Then we collapse Z be collapsed so we just end up with planes: // // 1 2 3 4 5 6 1 2 3 4 5 6 -XLA_TEST_F(ReshapeTest, FullyConnectedCollapse) { +XLA_TEST_P(ReshapeTest, FullyConnectedCollapse) { ComputationBuilder builder(client_, TestName()); Array4D t2x2x2x3(2, 2, 2, 3); auto filler2x3 = MakeLinspaceArray2D(1.0f, 6.0f, 2, 3); t2x2x2x3.FillWithYX(*filler2x3); - auto a = builder.ConstantR4FromArray4D(t2x2x2x3); - auto result = builder.Collapse(/*operand=*/a, /*dimensions=*/{1, 2, 3}); - - Array2D expected2x12( + auto input_literal = Literal::CreateFromArray(t2x2x2x3); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Collapse(/*operand=*/parameter, /*dimensions=*/{1, 2, 3}); + auto expected_literal = Literal::CreateR2( {{1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}}); - ComputeAndCompareR2(&builder, expected2x12, {}, zero_error_spec_); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } // As above, but uses reshape directly. -XLA_TEST_F(ReshapeTest, FullyConnectedCollapseDesugared) { +XLA_TEST_P(ReshapeTest, FullyConnectedCollapseDesugared) { ComputationBuilder builder(client_, TestName()); Array4D t(2, 1, 2, 2); t(0, 0, 0, 0) = 0; @@ -405,52 +549,68 @@ XLA_TEST_F(ReshapeTest, FullyConnectedCollapseDesugared) { t(1, 0, 0, 1) = 5; t(1, 0, 1, 0) = 6; t(1, 0, 1, 1) = 7; - auto a = builder.ConstantR4FromArray4D(t); - auto result = builder.Reshape(/*operand=*/a, /*dimensions=*/{0, 1, 2, 3}, - /*new_sizes=*/{2, 4}); - - Array2D expected({{0, 1, 2, 3}, {4, 5, 6, 7}}); - ComputeAndCompareR2(&builder, expected, {}, zero_error_spec_); + auto input_literal = Literal::CreateFromArray(t); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2, 3}, + /*new_sizes=*/{2, 4}); + + auto expected_literal = + Literal::CreateR2({{0, 1, 2, 3}, {4, 5, 6, 7}}); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } // Reshape various ranks to a scalar. -XLA_TEST_F(ReshapeTest, ToScalar) { +XLA_TEST_P(ReshapeTest, ToScalar) { for (int rank = 0; rank < 8; ++rank) { ComputationBuilder b(client_, TestName()); - auto input = Literal::CreateR1({83.0f}); std::vector ones(rank, 1); // this is {1, ..., 1}. std::vector dimensions(rank); std::iota(dimensions.begin(), dimensions.end(), 0); - *input->mutable_shape() = ShapeUtil::MakeShape(F32, ones); - b.Reshape(b.ConstantLiteral(*input), dimensions, {}); + Literal input_literal(ShapeUtil::MakeShape(F32, ones)); + std::vector zeros(rank, 0); // this is {0, ..., 0}. + input_literal.Set(zeros, 83.0f); - ComputeAndCompareR0(&b, 83.0f, {}, zero_error_spec_); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", + &b, ¶meter); + b.Reshape(parameter, dimensions, {}); + + auto expected_literal = Literal::CreateR0(83.0f); + ComputeAndCompareLiteral(&b, *expected_literal, {input.get()}, + zero_error_spec_); } } -XLA_TEST_F(ReshapeTest, BadDimensions) { +XLA_TEST_P(ReshapeTest, BadDimensions) { ComputationBuilder b(client_, TestName()); - b.Reshape(b.ConstantR1({1}), {}, {}); + auto input_literal = Literal::CreateR1({1.0f}); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &b, + ¶meter); + b.Reshape(parameter, {}, {}); EXPECT_THAT( ExecuteToString(&b, {}), ::testing::HasSubstr("not a permutation of the operand dimensions")); } -XLA_TEST_F(ReshapeTest, BadNewSizes) { +XLA_TEST_P(ReshapeTest, BadNewSizes) { ComputationBuilder b(client_, TestName()); - b.Reshape(b.ConstantR1({1, 2}), {1}, {}); + auto input_literal = Literal::CreateR1({1.0f, 2.0f}); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &b, + ¶meter); + b.Reshape(parameter, {1}, {}); EXPECT_THAT(ExecuteToString(&b, {}), ::testing::HasSubstr("mismatched element counts")); } -XLA_TEST_F(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) { - const Shape parameter_shape = ShapeUtil::MakeShape(F32, {2, 2, 2, 2}); +XLA_TEST_P(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) { ComputationBuilder builder(client_, TestName()); - auto a = builder.Parameter(0, parameter_shape, "a"); - builder.Reshape(a, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{2, 8}); - // clang-format off - auto literal = Literal::CreateR4FromArray4DWithLayout(Array4D{ + auto input_literal = Literal::CreateR4FromArray4DWithLayout(Array4D{ { { {0, 1}, @@ -474,8 +634,12 @@ XLA_TEST_F(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) { }, LayoutUtil::MakeLayout({0, 1, 2, 3})); // clang-format on - std::unique_ptr input = - client_->TransferToServer(*literal).ConsumeValueOrDie(); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + + builder.Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{2, 8}); + Array2D expected_array({ {0, 1, 2, 3, 100, 101, 102, 103}, {222, 333, 444, 555, 666, 777, 888, 999}, @@ -484,72 +648,75 @@ XLA_TEST_F(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) { Computation computation = builder.Build().ConsumeValueOrDie(); ExecutionOptions execution_options = execution_options_; *execution_options.mutable_shape_with_output_layout() = - ShapeUtil::MakeShapeWithLayout(F32, {2, 8}, {1, 0}); + ShapeUtil::MakeShapeWithLayout(use_bfloat16() ? BF16 : F32, {2, 8}, + {1, 0}); std::unique_ptr actual = client_ ->ExecuteAndTransfer(computation, {input.get()}, &execution_options) .ConsumeValueOrDie(); std::unique_ptr expected = Literal::CreateR2FromArray2D(expected_array); + if (use_bfloat16()) { + expected = LiteralTestUtil::ConvertF32ToBF16(*expected); + } LiteralTestUtil::ExpectEqual(*expected, *actual); } -XLA_TEST_F(ReshapeTest, R2ToR4_3x8_To_3x2x1x4) { - std::unique_ptr input = Literal::CreateR2({ +XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4) { + ComputationBuilder builder(client_, TestName()); + std::unique_ptr input_literal = Literal::CreateR2({ {0, 1, 2, 3, 4, 5, 6, 7}, {100, 101, 102, 103, 104, 105, 106, 107}, {200, 201, 202, 203, 204, 205, 206, 207}, }); - std::unique_ptr input_data = - client_->TransferToServer(*input).ConsumeValueOrDie(); - - ComputationBuilder builder(client_, TestName()); - auto a = builder.Parameter(0, input->shape(), "a"); - builder.Reshape(a, /*dimensions=*/{0, 1}, /*new_sizes=*/{3, 2, 1, 4}); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Reshape(parameter, /*dimensions=*/{0, 1}, /*new_sizes=*/{3, 2, 1, 4}); // clang-format off - Array4D expected = { + auto expected_literal = Literal::CreateR4({ {{{0, 1, 2, 3}}, {{4, 5, 6, 7}}}, {{{100, 101, 102, 103}}, {{104, 105, 106, 107}}}, {{{200, 201, 202, 203}}, {{204, 205, 206, 207}}} - }; + }); // clang-format on - ComputeAndCompareR4(&builder, expected, {input_data.get()}, - zero_error_spec_); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } // Tests R2->R4 reshape with the reshape dimensions {1, 0}. -XLA_TEST_F(ReshapeTest, R2ToR4_3x8_To_3x2x1x4_Dimensions_10) { - std::unique_ptr input = Literal::CreateR2({ +XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4_Dimensions_10) { + ComputationBuilder builder(client_, TestName()); + std::unique_ptr input_literal = Literal::CreateR2({ {0, 1, 2, 3, 4, 5, 6, 7}, {100, 101, 102, 103, 104, 105, 106, 107}, {200, 201, 202, 203, 204, 205, 206, 207}, }); - std::unique_ptr input_data = - client_->TransferToServer(*input).ConsumeValueOrDie(); - - ComputationBuilder builder(client_, TestName()); - auto a = builder.Parameter(0, input->shape(), "a"); - builder.Reshape(a, /*dimensions=*/{1, 0}, /*new_sizes=*/{3, 2, 1, 4}); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Reshape(parameter, /*dimensions=*/{1, 0}, /*new_sizes=*/{3, 2, 1, 4}); // clang-format off - Array4D expected = { + auto expected_literal = Literal::CreateR4({ {{{0, 100, 200, 1}}, {{101, 201, 2, 102}}}, {{{202, 3, 103, 203}}, {{4, 104, 204, 5}}}, {{{105, 205, 6, 106}}, {{206, 7, 107, 207}}} - }; + }); // clang-format on - ComputeAndCompareR4(&builder, expected, {input_data.get()}, - zero_error_spec_); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } -XLA_TEST_F(ReshapeTest, R4ToR2_2x1x1x1_To_2x1) { +XLA_TEST_P(ReshapeTest, R4ToR2_2x1x1x1_To_2x1) { + ComputationBuilder builder(client_, TestName()); std::mt19937 rng; std::uniform_real_distribution distribution; Array4D input(2, 1, 1, 1); @@ -559,12 +726,10 @@ XLA_TEST_F(ReshapeTest, R4ToR2_2x1x1x1_To_2x1) { std::unique_ptr input_literal = Literal::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); - std::unique_ptr input_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); - - ComputationBuilder builder(client_, TestName()); - auto a = builder.Parameter(0, input_literal->shape(), "a"); - builder.Reshape(a, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{2, 1}); + ComputationDataHandle parameter; + auto input_data = CreateParameterAndTransferLiteral( + 0, *input_literal, "input", &builder, ¶meter); + builder.Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{2, 1}); std::unique_ptr expected = LiteralTestUtil::Reshape({2, 1}, {1, 0}, *input_literal); @@ -572,7 +737,8 @@ XLA_TEST_F(ReshapeTest, R4ToR2_2x1x1x1_To_2x1) { zero_error_spec_); } -XLA_TEST_F(ReshapeTest, R4ToR2_2x1x4x1_To_4x2) { +XLA_TEST_P(ReshapeTest, R4ToR2_2x1x4x1_To_4x2) { + ComputationBuilder builder(client_, TestName()); std::mt19937 rng; std::uniform_real_distribution distribution; Array4D input(2, 1, 4, 1); @@ -582,12 +748,10 @@ XLA_TEST_F(ReshapeTest, R4ToR2_2x1x4x1_To_4x2) { std::unique_ptr input_literal = Literal::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); - std::unique_ptr input_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); - - ComputationBuilder builder(client_, TestName()); - auto a = builder.Parameter(0, input_literal->shape(), "a"); - builder.Reshape(a, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{4, 2}); + ComputationDataHandle parameter; + auto input_data = CreateParameterAndTransferLiteral( + 0, *input_literal, "input", &builder, ¶meter); + builder.Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{4, 2}); std::unique_ptr expected = LiteralTestUtil::Reshape({4, 2}, {1, 0}, *input_literal); @@ -596,7 +760,8 @@ XLA_TEST_F(ReshapeTest, R4ToR2_2x1x4x1_To_4x2) { } // Tests R4->R2 reshape with the reshape dimensions {0, 2, 1, 3}. -XLA_TEST_F(ReshapeTest, R4ToR2_5x10x2x3_To_5x60_Dimensions_0213) { +XLA_TEST_P(ReshapeTest, R4ToR2_5x10x2x3_To_5x60_Dimensions_0213) { + ComputationBuilder builder(client_, TestName()); std::mt19937 rng; std::uniform_real_distribution distribution; Array4D input(5, 10, 2, 3); @@ -606,12 +771,11 @@ XLA_TEST_F(ReshapeTest, R4ToR2_5x10x2x3_To_5x60_Dimensions_0213) { std::unique_ptr input_literal = Literal::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); - std::unique_ptr input_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); - - ComputationBuilder builder(client_, TestName()); - auto a = builder.Parameter(0, input_literal->shape(), "a"); - builder.Reshape(a, /*dimensions=*/{0, 2, 1, 3}, /*new_sizes=*/{5, 60}); + ComputationDataHandle parameter; + auto input_data = CreateParameterAndTransferLiteral( + 0, *input_literal, "input", &builder, ¶meter); + builder.Reshape(parameter, /*dimensions=*/{0, 2, 1, 3}, + /*new_sizes=*/{5, 60}); Array2D expected_array(5, 60); input.Each([&](tensorflow::gtl::ArraySlice indices, float* cell) { @@ -619,10 +783,12 @@ XLA_TEST_F(ReshapeTest, R4ToR2_5x10x2x3_To_5x60_Dimensions_0213) { *cell; }); auto expected = Literal::CreateR2FromArray2D(expected_array); - ComputeAndCompareLiteral(&builder, *expected, {input_data.get()}); + ComputeAndCompareLiteral(&builder, *expected, {input_data.get()}, + zero_error_spec_); } -XLA_TEST_F(ReshapeTest, NoopReshape) { +XLA_TEST_P(ReshapeTest, NoopReshape) { + ComputationBuilder builder(client_, TestName()); std::mt19937 rng; std::uniform_real_distribution distribution; Array4D input_array(2, 3, 5, 7); @@ -632,18 +798,17 @@ XLA_TEST_F(ReshapeTest, NoopReshape) { std::unique_ptr input_literal = Literal::CreateR4FromArray4DWithLayout( input_array, LayoutUtil::MakeLayout({1, 2, 3, 0})); - std::unique_ptr input_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); - - ComputationBuilder builder(client_, TestName()); - auto input = builder.Parameter(0, input_literal->shape(), "input"); - builder.Reshape(input, /*dimensions=*/{3, 0, 1, 2}, + ComputationDataHandle parameter; + auto input_data = CreateParameterAndTransferLiteral( + 0, *input_literal, "input", &builder, ¶meter); + builder.Reshape(parameter, /*dimensions=*/{3, 0, 1, 2}, /*new_sizes=*/{7, 2, 3, 5}); Computation computation = builder.Build().ConsumeValueOrDie(); ExecutionOptions execution_options = execution_options_; *execution_options.mutable_shape_with_output_layout() = - ShapeUtil::MakeShapeWithLayout(F32, {7, 2, 3, 5}, {2, 3, 0, 1}); + ShapeUtil::MakeShapeWithLayout(use_bfloat16() ? BF16 : F32, {7, 2, 3, 5}, + {2, 3, 0, 1}); std::unique_ptr output_literal = client_ ->ExecuteAndTransfer(computation, {input_data.get()}, @@ -652,35 +817,43 @@ XLA_TEST_F(ReshapeTest, NoopReshape) { // Since the reshape is a no-op, verify that it does not change the underlying // data. - EXPECT_EQ(tensorflow::gtl::ArraySlice(input_literal->f32s()), - tensorflow::gtl::ArraySlice(output_literal->f32s())); + if (use_bfloat16()) { + auto expected = LiteralTestUtil::ConvertF32ToBF16(*input_literal); + EXPECT_EQ(expected->data(), output_literal->data()); + } else { + EXPECT_EQ(input_literal->data(), output_literal->data()); + } } -XLA_TEST_F(ReshapeTest, R4ToR4Reshape_Trivial) { - auto literal_1x2x3x4 = Literal::CreateR4( +XLA_TEST_P(ReshapeTest, R4ToR4Reshape_Trivial) { + ComputationBuilder builder(client_, TestName()); + auto literal_1x2x3x4 = Literal::CreateR4( {{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, {{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}}); - ComputationBuilder builder(client_, TestName()); - auto input = builder.ConstantLiteral(*literal_1x2x3x4); - builder.Reshape(input, /*dimensions=*/{0, 1, 2, 3}, + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *literal_1x2x3x4, "input", + &builder, ¶meter); + builder.Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{1, 2, 3, 4}); - ComputeAndCompareLiteral(&builder, *literal_1x2x3x4, {}); + ComputeAndCompareLiteral(&builder, *literal_1x2x3x4, {input.get()}); } -XLA_TEST_F(ReshapeTest, R4ToR4Reshape) { - auto literal_1x2x3x4 = Literal::CreateR4( +XLA_TEST_P(ReshapeTest, R4ToR4Reshape) { + auto literal_1x2x3x4 = Literal::CreateR4( {{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, {{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}}); ComputationBuilder builder(client_, TestName()); - auto input = builder.ConstantLiteral(*literal_1x2x3x4); - builder.Reshape(input, /*dimensions=*/{1, 3, 2, 0}, + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *literal_1x2x3x4, "input", + &builder, ¶meter); + builder.Reshape(parameter, /*dimensions=*/{1, 3, 2, 0}, /*new_sizes=*/{2, 4, 3, 1}); // clang-format off - auto expected_2x4x3x1 = Literal::CreateR4( + auto expected_2x4x3x1 = Literal::CreateR4( {{{{1}, {5}, {9}}, {{2}, {6}, {10}}, {{3}, {7}, {11}}, @@ -691,10 +864,10 @@ XLA_TEST_F(ReshapeTest, R4ToR4Reshape) { {{16}, {20}, {24}}}}); // clang-format on - ComputeAndCompareLiteral(&builder, *expected_2x4x3x1, {}); + ComputeAndCompareLiteral(&builder, *expected_2x4x3x1, {input.get()}); } -XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeSimple) { +XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeSimple) { std::mt19937 rng; std::uniform_real_distribution distribution; std::vector bounds = {2, 2, 2, 2}; @@ -706,12 +879,12 @@ XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeSimple) { std::unique_ptr input_literal = Literal::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); - std::unique_ptr input_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); - ComputationBuilder builder(client_, TestName()); - auto a = builder.Parameter(0, input_literal->shape(), "a"); - builder.Reshape(a, /*dimensions=*/{0, 1, 3, 2}, /*new_sizes=*/new_bounds); + ComputationDataHandle parameter; + auto input_data = CreateParameterAndTransferLiteral( + 0, *input_literal, "input", &builder, ¶meter); + builder.Reshape(parameter, /*dimensions=*/{0, 1, 3, 2}, + /*new_sizes=*/new_bounds); std::unique_ptr expected = LiteralTestUtil::Reshape(new_bounds, {2, 3, 1, 0}, *input_literal) @@ -723,7 +896,7 @@ XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeSimple) { zero_error_spec_, &expected->shape()); } -XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeMajorFirstEffectiveR2) { +XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstEffectiveR2) { std::mt19937 rng; std::uniform_real_distribution distribution; std::vector bounds = {1, 1, 250, 300}; @@ -735,12 +908,12 @@ XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeMajorFirstEffectiveR2) { std::unique_ptr input_literal = Literal::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); - std::unique_ptr input_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); - ComputationBuilder builder(client_, TestName()); - auto a = builder.Parameter(0, input_literal->shape(), "a"); - builder.Reshape(a, /*dimensions=*/{0, 1, 3, 2}, /*new_sizes=*/new_bounds); + ComputationDataHandle parameter; + auto input_data = CreateParameterAndTransferLiteral( + 0, *input_literal, "input", &builder, ¶meter); + builder.Reshape(parameter, /*dimensions=*/{0, 1, 3, 2}, + /*new_sizes=*/new_bounds); std::unique_ptr expected = LiteralTestUtil::Reshape(new_bounds, {2, 3, 1, 0}, *input_literal) @@ -752,7 +925,7 @@ XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeMajorFirstEffectiveR2) { zero_error_spec_, &expected->shape()); } -XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1) { +XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1) { std::mt19937 rng; std::uniform_real_distribution distribution; std::vector bounds = {5, 5, 1, 10}; @@ -764,12 +937,12 @@ XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1) { std::unique_ptr input_literal = Literal::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); - std::unique_ptr input_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); - ComputationBuilder builder(client_, TestName()); - auto a = builder.Parameter(0, input_literal->shape(), "a"); - builder.Reshape(a, /*dimensions=*/{0, 1, 3, 2}, /*new_sizes=*/new_bounds); + ComputationDataHandle parameter; + auto input_data = CreateParameterAndTransferLiteral( + 0, *input_literal, "input", &builder, ¶meter); + builder.Reshape(parameter, /*dimensions=*/{0, 1, 3, 2}, + /*new_sizes=*/new_bounds); std::unique_ptr expected = LiteralTestUtil::Reshape(new_bounds, {2, 3, 1, 0}, *input_literal) @@ -781,7 +954,7 @@ XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1) { zero_error_spec_, &expected->shape()); } -XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1InR2) { +XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1InR2) { std::mt19937 rng; std::uniform_real_distribution distribution; // This happens in NN-Builder MNIST. @@ -794,12 +967,12 @@ XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1InR2) { std::unique_ptr input_literal = Literal::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); - std::unique_ptr input_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); - ComputationBuilder builder(client_, TestName()); - auto a = builder.Parameter(0, input_literal->shape(), "a"); - builder.Reshape(a, /*dimensions=*/{0, 1, 3, 2}, /*new_sizes=*/new_bounds); + ComputationDataHandle parameter; + auto input_data = CreateParameterAndTransferLiteral( + 0, *input_literal, "input", &builder, ¶meter); + builder.Reshape(parameter, /*dimensions=*/{0, 1, 3, 2}, + /*new_sizes=*/new_bounds); std::unique_ptr expected = LiteralTestUtil::Reshape(new_bounds, {2, 3, 1, 0}, *input_literal) @@ -811,7 +984,7 @@ XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1InR2) { zero_error_spec_, &expected->shape()); } -XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeTrivialR2) { +XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeTrivialR2) { std::mt19937 rng; std::uniform_real_distribution distribution; std::vector bounds = {3, 3, 1, 3}; @@ -823,12 +996,12 @@ XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeTrivialR2) { std::unique_ptr input_literal = Literal::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({0, 1, 2, 3})); - std::unique_ptr input_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); - ComputationBuilder builder(client_, TestName()); - auto a = builder.Parameter(0, input_literal->shape(), "a"); - builder.Reshape(a, /*dimensions=*/{1, 0, 2, 3}, /*new_sizes=*/new_bounds); + ComputationDataHandle parameter; + auto input_data = CreateParameterAndTransferLiteral( + 0, *input_literal, "input", &builder, ¶meter); + builder.Reshape(parameter, /*dimensions=*/{1, 0, 2, 3}, + /*new_sizes=*/new_bounds); std::unique_ptr expected = LiteralTestUtil::Reshape(new_bounds, {1, 0, 2, 3}, *input_literal) @@ -840,5 +1013,12 @@ XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeTrivialR2) { zero_error_spec_, &expected->shape()); } +#ifdef XLA_BACKEND_SUPPORTS_BFLOAT16 +INSTANTIATE_TEST_CASE_P(ReshapeTestInstance, ReshapeTest, ::testing::Bool()); +#else +INSTANTIATE_TEST_CASE_P(ReshapeTestInstance, ReshapeTest, + ::testing::ValuesIn(std::vector{false})); +#endif + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/reverse_test.cc b/tensorflow/compiler/xla/tests/reverse_test.cc index 1f6cfc85ccd25bb22db51411f7376489c14c3603..8fc841f14087cdea02fe44cdaea521ff92122aec 100644 --- a/tensorflow/compiler/xla/tests/reverse_test.cc +++ b/tensorflow/compiler/xla/tests/reverse_test.cc @@ -28,56 +28,89 @@ limitations under the License. namespace xla { namespace { -class ReverseTest : public ClientLibraryTestBase {}; - -// Tests the reverse operation on a scalar. -XLA_TEST_F(ReverseTest, ReverseScalar) { - ComputationBuilder b(client_, TestName()); - float input = 3.5f; - b.Rev(b.ConstantR0(input), {}); - ComputeAndCompareR0(&b, input, {}); -} - -// Tests the reverse operation on a 0x0 float array on both dimensions. -XLA_TEST_F(ReverseTest, Reverse0x0FloatArray) { - ComputationBuilder b(client_, TestName()); - b.Rev(b.ConstantR2FromArray2D(Array2D(0, 0)), {0, 1}); - ComputeAndCompareR2(&b, Array2D(0, 0), {}); -} - -// Tests the reverse operation on a 0x1 float array on both dimensions. -XLA_TEST_F(ReverseTest, Reverse0x1FloatArray) { - ComputationBuilder b(client_, TestName()); - b.Rev(b.ConstantR2FromArray2D(Array2D(0, 1)), {0, 1}); - ComputeAndCompareR2(&b, Array2D(0, 1), {}); +#ifdef XLA_BACKEND_SUPPORTS_BFLOAT16 +// Tests both F32 and BF16. +static std::array use_bfloat16_params{false, true}; +#else +// Only tests F32. +static std::array use_bfloat16_params{false}; +#endif + +struct ReverseSpec { + tensorflow::gtl::ArraySlice input_dims; + tensorflow::gtl::ArraySlice reversal; + bool use_bfloat16; + + string ToTestCaseName() const { + return tensorflow::strings::Printf( + "reverse_%s_in_dims_%s_%s", + tensorflow::str_util::Join(input_dims, "x").c_str(), + tensorflow::str_util::Join(reversal, "x").c_str(), + use_bfloat16 ? "bf16" : "f32"); + } +}; + +static std::vector GetTestCases() { + // clang-format off + return ExpandUseBfloat16( + use_bfloat16_params, + {{{}, {}}, + {{0, 0}, {0, 1}}, + {{0, 1}, {0, 1}}, + {{1, 0}, {0, 1}}, + {{1, 1}, {0, 1}}, + {{2, 0, 4, 3}, {0, 2}}, + {{2, 0, 4, 3}, {1, 3}}, + {{1, 2, 3, 4}, {0, 3}}, + {{4, 3, 2, 1}, {0, 1}}, + }); + // clang-format on } -// Tests the reverse operation on a 1x0 float array on both dimensions. -XLA_TEST_F(ReverseTest, Reverse1x0FloatArray) { - ComputationBuilder b(client_, TestName()); - b.Rev(b.ConstantR2FromArray2D(Array2D(1, 0)), {0, 1}); - ComputeAndCompareR2(&b, Array2D(1, 0), {}); +void PrintTo(const ReverseSpec& spec, std::ostream* os) { + *os << spec.ToTestCaseName(); } -// Tests the reverse operation on a 1x1 float array on both dimensions. -XLA_TEST_F(ReverseTest, Reverse1x1FloatArray) { - ComputationBuilder b(client_, TestName()); - Array2D input({{3.5f}}); - b.Rev(b.ConstantR2FromArray2D(input), {0, 1}); - ComputeAndCompareR2(&b, input, {}); +class FloatReverseTest : public ClientLibraryTestBase, + public ::testing::WithParamInterface { + public: + FloatReverseTest() { set_use_bfloat16(GetParam().use_bfloat16); } +}; + +TEST_P(FloatReverseTest, Reverses) { + const ReverseSpec& spec = GetParam(); + std::vector input_vector( + ShapeUtil::ElementsIn(ShapeUtil::MakeShape(F32, spec.input_dims))); + std::iota(input_vector.begin(), input_vector.end(), 0.0); + auto r1_literal = Literal::CreateR1(input_vector); + auto input_literal = r1_literal->Reshape(spec.input_dims).ConsumeValueOrDie(); + + ComputationBuilder builder(client_, TestName()); + auto a = AddParam(*input_literal, &builder); + builder.Rev(a, spec.reversal); + + std::unique_ptr expected = input_literal->CloneToUnique(); + std::vector output_indices(spec.input_dims.size()); + expected->EachCell( + [&](tensorflow::gtl::ArraySlice indices, float) { + for (int64 i = 0; i < indices.size(); ++i) { + output_indices[i] = indices[i]; + } + float value = input_literal->Get(indices); + for (int64 dim : spec.reversal) { + output_indices[dim] = (spec.input_dims[dim] - 1) - indices[dim]; + } + expected->Set(output_indices, value); + }); + ComputeAndCompareLiteral(&builder, *expected, {}); } -XLA_TEST_F(ReverseTest, Reverse2x0x4x3FloatArrayDim02) { - ComputationBuilder b(client_, TestName()); - b.Rev(b.ConstantR4FromArray4D(Array4D(2, 0, 4, 3)), {0, 2}); - ComputeAndCompareR4(&b, Array4D(2, 0, 4, 3), {}); -} +INSTANTIATE_TEST_CASE_P(FloatReverseInstance, FloatReverseTest, + ::testing::ValuesIn(GetTestCases()), + ::testing::PrintToStringParamName()); -XLA_TEST_F(ReverseTest, Reverse2x0x4x3FloatArrayDim13) { - ComputationBuilder b(client_, TestName()); - b.Rev(b.ConstantR4FromArray4D(Array4D(2, 0, 4, 3)), {1, 3}); - ComputeAndCompareR4(&b, Array4D(2, 0, 4, 3), {}); -} +// A simple test class which not templated by float precision. +class ReverseTest : public ClientLibraryTestBase {}; // Tests the reverse operation on a 4D U8 array on dimension 0 and 3. XLA_TEST_F(ReverseTest, Reverse4DU8ArrayOnDim23) { diff --git a/tensorflow/compiler/xla/tests/sample_file_test.cc b/tensorflow/compiler/xla/tests/sample_file_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..31b104f4e37f77d47f56ff8183ee1de1cc22e44d --- /dev/null +++ b/tensorflow/compiler/xla/tests/sample_file_test.cc @@ -0,0 +1,51 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This demonstrates how to use hlo_test_base to create a file based testcase +// and compare results on gpu and cpu. + +#include +#include + +#include "tensorflow/compiler/xla/service/platform_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +class SampleFileTest : public HloTestBase { + protected: + SampleFileTest() + : HloTestBase( + /*test_platform=*/PlatformUtil::GetPlatform("gpu").ValueOrDie(), + /*reference_platform=*/PlatformUtil::GetPlatform("cpu") + .ValueOrDie()) {} +}; + +TEST_F(SampleFileTest, Convolution) { + const string& filename = "compiler/xla/tests/isolated_convolution.hlo"; + string test_srcdir = tensorflow::testing::TensorFlowSrcRoot(); + EXPECT_TRUE(RunAndCompareFromFile( + tensorflow::io::JoinPath(test_srcdir, filename), ErrorSpec{0.01})); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/sample_text_test.cc b/tensorflow/compiler/xla/tests/sample_text_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..b4f2b74e3dc9e80f50454b28eb6f2502cef3e681 --- /dev/null +++ b/tensorflow/compiler/xla/tests/sample_text_test.cc @@ -0,0 +1,66 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This demonstrates how to use hlo_test_base to create textual IR based +// testcases. + +#include +#include + +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/gtl/optional.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +using tensorflow::gtl::nullopt; + +class SampleTextTest : public HloTestBase {}; + +TEST_F(SampleTextTest, Axpy) { + const string& hlo_string = R"( +HloModule axpy_module: +ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] { + %alpha = f32[] parameter(0) + %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={} + %x = f32[2,4]{1,0} parameter(1) + %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x) + %y = f32[2,4]{1,0} parameter(2) + ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y) +} +)"; + EXPECT_TRUE(RunAndCompareNoHloPasses(hlo_string, ErrorSpec{0.0001})); +} + +TEST_F(SampleTextTest, Tuple) { + const string& hlo_string = R"( +HloModule TupleCreate_module: +ENTRY %TupleCreate.v4 (v1: f32[], v2: f32[3], v3: f32[2,3]) -> (f32[], f32[3], f32[2,3]) { + %v1 = f32[] parameter(0) + %v2 = f32[3]{0} parameter(1) + %v3 = f32[2,3]{1,0} parameter(2) + ROOT %tuple = (f32[], f32[3]{0}, f32[2,3]{1,0}) tuple(f32[] %v1, f32[3]{0} %v2, f32[2,3]{1,0} %v3) +} +)"; + EXPECT_TRUE(RunAndCompare(hlo_string, nullopt)); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/scalar_computations_test.cc b/tensorflow/compiler/xla/tests/scalar_computations_test.cc index b5e7570778ffeca66cc15d7cd2b153639637a647..debf2d2d317fe64ca1ef86cb1f2978e76af1b55d 100644 --- a/tensorflow/compiler/xla/tests/scalar_computations_test.cc +++ b/tensorflow/compiler/xla/tests/scalar_computations_test.cc @@ -69,6 +69,13 @@ class ScalarComputationsTest : public ClientLibraryTestBase { } }; +XLA_TEST_F(ScalarComputationsTest, ReturnScalarF32) { + ComputationBuilder builder(client_, TestName()); + builder.ConstantR0(2.1f); + + ComputeAndCompareR0(&builder, 2.1f, {}, error_spec_); +} + XLA_TEST_F(ScalarComputationsTest, NegateScalarF32) { ComputationBuilder builder(client_, TestName()); builder.Neg(builder.ConstantR0(2.1f)); diff --git a/tensorflow/compiler/xla/tests/slice_test.cc b/tensorflow/compiler/xla/tests/slice_test.cc index c21124750ad512cad69b1483e708613ee2857ac0..ac163df127e0087c02777fa3d5ce7970c51b97b9 100644 --- a/tensorflow/compiler/xla/tests/slice_test.cc +++ b/tensorflow/compiler/xla/tests/slice_test.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -33,7 +34,6 @@ namespace xla { namespace { using ::tensorflow::str_util::Join; -using ::tensorflow::strings::StrCat; class SliceTest : public ClientLibraryTestBase {}; @@ -211,6 +211,13 @@ class SliceR1Test : public ClientLibraryTestBase, } }; +string SliceR1TestDataToString(const ::testing::TestParamInfo& data) { + const R1Spec& spec = data.param; + return ::tensorflow::strings::Printf("%lld_%lld_%lld_%lld", spec.input_dim0, + spec.slice_start, spec.slice_limit, + spec.slice_stride); +} + XLA_TEST_P(SliceR1Test, DoIt_F32) { Run(GetParam()); } XLA_TEST_P(SliceR1Test, DoIt_F64) { Run(GetParam()); } @@ -223,30 +230,66 @@ XLA_TEST_P(SliceR1Test, DoIt_U64) { Run(GetParam()); } XLA_TEST_P(SliceR1Test, DoIt_S64) { Run(GetParam()); } -INSTANTIATE_TEST_CASE_P( // - SliceR1TestInstantiation, // - SliceR1Test, // - ::testing::Values( // - R1Spec{10, 0, 0, 1}, // - R1Spec{10, 7, 7, 1}, // - R1Spec{10, 2, 4, 1}, // - R1Spec{10, 2, 4, 2}, // - R1Spec{10, 0, 10, 1}, // - R1Spec{1024, 1024 - 4, 1024, 1}, // - R1Spec{4096, 7, 7 + 1024, 1}, // - R1Spec{10, 0, 10, 2}, // - R1Spec{10, 0, 10, 3}, // - R1Spec{10, 0, 10, 4}, // - R1Spec{10, 0, 10, 5}, // - R1Spec{10, 0, 10, 10}, // - R1Spec{500, 200, 400, 7}, // - R1Spec{4096, 1, 4095, 3}, // - R1Spec{2047, 1024 - 24, 1024 + 160, 31}, // - R1Spec{2047, 1, 2046, 3 * 128}, // - R1Spec{4096, 1024 + 3, 4095, 500}, // - R1Spec{8192, 0, 8192, 1024 * 3 + 400} // - ) // +// Tests for R1 slice ops. +// The format for each testcase is {input size, start, limit, stride}. +// clang-format off +INSTANTIATE_TEST_CASE_P( + SliceR1TestInstantiation, + SliceR1Test, + ::testing::Values( + R1Spec{10, 0, 0, 1}, + R1Spec{10, 7, 7, 1}, + R1Spec{10, 0, 5, 1}, + R1Spec{10, 3, 5, 1}, + R1Spec{10, 0, 10, 1}, + R1Spec{1024, 0, 5, 1}, + R1Spec{1024, 3, 5, 1}, + R1Spec{1024 + 17, 0, 5, 1}, + R1Spec{1024 + 17, 3, 5, 1}, + R1Spec{1024 + 17, 1024, 1024 + 6, 1}, + R1Spec{1024 + 17, 1024 + 1, 1024 + 6, 1}, + R1Spec{1024, 1024 - 4, 1024, 1}, + R1Spec{4 * 1024, 7, 7 + 1024, 1}, + R1Spec{4 * 1024, 0, 4 * 1024, 1}, + R1Spec{4 * 1024, 1, 4 * 1024 - 1, 1}, + R1Spec{4 * 1024, 1024, 3 * 1024, 1}, + R1Spec{4 * 1024, 1024 + 1, 3 * 1024 - 1, 1}, + R1Spec{16 * 1024, 0, 5, 1}, + R1Spec{16 * 1024, 3, 5, 1}, + R1Spec{16 * 1024 + 17, 0, 5, 1}, + R1Spec{16 * 1024 + 17, 3, 5, 1}, + R1Spec{16 * 1024 + 17, 16 * 1024, 16 * 1024 + 6, 1}, + R1Spec{16 * 1024 + 17, 16 * 1024 + 1, 16 * 1024 + 6, 1}, + R1Spec{16 * 1024, 4 * 1024 - 17, 8 * 1024 - 18, 1}, + R1Spec{64 * 1024, 0, 64 * 1024, 1}, + R1Spec{64 * 1024, 1, 64 * 1024 - 1, 1}, + R1Spec{64 * 1024, 1024, 63 * 1024, 1}, + R1Spec{64 * 1024, 1024 + 1, 63 * 1024 - 1, 1}, + R1Spec{64 * 1024, 32 * 1024, 33 * 1024, 1}, + R1Spec{64 * 1024, 32 * 1024 + 1, 33 * 1024 - 1, 1}, + R1Spec{64 * 1024, 32 * 1024 - 17, 36 * 1024 - 18, 1}, +// TODO(b/69425338): This uses too much memory on GPU. +#ifndef XLA_TEST_BACKEND_GPU + R1Spec{16 * 1024 * 1024, 4 * 1024 * 1024, 12 * 1024 * 1024, 1}, + R1Spec{16 * 1024 * 1024, 4 * 1024 * 1024 + 1, 12 * 1024 * 1024 - 1, 1}, + R1Spec{16 * 1024 * 1024, 4 * 1024 * 1024 - 1, 12 * 1024 * 1024 + 1, 1}, +#endif + R1Spec{10, 2, 4, 2}, + R1Spec{10, 0, 10, 2}, + R1Spec{10, 0, 10, 3}, + R1Spec{10, 0, 10, 4}, + R1Spec{10, 0, 10, 5}, + R1Spec{10, 0, 10, 10}, + R1Spec{500, 200, 400, 7}, + R1Spec{4096, 1, 4095, 3}, + R1Spec{2047, 1024 - 24, 1024 + 160, 31}, + R1Spec{2047, 1, 2046, 3 * 128}, + R1Spec{4096, 1024 + 3, 4095, 500}, + R1Spec{8192, 0, 8192, 1024 * 3 + 400} + ), + SliceR1TestDataToString ); +// clang-format on struct R2Spec { int64 input_dim0; @@ -339,7 +382,7 @@ struct R4Spec { string R4SpecToString(const ::testing::TestParamInfo& data) { const R4Spec& spec = data.param; - return StrCat( // + return tensorflow::strings::StrCat( // "input_", Join(spec.input_dims, "x"), // "__layout_", Join(spec.input_layout, ""), // "__starts_", Join(spec.slice_starts, "x"), // diff --git a/tensorflow/compiler/xla/tests/test_macros.h b/tensorflow/compiler/xla/tests/test_macros.h index 28a2d0198a707cec1aa5e0fbed341ee9b2a927f7..cc4eaf62f50d1fa622c705fab810fe1e1b0fbf08 100644 --- a/tensorflow/compiler/xla/tests/test_macros.h +++ b/tensorflow/compiler/xla/tests/test_macros.h @@ -36,6 +36,7 @@ limitations under the License. #define DISABLED_ON_CPU(X) X #define DISABLED_ON_CPU_PARALLEL(X) X #define DISABLED_ON_GPU(X) X +#define DISABLED_ON_INTERPRETER(X) X // We need this macro instead of pasting directly to support nesting // the DISABLED_ON_FOO macros, as in the definition of DISABLED_ON_CPU. @@ -62,6 +63,11 @@ limitations under the License. # define DISABLED_ON_GPU(X) XLA_TEST_PASTE(DISABLED_, X) #endif // XLA_TEST_BACKEND_GPU +#ifdef XLA_TEST_BACKEND_INTERPRETER +# undef DISABLED_ON_INTERPRETER +# define DISABLED_ON_INTERPRETER(X) XLA_TEST_PASTE(DISABLED_, X) +#endif // XLA_TEST_BACKEND_INTERPRETER + // clang-format on namespace xla { diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc index 0d56c9f48363d0569921d7c76050dcc66208931b..8b10aef5b81c18648b6e255445d66a6d195f8a76 100644 --- a/tensorflow/compiler/xla/tests/test_utils.cc +++ b/tensorflow/compiler/xla/tests/test_utils.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/primitive_util.h" +#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" #include "tensorflow/compiler/xla/service/hlo_verifier.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" @@ -27,10 +28,31 @@ void PopulateWithRandomFloatingPointData(Literal* literal) { CHECK_EQ(literal->shape().element_type(), primitive_util::NativeToPrimitiveType()); std::minstd_rand0 engine; - std::uniform_real_distribution generator(0.0f, 1.0f); + // Create uniform numbers between 1 and 1.125 ot avoid creating denormal + // numbers. + std::uniform_real_distribution generator(1.0f, 1.125f); TF_CHECK_OK(literal->Populate( + [&](tensorflow::gtl::ArraySlice indices) { + // Generate a random uniforma number from -0.0625 and 0.0625 and bias it + // with a position dependent nubmer with mean 0.037109375. These number + // should allow for long chains of accumulation without being too close + // to zero or to large to accumulate all numbers accurately. + return (generator(engine) - 1.0625) + + static_cast(Product(indices) % 113 - 47) / + static_cast(256.0f); + })); +} + +// The standard library does not have a case for bfloat16, unsurprisingly, so we +// handle that one specially. +template <> +void PopulateWithRandomFloatingPointData(Literal* literal) { + CHECK_EQ(literal->shape().element_type(), BF16); + std::minstd_rand0 engine; + std::uniform_real_distribution generator(-0.9f, 1.0f); + TF_CHECK_OK(literal->Populate( [&](tensorflow::gtl::ArraySlice /*indices*/) { - return generator(engine); + return static_cast(generator(engine)); })); } @@ -47,42 +69,136 @@ void PopulateWithRandomIntegralData(Literal* literal) { })); } -bool LooksLikeSum(const HloInstruction& instruction) { - return instruction.opcode() == HloOpcode::kAdd && - instruction.operand(0)->opcode() == HloOpcode::kParameter && - instruction.operand(1)->opcode() == HloOpcode::kParameter && - instruction.operand(0) != instruction.operand(1); +// Matches binary addition computations. +bool LooksLikeSum(const HloComputation& computation) { + const HloInstruction* const root = computation.root_instruction(); + return root->opcode() == HloOpcode::kAdd && + computation.num_parameters() == 2 && + root->operand(0)->opcode() == HloOpcode::kParameter && + root->operand(1)->opcode() == HloOpcode::kParameter && + root->operand(0) != root->operand(1); +} + +// Reduce, ReduceWindow, and SelectAndScatter ops may use binary addition, +// which requires an init_value of 0 rather than a random value. +bool NeedsZeroInitValue(const HloUse& use) { + const HloInstruction* const instruction = use.instruction; + const HloOpcode opcode = instruction->opcode(); + const int64 op_num = use.operand_number; + return ( + ((opcode == HloOpcode::kReduce || opcode == HloOpcode::kReduceWindow) && + op_num == 1 && LooksLikeSum(*instruction->to_apply())) || + (opcode == HloOpcode::kSelectAndScatter && op_num == 2 && + LooksLikeSum(*instruction->scatter()))); } -// Given an instruction and operand number, replace the given operand with -// a Literal Constant Zero. Handle the case of a fusion instruction by -// replacing the fusion's parent's parameter with a Literal Constant Zero, -// unless the fusion's parent is itself a fusion. -Status MaybeReplaceParameterInputWithZero(HloInstruction* const instruction, - const int64 operand_number) { - CHECK_LT(operand_number, instruction->operand_count()); - if (instruction->operand(operand_number)->opcode() != HloOpcode::kParameter) { - return Status::OK(); +// Generate random values that are constrained to the input_shape minus the +// output_shape so as not to produce wrapping slices, for instance. +std::unique_ptr MakeRandomNonwrappingSliceIndex( + const Shape& input_shape, const Shape& slice_shape) { + const int64 rank = ShapeUtil::Rank(input_shape); + std::vector start_indices(rank); + std::minstd_rand0 engine; + for (int i = 0; i < rank; ++i) { + const int32 upper_bound = ShapeUtil::GetDimension(input_shape, i) - + ShapeUtil::GetDimension(slice_shape, i); + std::uniform_int_distribution generator(0, upper_bound); + start_indices[i] = generator(engine); } + return Literal::CreateR1(start_indices); +} - HloComputation* const computation = instruction->parent(); - std::unique_ptr zero = HloInstruction::CreateConstant( - MakeUnique(Literal::Zero(instruction->shape().element_type()))); +// Use dataflow analysis on each parameter to see if there are uses that would +// be problematic when generating input data. Returns the list of instructions +// that correspond to their uses. +// +// Should be paired with the CreateLiteralForConstrainedUses() function below. +std::vector FindConstrainedUses( + const HloDataflowAnalysis& dataflow, const HloInstruction& param) { + std::vector constrained_uses; + for (const auto& pair : dataflow.GetInstructionValueSet(¶m)) { + const HloValue& value = dataflow.GetUniqueValueAt(¶m, pair.first); + for (const HloUse& use : value.uses()) { + HloInstruction* instruction = use.instruction; + const HloOpcode opcode = instruction->opcode(); + const int64 op_num = use.operand_number; + if ((opcode == HloOpcode::kDynamicSlice && op_num == 1) || + (opcode == HloOpcode::kDynamicUpdateSlice && op_num == 2)) { + constrained_uses.push_back(instruction); + } else if (opcode == HloOpcode::kFusion) { + const HloInstruction* const to_analyze = + instruction->fused_parameter(op_num); + auto fused_uses = FindConstrainedUses(dataflow, *to_analyze); + constrained_uses.insert(constrained_uses.end(), fused_uses.begin(), + fused_uses.end()); + } else if (NeedsZeroInitValue(use)) { + constrained_uses.push_back(instruction); + } else if (opcode == HloOpcode::kConvert || + opcode == HloOpcode::kReducePrecision) { + auto converted_uses = FindConstrainedUses(dataflow, *instruction); + constrained_uses.insert(constrained_uses.end(), converted_uses.begin(), + converted_uses.end()); + } + } + } + return constrained_uses; +} - if (computation->IsFusionComputation()) { - HloInstruction* const fusion_instruction = computation->FusionInstruction(); - if (fusion_instruction->IsFused()) { - return Unimplemented( - "Unable to replace fused parameter of fusion instruction"); +// Given a parameter, generate a random Literal to use as input if there exist +// no constrained uses in the dataflow graph. If such constraints exist, +// generate a constrained literal (either bounded in the case of indices, or +// zero in the case of init_values for reductions). +StatusOr> CreateLiteralForConstrainedUses( + const tensorflow::gtl::ArraySlice constrained_uses, + const HloInstruction& param) { + HloInstruction* needs_index = nullptr; + HloInstruction* needs_zero = nullptr; + for (HloInstruction* use : constrained_uses) { + switch (use->opcode()) { + case HloOpcode::kDynamicSlice: + case HloOpcode::kDynamicUpdateSlice: + if (needs_index != nullptr && + !ShapeUtil::Equal(needs_index->shape(), use->shape())) { + return Unimplemented( + "Conflicting operand generation slice index constraints\n"); + } + needs_index = use; + break; + + case HloOpcode::kReduce: + case HloOpcode::kReduceWindow: + case HloOpcode::kSelectAndScatter: + needs_zero = use; + break; + + default: + return Unimplemented( + "Constrained operand generation not implemented for %s.", + use->ToString().c_str()); } - TF_RETURN_IF_ERROR(fusion_instruction->ReplaceOperandWith( - instruction->operand(operand_number)->parameter_number(), - fusion_instruction->parent()->AddInstruction(std::move(zero)))); + } + if (needs_index != nullptr && needs_zero != nullptr) { + return Unimplemented( + "Conflicting operand generation constraints.\nNeeds index: %s\nNeeds " + "zero: %s\n", + needs_index->ToString().c_str(), needs_zero->ToString().c_str()); + } + if (needs_index != nullptr) { + return MakeRandomNonwrappingSliceIndex(needs_index->operand(0)->shape(), + needs_index->shape()); + } else if (needs_zero != nullptr) { + return Literal::CreateFromShape(param.shape()); } else { - TF_RETURN_IF_ERROR(instruction->ReplaceOperandWith( - operand_number, computation->AddInstruction(std::move(zero)))); + return MakeFakeLiteral(param.shape()); } - return Status::OK(); +} + +// Given a module entry parameter, use the dataflow analysis to see if a +// special case literal must be created, or if we can generate fake data. +StatusOr> MakeConstrainedArgument( + const HloDataflowAnalysis& dataflow, const HloInstruction& param) { + const auto constrained_uses = FindConstrainedUses(dataflow, param); + return CreateLiteralForConstrainedUses(constrained_uses, param); } } // namespace @@ -99,6 +215,9 @@ StatusOr> MakeFakeLiteral(const Shape& shape) { } std::unique_ptr literal = Literal::CreateFromShape(shape); switch (shape.element_type()) { + case BF16: + PopulateWithRandomFloatingPointData(literal.get()); + break; case F32: PopulateWithRandomFloatingPointData(literal.get()); break; @@ -146,42 +265,20 @@ StatusOr> MakeFakeLiteral(const Shape& shape) { } StatusOr>> MakeFakeArguments( - const HloModule& module) { - std::vector> arguments; - for (const ShapeLayout& shape_layout : - module.config().entry_computation_layout().parameter_layouts()) { - TF_ASSIGN_OR_RETURN(auto literal, MakeFakeLiteral(shape_layout.shape())); - arguments.push_back(std::move(literal)); + HloModule* const module) { + TF_ASSIGN_OR_RETURN(auto dataflow, HloDataflowAnalysis::Run(module)); + const auto params = module->entry_computation()->parameter_instructions(); + std::vector> arguments(params.size()); + for (int i = 0; i < params.size(); ++i) { + TF_ASSIGN_OR_RETURN(arguments[i], + MakeConstrainedArgument(*dataflow, *params[i])); } return std::move(arguments); } -Status ReplaceInitsWithConstants(HloModule* const module) { - for (HloComputation* const computation : module->computations()) { - for (HloInstruction* const instruction : computation->instructions()) { - const HloOpcode opcode = instruction->opcode(); - if ((opcode == HloOpcode::kReduce || - opcode == HloOpcode::kReduceWindow) && - LooksLikeSum(*instruction->to_apply()->root_instruction())) { - TF_RETURN_IF_ERROR(MaybeReplaceParameterInputWithZero(instruction, 1)); - } else if (opcode == HloOpcode::kSelectAndScatter && - LooksLikeSum(*instruction->scatter()->root_instruction())) { - TF_RETURN_IF_ERROR(MaybeReplaceParameterInputWithZero(instruction, 2)); - } - } - } - return Status::OK(); -} - Status VerifyHloModule(const perftools::gputools::Platform& platform, HloModule* const module) { - return HloVerifier( - std::bind( - &TransferManager::GetByteSizeRequirement, - TransferManager::GetForPlatform(&platform).ConsumeValueOrDie(), - std::placeholders::_1)) - .Run(module) - .status(); + return HloVerifier().Run(module).status(); } } // namespace xla diff --git a/tensorflow/compiler/xla/tests/test_utils.h b/tensorflow/compiler/xla/tests/test_utils.h index 9aca162a185e5b22888229555b7bce88769c79a6..0fb024ffb074f1c90b75022bc7f5a8b58b03c0c2 100644 --- a/tensorflow/compiler/xla/tests/test_utils.h +++ b/tensorflow/compiler/xla/tests/test_utils.h @@ -60,13 +60,11 @@ StatusOr> MakeFakeLiteral(const Shape& shape); // Generates a vector of arguments containing fake data. The number, shape and // layout of the arguments is appropriate for given HLO module. +// +// Will handle special cases such as making sure that indices used for dynamic +// slices are bounded, reduces that call adds use 0 as an init value, etc. StatusOr>> MakeFakeArguments( - const HloModule& module); - -// Reductions using Adds, ReduceWindow, and SelectAndScatter, require their -// init_value to be replaced with the constant 0.0f when testing, otherwise we -// may generate a bad init_value when looking at the op in isolation. -Status ReplaceInitsWithConstants(HloModule* const module); + HloModule* const module); // Check that a given module satisfies various constraints before trying to // execute it. diff --git a/tensorflow/compiler/xla/tests/transfer_manager_test.cc b/tensorflow/compiler/xla/tests/transfer_manager_test.cc index f2a64749482e5f5a8c5d72034fb7a4eee07baf48..268ba338f2e6740a1d1a046d5a85494f3cf2e9f8 100644 --- a/tensorflow/compiler/xla/tests/transfer_manager_test.cc +++ b/tensorflow/compiler/xla/tests/transfer_manager_test.cc @@ -46,9 +46,10 @@ class TransferManagerTest : public LocalClientTestBase { ~TransferManagerTest() override = default; std::unique_ptr AllocateDeviceBuffer(const Shape& shape) { - return ScopedShapedBuffer::Allocate( - shape, GetOrCreateAllocator(local_client_->platform()), - /*device_ordinal=*/0, shape_size_fn_) + return transfer_manager_ + ->AllocateScopedShapedBuffer( + shape, GetOrCreateAllocator(local_client_->platform()), + /*device_ordinal=*/0) .ValueOrDie(); } @@ -118,7 +119,7 @@ XLA_TEST_F(TransferManagerTest, TransferR1U8) { transfer_manager_->TransferLiteralFromDevice( stream_executor_, *device_buffer)); - EXPECT_EQ(result->u8s_string(), test_string); + EXPECT_EQ(result->GetR1U8AsString(), test_string); } XLA_TEST_F(TransferManagerTest, TransferR2F32) { @@ -211,5 +212,39 @@ XLA_TEST_F(TransferManagerTest, TransferNestedTuple) { LiteralTestUtil::ExpectEqual(*literal, *result); } +XLA_TEST_F(TransferManagerTest, TransferComplexValue) { + std::unique_ptr literal = Literal::CreateR1( + {complex64(1.0f, 2.0f), complex64(42.0f, -123.4f)}); + auto device_buffer = AllocateDeviceBuffer(literal->shape()); + + // Round trip literal through device. + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice( + stream_executor_, *literal, *device_buffer)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, + transfer_manager_->TransferLiteralFromDevice( + stream_executor_, *device_buffer)); + + LiteralTestUtil::ExpectEqual(*literal, *result); +} + +XLA_TEST_F(TransferManagerTest, TransferComplexValueInTuple) { + std::unique_ptr literal = Literal::MakeTuple( + {Literal::CreateR1( + {complex64(1.0f, 2.0f), complex64(42.0f, -123.4f)}) + .get(), + Literal::CreateR1({1, 2, 3, 4, 5, 6}).get(), + Literal::CreateR0(complex64(0.3f, -0.4f)).get()}); + auto device_buffer = AllocateDeviceBuffer(literal->shape()); + + // Round trip literal through device. + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice( + stream_executor_, *literal, *device_buffer)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, + transfer_manager_->TransferLiteralFromDevice( + stream_executor_, *device_buffer)); + + LiteralTestUtil::ExpectEqual(*literal, *result); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/tuple_test.cc b/tensorflow/compiler/xla/tests/tuple_test.cc index 5a012c93d64f6a6fca73aa422e20cf238c945ce9..a8bca70d85ddf168bc441231d6f43bead019b10a 100644 --- a/tensorflow/compiler/xla/tests/tuple_test.cc +++ b/tensorflow/compiler/xla/tests/tuple_test.cc @@ -57,6 +57,20 @@ XLA_TEST_F(TupleTest, TupleConstant) { ComputeAndCompareTuple(&builder, *value, {}, error_spec_); } +// Tests a tuple made of scalar constants. +XLA_TEST_F(TupleTest, TupleScalarConstant) { + ComputationBuilder builder(client_, TestName()); + + const float constant_scalar1 = 7.3f; + const float constant_scalar2 = 1.2f; + auto value = + Literal::MakeTuple({Literal::CreateR0(constant_scalar1).get(), + Literal::CreateR0(constant_scalar2).get()}); + + auto result = builder.ConstantLiteral(*value); + ComputeAndCompareTuple(&builder, *value, {}, error_spec_); +} + // Tests the creation of tuple data. XLA_TEST_F(TupleTest, TupleCreate) { ComputationBuilder builder(client_, TestName()); @@ -445,5 +459,61 @@ XLA_TEST_F(TupleTest, GetTupleElementOfNestedTuple) { ComputeAndCompareR1(&builder, expected, arguments, ErrorSpec(1e-5)); } +XLA_TEST_F(TupleTest, ComplexTuples) { + ComputationBuilder builder(client_, TestName()); + { + Shape c64r0 = ShapeUtil::MakeShape(C64, {}); + Shape c64r1 = ShapeUtil::MakeShape(C64, {2}); + Shape c64r2 = ShapeUtil::MakeShape(C64, {3, 2}); + Shape arg0_shape = ShapeUtil::MakeTupleShape( + {c64r0, ShapeUtil::MakeTupleShape({c64r1, c64r2})}); + auto input0 = builder.Parameter(0, arg0_shape, "input0"); + auto t0 = builder.GetTupleElement(input0, 0); + auto t1 = builder.GetTupleElement(input0, 1); + auto t10 = builder.GetTupleElement(t1, 0); + auto t11 = builder.GetTupleElement(t1, 1); + auto sum = builder.Add(builder.Add(t10, t11, {1}), t0); + auto input1 = builder.Parameter(1, c64r1, "input1"); + auto prod = builder.Mul(input1, sum, {1}); + builder.Tuple({builder.Tuple({prod, sum}), + builder.ConstantR0({123, 456})}); + } + + std::unique_ptr arg0 = + client_ + ->TransferToServer(*Literal::MakeTuple( + {Literal::CreateR0({1, 2}).get(), + Literal::MakeTuple( + {Literal::CreateR1({{10, 20}, {30, 40}}).get(), + Literal::CreateR2( + {{{100, 200}, {300, 400}}, + {{1000, 2000}, {3000, 4000}}, + {{10000, 20000}, {30000, 40000}}}) + .get()}) + .get()})) + .ConsumeValueOrDie(); + std::unique_ptr arg1 = + client_ + ->TransferToServer(*Literal::CreateR1({{1, 2}, {1, -2}})) + .ConsumeValueOrDie(); + auto sum = Literal::CreateR2({{{111, 222}, {331, 442}}, + {{1011, 2022}, {3031, 4042}}, + {{10011, 20022}, {30031, 40042}}}); + auto prod = Literal::CreateFromShape(sum->shape()); + ASSERT_TRUE(prod->Populate( + [&sum](tensorflow::gtl::ArraySlice indexes) { + return sum->Get(indexes) * + (indexes[indexes.size() - 1] == 0 + ? complex64(1, 2) + : complex64(1, -2)); + }) + .ok()); + auto expected = + Literal::MakeTuple({Literal::MakeTuple({prod.get(), sum.get()}).get(), + Literal::CreateR0({123, 456}).get()}); + ComputeAndCompareTuple(&builder, *expected, {arg0.get(), arg1.get()}, + error_spec_); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/while_test.cc b/tensorflow/compiler/xla/tests/while_test.cc index 49f673f5f0bf9b844ab4030383784208b4e2c58a..52157b837c383205f77a030ef98b2fd03a41aff5 100644 --- a/tensorflow/compiler/xla/tests/while_test.cc +++ b/tensorflow/compiler/xla/tests/while_test.cc @@ -357,8 +357,7 @@ TEST_F(WhileTest, WhileWithVectorResultIntoTuple) { ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001)); } -// TODO(b/63003356): 11-06-2017: fails on all back-ends with incorrect result. -TEST_F(WhileTest, DISABLED_WhileWithPermutationAndTupleResult) { +TEST_F(WhileTest, WhileWithPermutationAndTupleResult) { std::vector shape_elements = { ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(F32, {3}), ShapeUtil::MakeShape(F32, {3}), ShapeUtil::MakeShape(F32, {3})}; @@ -411,8 +410,7 @@ TEST_F(WhileTest, DISABLED_WhileWithPermutationAndTupleResult) { ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001)); } -// TODO(b/63003356): 11-06-2017: fails on all back-ends with incorrect result. -TEST_F(WhileTest, DISABLED_WhileWithPermutationAndVectorResult) { +TEST_F(WhileTest, WhileWithPermutationAndVectorResult) { std::vector shape_elements = { ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(F32, {3}), ShapeUtil::MakeShape(F32, {3}), ShapeUtil::MakeShape(F32, {3})}; @@ -565,6 +563,53 @@ TEST_F(WhileTest, WhileWithPredicateTupleResult) { ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0)); } +TEST_F(WhileTest, WhileWithTupleConstantScalarResult) { + std::vector shape_elements = {ShapeUtil::MakeShape(S32, {}), + ShapeUtil::MakeShape(S32, {})}; + Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements); + + // Create a computation for the condition. + // Repeat for 5 iterations. + Computation condition; + { + ComputationBuilder builder(client_, "condition"); + auto prev = builder.Parameter(0, result_shape, "prev"); + auto iteration = builder.GetTupleElement(prev, 0); + builder.Gt(builder.ConstantR0(5), iteration); + condition = builder.Build().ConsumeValueOrDie(); + } + + // Create a computation for the body. + // Add 1 to the iteration variable and set the other tuple element to a + // constant. + Computation body; + { + ComputationBuilder builder(client_, "body"); + auto prev = builder.Parameter(0, result_shape, "prev"); + auto iteration = builder.GetTupleElement(prev, 0); + auto result = + builder.Tuple({builder.Add(iteration, builder.ConstantR0(1)), + builder.ConstantR0(7)}); + body = builder.Build().ConsumeValueOrDie(); + } + + // Create a While node with computations for the condition and the body. + ComputationBuilder builder(client_, "while"); + auto init = builder.Tuple( + {builder.ConstantR0(0), builder.ConstantR0(7)}); + auto result = builder.While(condition, body, init); + VLOG(2) << "while = " + << ShapeUtil::HumanString( + *builder.GetShape(result).ConsumeValueOrDie()); + + auto expected_counter = Literal::CreateR0(5); + auto expected_data = Literal::CreateR0(7); + auto expected = + Literal::MakeTuple({expected_counter.get(), expected_data.get()}); + VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape()); + ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001)); +} + // Tests two while nodes when the result type T is a Tuple and the second // while node uses the result of the first while node which is used in two // nodes. @@ -913,8 +958,7 @@ TEST_F(WhileTest, WhileWithPrngScalarResult) { } } -// TODO(b/34969189) Fails with bad AtomicCmpSwap on GPU on 2017-09-11. -TEST_F(WhileTest, DISABLED_ON_GPU(WhileThatSwapsParameterWithTupleElement)) { +TEST_F(WhileTest, WhileThatSwapsParameterWithTupleElement) { auto element_shape = ShapeUtil::MakeShape(F32, {2}); ComputationBuilder outer(client_, "outer"); @@ -950,8 +994,7 @@ TEST_F(WhileTest, DISABLED_ON_GPU(WhileThatSwapsParameterWithTupleElement)) { ErrorSpec(1e-6)); } -// TODO(b/34969189) Fails with bad AtomicCmpSwap on GPU on 2017-09-11. -TEST_F(WhileTest, DISABLED_ON_GPU(WhileThatSwapsParameterWithBroadcast)) { +TEST_F(WhileTest, WhileThatSwapsParameterWithBroadcast) { auto element_shape = ShapeUtil::MakeShape(F32, {2}); ComputationBuilder outer(client_, "outer"); @@ -1164,6 +1207,50 @@ TEST_F(WhileTest, WhileWithCallInsideCondition) { ComputeAndCompareR0(&builder, 5, {}); } +TEST_F(WhileTest, WhileWithLoopInvariantOperation) { + auto matrix_shape = ShapeUtil::MakeShape(F32, {2, 2}); + auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); + auto while_shape = ShapeUtil::MakeTupleShape( + {scalar_s32, matrix_shape, matrix_shape, matrix_shape}); + + // Create a computation for the condition: repeat for 5 iterations. + Computation condition; + { + ComputationBuilder builder(client_, "condition"); + auto state = builder.Parameter(0, while_shape, "state"); + builder.Gt(builder.ConstantR0(5), builder.GetTupleElement(state, 0)); + TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build()); + } + + Computation body; + { + ComputationBuilder builder(client_, "body"); + auto state = builder.Parameter(0, while_shape, "state"); + auto indvar = builder.GetTupleElement(state, 0); + auto input_0 = builder.GetTupleElement(state, 1); + auto input_1 = builder.GetTupleElement(state, 2); + auto output = builder.Tanh(builder.Dot(input_0, input_1)); + auto indvar_next = builder.Add(indvar, builder.ConstantR0(1)); + auto tuple_result = builder.Tuple({indvar_next, input_0, input_1, output}); + TF_ASSERT_OK_AND_ASSIGN(body, builder.Build()); + } + + ComputationBuilder builder(client_, TestName()); + auto matrix_input = builder.Parameter(0, matrix_shape, "matrix"); + auto init = builder.Tuple( + {builder.ConstantR0(0), matrix_input, matrix_input, matrix_input}); + auto while_instruction = builder.While(condition, body, init); + builder.GetTupleElement(while_instruction, 3); + + TF_ASSERT_OK_AND_ASSIGN(auto param_value, + client_->TransferToServer(*Literal::CreateR2( + {{1.0, 2.0}, {-1.0, -2.0}}))); + + ComputeAndCompareR2( + &builder, {{-0.76159416, -0.96402758}, {0.76159416, 0.96402758}}, + {param_value.get()}, ErrorSpec(4e-5)); +} + void BM_WhileLoop(int num_iters) { // Benchmark a simple kernel to measure while loop overheads. tensorflow::testing::StopTiming(); diff --git a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..146fbadcb68e6c5d0fa0856c1c98b399df72051f --- /dev/null +++ b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc @@ -0,0 +1,308 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/service/platform_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/regexp.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { +namespace se = ::perftools::gputools; + +class HloProfileTest : public ClientLibraryTestBase {}; + +struct ParsedProfileOutputLine { + int64 cycles; + string cycles_percentage; + double usec; + string flops; + string trops; + string bytes_per_sec; + string bytes_per_cycle; + string name; +}; + +StatusOr ParseProfileOutputLine(const string& line, + bool expect_flops, + bool expect_trops) { + string separator = "[^:]*:: +"; + string match_percentage = "\\d+\\.\\d\\d%"; + string match_cycles = "(\\d+) cycles +\\( *(" + match_percentage + ")\\)"; + string match_usecs = "([0-9.]+) usec"; + string match_flops = expect_flops ? "([0-9.TGMk]+)FLOP/s" : "()"; + string match_trops = expect_trops ? "([0-9.TGMk]+)TROP/s" : "()"; + string match_bytes_per_sec = "([0-9.TGMKi]+)B/s"; + string match_bytes_per_cycle = "([0-9.TGMKi]+)B/cycle"; + string regexp_pattern = tensorflow::strings::StrCat( + " +", match_cycles, separator, match_usecs, separator, match_flops, + separator, match_trops, separator, match_bytes_per_sec, separator, + match_bytes_per_cycle, separator, "(.*)"); + + RE2 pattern(regexp_pattern); + ParsedProfileOutputLine parsed_line; + bool matched = RE2::FullMatch( + line, pattern, &parsed_line.cycles, &parsed_line.cycles_percentage, + &parsed_line.usec, &parsed_line.flops, &parsed_line.trops, + &parsed_line.bytes_per_sec, &parsed_line.bytes_per_cycle, + &parsed_line.name); + if (!matched) { + return tensorflow::errors::InvalidArgument( + "Input did not match regexp. Input: ", line, + ", Regexp: ", regexp_pattern); + } + + return parsed_line; +} + +// Returns void so that we can ASSERT. +void ExecuteAndFetchProfile(string* profile_output, LocalClient* client, + const Computation& computation, + const Shape& lhs_arg_shape, + const Shape& rhs_arg_shape) { + LocalService* service = ClientLibrary::GetXlaService(client->platform()); + Backend* backend = service->mutable_backend(); + se::StreamExecutor* executor = backend->default_stream_executor(); + DeviceMemoryAllocator* allocator = backend->memory_allocator(); + auto* transfer_manager = backend->transfer_manager(); + + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr lhs_arg, + transfer_manager->AllocateScopedShapedBuffer( + lhs_arg_shape, allocator, backend->default_device_ordinal())); + TF_ASSERT_OK(transfer_manager->TransferLiteralToDevice( + executor, *Literal::CreateFromShape(lhs_arg_shape), *lhs_arg)); + + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr rhs_arg, + transfer_manager->AllocateScopedShapedBuffer( + rhs_arg_shape, allocator, backend->default_device_ordinal())); + TF_ASSERT_OK(transfer_manager->TransferLiteralToDevice( + executor, *Literal::CreateFromShape(rhs_arg_shape), *rhs_arg)); + + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr local_executable, + client->Compile(computation, {&lhs_arg_shape, &rhs_arg_shape}, + ExecutableBuildOptions())); + + Executable* executable = local_executable->executable(); + HloExecutionProfile hlo_execution_profile( + &executable->hlo_profile_printer(), &executable->hlo_profile_index_map()); + + TF_ASSERT_OK_AND_ASSIGN( + Backend::StreamPtr stream_ptr, + backend->BorrowStream(backend->default_device_ordinal())); + ExecutableRunOptions exec_run_options; + exec_run_options.set_stream(stream_ptr.get()); + exec_run_options.set_allocator(backend->memory_allocator()); + exec_run_options.set_intra_op_thread_pool( + backend->eigen_intra_op_thread_pool_device()); + ServiceExecutableRunOptions run_options( + exec_run_options, /*borrow_stream=*/nullptr, + backend->eigen_intra_op_thread_pool()); + TF_ASSERT_OK_AND_ASSIGN( + auto execution_result, + executable->ExecuteOnStream(&run_options, {lhs_arg.get(), rhs_arg.get()}, + &hlo_execution_profile)); + (void)execution_result; + + *profile_output = + hlo_execution_profile.ToString(executor->GetDeviceDescription()); + + XLA_VLOG_LINES(4, *profile_output); +} + +// TODO(b/71364943): This test exposes a bug in the parallel CPU backend. +XLA_TEST_F(HloProfileTest, DISABLED_ON_CPU_PARALLEL(ProfileSingleComputation)) { + const int64 m = 256, k = 256, n = 256; + Shape lhs_shape = ShapeUtil::MakeShape(F32, {m, k}); + Shape rhs_shape = ShapeUtil::MakeShape(F32, {m, k}); + + TF_ASSERT_OK_AND_ASSIGN(se::Platform * platform, + PlatformUtil::GetDefaultPlatform()); + TF_ASSERT_OK_AND_ASSIGN(LocalClient * client, + ClientLibrary::GetOrCreateLocalClient(platform)); + + ComputationBuilder builder(client, TestName()); + auto result = builder.Tanh(builder.Dot( + builder.Parameter(0, ShapeUtil::MakeShape(F32, {m, k}), "dot_lhs"), + builder.Parameter(1, ShapeUtil::MakeShape(F32, {k, n}), "dot_rhs"))); + + TF_ASSERT_OK_AND_ASSIGN(auto computation, builder.Build()); + + string profile_output; + ExecuteAndFetchProfile(&profile_output, client, computation, lhs_shape, + rhs_shape); + + std::vector profile_output_lines = + tensorflow::str_util::Split(profile_output, '\n'); + + TF_ASSERT_OK_AND_ASSIGN( + ParsedProfileOutputLine total_profile, + ParseProfileOutputLine(profile_output_lines[1], /*expect_flops=*/true, + /*expect_trops=*/true)); + + TF_ASSERT_OK_AND_ASSIGN( + ParsedProfileOutputLine dot_profile, + ParseProfileOutputLine(profile_output_lines[2], /*expect_flops=*/true, + /*expect_trops=*/false)); + + TF_ASSERT_OK_AND_ASSIGN( + ParsedProfileOutputLine tanh_profile, + ParseProfileOutputLine(profile_output_lines[3], /*expect_flops=*/false, + /*expect_trops=*/true)); + + EXPECT_GT(total_profile.cycles, 0); + EXPECT_EQ(total_profile.cycles_percentage, "100.00%"); + + EXPECT_GT(total_profile.cycles, dot_profile.cycles); + EXPECT_NE(dot_profile.cycles_percentage, "0.00%"); + EXPECT_NE(dot_profile.cycles_percentage, "100.00%"); + + EXPECT_GT(total_profile.cycles, tanh_profile.cycles); + EXPECT_NE(tanh_profile.cycles_percentage, "0.00%"); + EXPECT_NE(tanh_profile.cycles_percentage, "100.00%"); +} + +// TODO(b/71364943): This test exposes a bug in the parallel CPU backend. +// +// TODO(b/71544591): The GPU backend does not record cycles spent in on Hlo +// instructions "interior" to while nodes. +XLA_TEST_F(HloProfileTest, + DISABLED_ON_GPU(DISABLED_ON_CPU_PARALLEL(ProfileWhileComputation))) { + const int64 size = 256; + Shape matrix_shape = ShapeUtil::MakeShape(F32, {size, size}); + Shape while_result_shape = + ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(S32, {}), matrix_shape}); + + TF_ASSERT_OK_AND_ASSIGN(se::Platform * platform, + PlatformUtil::GetDefaultPlatform()); + TF_ASSERT_OK_AND_ASSIGN(LocalClient * client, + ClientLibrary::GetOrCreateLocalClient(platform)); + + Computation condition; + { + ComputationBuilder builder(client, "condition"); + auto state = builder.Parameter(0, while_result_shape, "state"); + auto iteration = builder.GetTupleElement(state, 0); + builder.Gt(builder.ConstantR0(5), iteration); + TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build()); + } + + Computation body; + { + ComputationBuilder builder(client, "body"); + auto state = builder.Parameter(0, while_result_shape, "state"); + auto matrix = builder.GetTupleElement(state, 1); + auto next_iteration = builder.Add(builder.GetTupleElement(state, 0), + builder.ConstantR0(1)); + builder.Tuple({next_iteration, builder.Dot(matrix, matrix)}); + TF_ASSERT_OK_AND_ASSIGN(body, builder.Build()); + } + + ComputationBuilder builder(client, TestName()); + auto initial_while_state = + builder.Tuple({builder.ConstantR0(0), + builder.Parameter(0, matrix_shape, "initial_value")}); + auto while_result = builder.While(condition, body, initial_while_state); + builder.Add(builder.GetTupleElement(while_result, 1), + builder.Parameter(1, matrix_shape, "other_value")); + + TF_ASSERT_OK_AND_ASSIGN(auto computation, builder.Build()); + + string profile_output; + ExecuteAndFetchProfile(&profile_output, client, computation, matrix_shape, + matrix_shape); + + std::vector profile_output_lines = + tensorflow::str_util::Split(profile_output, '\n'); + + auto while_body_profile_start = + std::find_if(profile_output_lines.begin(), profile_output_lines.end(), + [](tensorflow::StringPiece s) { + return s.starts_with("Execution profile for body"); + }); + + ASSERT_NE(while_body_profile_start, profile_output_lines.end()); + + TF_ASSERT_OK_AND_ASSIGN( + ParsedProfileOutputLine total_while_body_profile, + ParseProfileOutputLine(*std::next(while_body_profile_start, 1), + /*expect_flops=*/false, + /*expect_trops=*/false)); + + TF_ASSERT_OK_AND_ASSIGN( + ParsedProfileOutputLine dot_profile, + ParseProfileOutputLine(*std::next(while_body_profile_start, 2), + /*expect_flops=*/false, + /*expect_trops=*/false)); + + EXPECT_GT(total_while_body_profile.cycles, 0); + EXPECT_EQ(total_while_body_profile.name, "[total]"); + EXPECT_EQ(total_while_body_profile.cycles_percentage, "100.00%"); + + EXPECT_GT(total_while_body_profile.cycles, dot_profile.cycles); + EXPECT_NE(dot_profile.cycles_percentage, "0.00%"); + EXPECT_NE(dot_profile.cycles_percentage, "100.00%"); +} +} // namespace +} // namespace xla + +static std::pair AddXlaHloProfileFlag(int argc, char** argv) { + // Intentional "leak". + char** new_argv = new char*[argc + 2]; + for (int i = 0; i < argc; i++) { + new_argv[i] = argv[i]; + } + + // We do it this way (as opposed to piping in a modified DebugOptions + // instance) for better end-to-end integration testing. + new_argv[argc] = strdup("--xla_hlo_profile"); + + // Fusion can change the Hlo instructions that show up in the final Hlo + // executable, so block it here. + new_argv[argc + 1] = strdup("--xla_disable_hlo_passes=fusion"); + return {argc + 2, new_argv}; +} + +GTEST_API_ int main(int argc, char** argv) { + std::vector flag_list; + xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); + std::tie(argc, argv) = AddXlaHloProfileFlag(argc, argv); + + auto usage = tensorflow::Flags::Usage(argv[0], flag_list); + if (!tensorflow::Flags::Parse(&argc, argv, flag_list)) { + LOG(ERROR) << "\n" << usage; + return 2; + } + + testing::InitGoogleTest(&argc, argv); + if (argc > 1) { + LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage; + return 2; + } + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/compiler/xla/text_literal_reader.cc b/tensorflow/compiler/xla/text_literal_reader.cc index 4d060895d357493327ec50b38016478c65fef94d..6fa4c48e11d1102367b21bc21d4734466495ef0e 100644 --- a/tensorflow/compiler/xla/text_literal_reader.cc +++ b/tensorflow/compiler/xla/text_literal_reader.cc @@ -102,9 +102,9 @@ StatusOr> TextLiteralReader::ReadAllLines() { ShapeUtil::HumanString(shape).c_str()); } - auto result = MakeUnique(); + auto result = MakeUnique(shape); const float fill = std::numeric_limits::quiet_NaN(); - result->PopulateWithValue(fill, AsInt64Slice(shape.dimensions())); + result->PopulateWithValue(fill); std::vector pieces; std::vector coordinates; std::vector coordinate_values; diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc index 78d8fb1f4330aed899ca917e66fae819a002b3a9..24417a0cb8212e59cc0af53bd5bb21afcf3e134b 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc @@ -69,7 +69,7 @@ void RealMain(tensorflow::gtl::ArraySlice args, bool compile) { fprintf(stdout, "HLO compiled for %s backend:\n%s\n", local_service->backend().platform()->Name().c_str(), - module.ToString().c_str()); + module.ToString(HloPrintOptions::ShortParsable()).c_str()); } else { const ComputationTracker& tracker = local_service->computation_tracker(); UserComputation* user_computation = @@ -80,7 +80,8 @@ void RealMain(tensorflow::gtl::ArraySlice args, bool compile) { tracker.BuildHloModule(versioned_handle, HloModuleConfig()) .ConsumeValueOrDie(); - fprintf(stdout, "%s\n", module->ToString().c_str()); + fprintf(stdout, "%s\n", + module->ToString(HloPrintOptions::ShortParsable()).c_str()); } } } diff --git a/tensorflow/compiler/xla/tools/parser/BUILD b/tensorflow/compiler/xla/tools/parser/BUILD index ce936af6c3376387c1ed9fa48da23b8af537f6e5..97aacf6b39f83978e732060817cd93ede81ca782 100644 --- a/tensorflow/compiler/xla/tools/parser/BUILD +++ b/tensorflow/compiler/xla/tools/parser/BUILD @@ -34,9 +34,9 @@ cc_library( deps = [ "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/service:hlo", "//tensorflow/core:lib", "//tensorflow/core:regexp_internal", ], diff --git a/tensorflow/compiler/xla/tools/parser/README.md b/tensorflow/compiler/xla/tools/parser/README.md index 6232967f5f04cbf316d985357ae84c28335531e2..f0f3dd7785c13e505e1eb6d4c8cd4bad157c4993 100644 --- a/tensorflow/compiler/xla/tools/parser/README.md +++ b/tensorflow/compiler/xla/tools/parser/README.md @@ -1,24 +1,26 @@ -# HloModule string syntax - -TODO: Support all subcomputations (for fusion, reduce, ...). - -TODO: Support all extra attributes, e.g. dimensions, strides. +# HLO Text Syntax ```yacc hlo_module : 'HloModule' name computations ; +/* If no computation is marked as ENTRY, the last computation will be the entry +computation of the module.*/ computations : computation | computation computations ; computation - : 'ENTRY' name param_list '->' shape instruction_list - | name param_list '->' shape instruction_list + : 'ENTRY' name param_list_to_shape instruction_list + | name param_list_to_shape instruction_list + | 'ENTRY' name instruction_list + | name instruction_list ; +/* If no instruction is marked as ROOT, the last instruction will be the root of +its computation. */ instruction_list : '{' instruction_list1 '}' ; @@ -41,6 +43,7 @@ operands1 ; operand : shape name + | name ; attributes @@ -60,6 +63,10 @@ attribute_value | '{' sub_attributes '}' ; +param_list_to_shape + : param_list '->' shape + ; + param_list : '(' param_list1 ')' ; @@ -84,6 +91,7 @@ tuple_elements name : identifier ':' | '%' identifier + | identifier ; identifier @@ -108,7 +116,29 @@ non_tuple | rank2345 ; rank2345 - : shape nested_array + : shape sparse_or_nested_array + ; +sparse_or_nested_array + : sparse_array + | nested_array + ; +sparse_array + : '{' sparse_array1 '}' + ; +sparse_array1 + : sparse_array_item + | sparse_array1 ',' sparse_array_item + ; +sparse_array_item + : multi_index ':' scalar + ; +multi_index + : kInt + | '[' multi_index1 ']' + ; +multi_index1 + : kInt + | multi_index1 ',' kInt ; ``` diff --git a/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc b/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc index 56744440db1b17aa1cc8823feb1bad279f8f4f75..fc0e4444521247734fc240a03da669244fe1a6a4 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc +++ b/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc @@ -17,7 +17,6 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/util.h" @@ -153,21 +152,21 @@ TokKind HloLexer::LexToken() { } } -// Lex a shape, name, keyword, opcode, attribute name, or the dim labels -// pattern. +// Lex a shape, name, keyword, attribute name, the dim labels pattern, and +// other identifiers. // // shape ::= ([a-zA-Z0-9_]*[0-9]*)\[([0-9,]*)\](?:\s*{([0-9,]*)})? // name ::= [a-zA-Z_][a-zA-Z0-9_.-]*: // keyword ::= HloModule, ENTRY, ... -// opcode ::= add, greater-than, ... // attribute_name ::= condition, body, dimensions, ... // dim_labels_pattern ::= [0-9bf]{2,}_[0-9io]{2,}->[0-9bf]{2,} +// identifiers ::= other cases that match [a-zA-Z_][a-zA-Z0-9_.-]* TokKind HloLexer::LexIdentifier() { { auto consumable = RegexpStringPieceFromPointers(token_start_, buf_.end()); // 'consumable' will be advanced iff its prefix matches the pattern. static LazyRE2 shape_pattern = { - R"(^(\w*\d*)\[([\d,]*)\](?:{([\d,]*)})?)"}; + R"(^(\w*\d*)\[([\d,]*)\](?:(dense|sparse)?{([\d,]+)})?)"}; if (RE2::Consume(&consumable, *shape_pattern)) { auto status_or_shape = ShapeUtil::ParseShapeString( StringPieceFromPointers(token_start_, consumable.begin())); @@ -220,20 +219,6 @@ TokKind HloLexer::LexIdentifier() { #undef KEYWORD - // See if this is an opcode. - auto opcode = StringToHloOpcode(identifier.ToString()); - if (opcode.ok()) { - opcode_val_ = opcode.ValueOrDie(); - return TokKind::kOpcode; - } - - // See if this is an fusion kind. - auto kind = xla::StringToFusionKind(identifier.ToString()); - if (kind.ok()) { - fusion_kind_val_ = kind.ValueOrDie(); - return TokKind::kFusionKind; - } - { auto consumable = RegexpStringPieceFromPointers(token_start_, buf_.end()); static LazyRE2 dim_labels_pattern = { @@ -244,8 +229,9 @@ TokKind HloLexer::LexIdentifier() { return TokKind::kDimLabels; } } - current_ptr_ = token_start_ + 1; - return TokKind::kError; + + str_val_ = identifier.ToString(); + return TokKind::kIdent; } // Lex names after a % character. @@ -271,7 +257,8 @@ TokKind HloLexer::LexPercent() { // fp without exp ::= [-]?([0-9]+[.][0-9]*|[0-9]*[.][0-9]+) // dim_labels_pattern ::= [0-9bf]{2,}_[0-9io]{2,}->[0-9bf]{2,} // dxd_pattern ::= [0-9]+(x[0-9]+)+ -// pad_pattern ::= [0-9]+_[0-9]+(_[0-9]+)?(x[0-9]+_[0-9]+(_[0-9]+)?)* +// pad_pattern ::= +// [-]?[0-9]+_[-]?[0-9]+(_[0-9]+)?(x[-]?[0-9]+_[-]?[0-9]+(_[0-9]+)?)* // int ::= [-]?[0-9]+ // negative inf ::= '-inf' TokKind HloLexer::LexNumberOrPattern() { @@ -289,7 +276,7 @@ TokKind HloLexer::LexNumberOrPattern() { R"([0-9bf]{2,}_[0-9io]{2,}->[0-9bf]{2,})"}; static LazyRE2 dxd_pattern = {R"([0-9]+(x[0-9]+)+)"}; static LazyRE2 pad_pattern = { - R"([0-9]+_[0-9]+(_[0-9]+)?(x[0-9]+_[0-9]+(_[0-9]+)?)*)"}; + R"([-]?[0-9]+_[-]?[0-9]+(_[0-9]+)?(x[-]?[0-9]+_[-]?[0-9]+(_[0-9]+)?)*)"}; if (RE2::Consume(&consumable, *dim_labels_pattern)) { current_ptr_ = consumable.begin(); @@ -326,18 +313,43 @@ TokKind HloLexer::LexNumberOrPattern() { return TokKind::kError; } -StringPiece HloLexer::GetCurrentLine() const { - const char* start = token_start_; - const char* end = current_ptr_; - if (!CanDereference(start) || !CanDereference(end)) { - return "LINE OUT OF RANGE"; +std::pair HloLexer::GetLineAndColumn(LocTy location) const { + unsigned line_no = 1; + const char* start = buf_.begin(); + const char* ptr = start; + if (line_no_cache_.last_query && CanDereference(line_no_cache_.last_query) && + line_no_cache_.last_query <= location) { + ptr = line_no_cache_.last_query; + line_no = line_no_cache_.line_no_of_query; } - while (start > buf_.begin() && *start != '\n') { - start--; + for (; ptr != location; ptr++) { + if (*ptr == '\n') { + line_no++; + } } - while (end < buf_.end() && *end != '\n') { - end++; + + // Update the line number cache. + line_no_cache_.last_query = ptr; + line_no_cache_.line_no_of_query = line_no; + size_t line_offset = StringPieceFromPointers(start, ptr).rfind('\n'); + if (line_offset == StringPiece::npos) { + line_offset = 0; } + return {line_no, ptr - start - line_offset}; +} + +StringPiece HloLexer::GetLine(LocTy loc) const { + if (!CanDereference(loc)) { + return "LINE OUT OF RANGE"; + } + size_t line_start = + StringPieceFromPointers(buf_.begin(), loc + 1).rfind('\n'); + const char* start = line_start == StringPiece::npos + ? buf_.begin() + : buf_.begin() + line_start + 1; + size_t line_end = StringPieceFromPointers(loc, buf_.end()).find('\n'); + const char* end = line_end == StringPiece::npos ? buf_.end() : loc + line_end; + return StringPieceFromPointers(start, end); } @@ -428,14 +440,12 @@ string TokKindToString(TokKind kind) { return "kDxD"; case TokKind::kPad: return "kPad"; + case TokKind::kIdent: + return "kIdent"; case TokKind::kString: return "kString"; case TokKind::kShape: return "kShape"; - case TokKind::kOpcode: - return "kOpcode"; - case TokKind::kFusionKind: - return "kFusionKind"; case TokKind::kInt: return "kInt"; case TokKind::kDecimal: diff --git a/tensorflow/compiler/xla/tools/parser/hlo_lexer.h b/tensorflow/compiler/xla/tools/parser/hlo_lexer.h index 5c9d1bf3912584040dc5260cc6730247d439fd60..27880b9b8afbfa58abfedc3b2cecd5236b78a6d6 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_lexer.h +++ b/tensorflow/compiler/xla/tools/parser/hlo_lexer.h @@ -18,9 +18,8 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/tools/parser/hlo_token.h" +#include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/platform/logging.h" @@ -48,6 +47,7 @@ class HloLexer { case TokKind::kDxD: case TokKind::kPad: case TokKind::kString: + case TokKind::kIdent: return str_val_; default: LOG(FATAL) << "This token does not have string value"; @@ -57,14 +57,6 @@ class HloLexer { CHECK(GetKind() == TokKind::kShape); return shape_val_; } - HloOpcode GetOpcodeVal() const { - CHECK(GetKind() == TokKind::kOpcode); - return opcode_val_; - } - HloInstruction::FusionKind GetFusionKindVal() const { - CHECK(GetKind() == TokKind::kFusionKind); - return fusion_kind_val_; - } int64 GetInt64Val() const { CHECK(GetKind() == TokKind::kInt); return int64_val_; @@ -74,8 +66,16 @@ class HloLexer { return decimal_val_; } - // Returns the line of text that is currently being lexed. - tensorflow::StringPiece GetCurrentLine() const; + typedef const char* LocTy; + + // Returns the location of the current token. + LocTy GetLoc() const { return token_start_; } + + // Returns the line and column of a location in the buffer. + std::pair GetLineAndColumn(LocTy location) const; + + // Returns the whole line given the location. + tensorflow::StringPiece GetLine(LocTy loc) const; private: // Returns the current character. If it's neither the end of input buffer nor @@ -114,10 +114,15 @@ class HloLexer { TokKind current_kind_; string str_val_; Shape shape_val_; - HloOpcode opcode_val_; - HloInstruction::FusionKind fusion_kind_val_; int64 int64_val_; double decimal_val_; + + struct LineNoCacheTy { + const char* last_query; + unsigned line_no_of_query; + }; + // This caches the line number of the previous query. + mutable LineNoCacheTy line_no_cache_{nullptr, 0}; }; } // namespace tools diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc index 47979ec6f361789f29e8f7ff47793747330551fc..1c68e271e0f75d8facc36bd0878190f3db512972 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc +++ b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/gtl/map_util.h" @@ -40,6 +41,8 @@ const double kF16max = 65504; // Parser for the HloModule::ToString() format text. class HloParser { public: + using LocTy = HloLexer::LocTy; + explicit HloParser(StringPiece str, const HloModuleConfig& config) : lexer_(str), config_(config) {} @@ -56,7 +59,7 @@ class HloParser { // ParseXXX returns false if an error occurred. bool ParseHloModule(); bool ParseComputations(); - bool ParseComputation(); + bool ParseComputation(HloComputation** entry_computation); bool ParseInstructionList(HloComputation::Builder* builder, string* root_name); bool ParseInstruction(HloComputation::Builder* builder, string* root_name); @@ -65,6 +68,13 @@ class HloParser { bool ParseTupleLiteral(std::unique_ptr* literal, const Shape& shape); bool ParseNonTupleLiteral(std::unique_ptr* literal, const Shape& shape); + bool ParseDenseLiteral(std::unique_ptr* literal, const Shape& shape); + bool ParseSparseLiteral(std::unique_ptr* literal, + const Shape& shape); + template + bool ParseSparseLiteralHelper(std::unique_ptr* literal, + const Shape& shape); + // Sets the sub-value of literal at the given index to the given value. The // literal's shape must have the default layout. bool SetValueInLiteral(int64 value, int64 linear_index, Literal* literal); @@ -96,6 +106,7 @@ class HloParser { kString, kBracedInt64List, kHloComputation, + kFftType, kWindow, kConvolutionDimensionNumbers, kSharding, @@ -104,6 +115,7 @@ class HloParser { kPaddingConfig, kMetadata, kFusionKind, + kDistribution, }; struct AttrConfig { @@ -167,20 +179,30 @@ class HloParser { bool ParseInt64List(const TokKind start, const TokKind end, const TokKind delim, std::vector* result); + bool ParseParamListToShape(Shape* shape, LocTy* shape_loc); bool ParseParamList(); bool ParseName(string* result); bool ParseAttributeName(string* result); bool ParseString(string* result); bool ParseShape(Shape* result); bool ParseOpcode(HloOpcode* result); + bool ParseFftType(FftType* result); bool ParseFusionKind(HloInstruction::FusionKind* result); + bool ParseRandomDistribution(RandomDistribution* result); bool ParseInt64(int64* result); bool ParseDouble(double* result); bool ParseBool(bool* result); bool ParseToken(TokKind kind, const string& msg); + // Returns true if the current token is the beginning of a shape. + bool CanBeShape(); + // Returns true if the current token is the beginning of a + // param_list_to_shape. + bool CanBeParamListToShape(); + // Logs the current parsing line and the given message. Always returns false. bool TokenError(StringPiece msg); + bool Error(LocTy loc, StringPiece msg); // If the current token is 'kind', eats it (i.e. lexes the next token) and // returns true. @@ -191,10 +213,12 @@ class HloParser { // Adds the instruction to the pool. Returns false and emits an error if the // instruction already exists. - bool AddInstruction(const string& name, HloInstruction* instruction); + bool AddInstruction(const string& name, HloInstruction* instruction, + LocTy name_loc); // Adds the computation to the pool. Returns false and emits an error if the // computation already exists. - bool AddComputation(const string& name, HloComputation* computation); + bool AddComputation(const string& name, HloComputation* computation, + LocTy name_loc); // The map from the instruction name to the instruction. This does not own the // instructions. @@ -203,19 +227,30 @@ class HloParser { HloLexer lexer_; std::unique_ptr module_; + std::vector> computations_; const HloModuleConfig config_; std::vector error_; }; -bool HloParser::TokenError(StringPiece msg) { - const string error = - StrCat("was parsing \"", lexer_.GetCurrentLine(), "\"; token ", - TokKindToString(lexer_.GetKind()), "; ", msg); - VLOG(1) << "TokenError: " << error; - error_.push_back(error); +bool HloParser::Error(LocTy loc, StringPiece msg) { + auto line_col = lexer_.GetLineAndColumn(loc); + const unsigned line = line_col.first; + const unsigned col = line_col.second; + std::vector error_lines; + error_lines.push_back( + StrCat("was parsing ", line, ":", col, ": error: ", msg)); + error_lines.push_back(lexer_.GetLine(loc).ToString()); + error_lines.push_back(col == 0 ? "" : StrCat(string(col - 1, ' '), "^")); + + error_.push_back(tensorflow::str_util::Join(error_lines, "\n")); + VLOG(1) << "Error: " << error_.back(); return false; } +bool HloParser::TokenError(StringPiece msg) { + return Error(lexer_.GetLoc(), msg); +} + bool HloParser::Run() { lexer_.Lex(); return ParseHloModule(); @@ -241,27 +276,67 @@ bool HloParser::ParseHloModule() { // computations ::= (computation)+ bool HloParser::ParseComputations() { + HloComputation* entry_computation = nullptr; do { - if (!ParseComputation()) { + if (!ParseComputation(&entry_computation)) { return false; } } while (lexer_.GetKind() != TokKind::kEof); + + for (int i = 0; i < computations_.size(); i++) { + // If entry_computation is not nullptr, it means the computation it pointed + // to is marked with "ENTRY"; otherwise, no computation is marked with + // "ENTRY", and we use the last computation as the entry computation. We + // add the non-entry computations as embedded computations to the module. + if ((entry_computation != nullptr && + computations_[i].get() != entry_computation) || + (entry_computation == nullptr && i != computations_.size() - 1)) { + module_->AddEmbeddedComputation(std::move(computations_[i])); + continue; + } + auto computation = + module_->AddEntryComputation(std::move(computations_[i])); + // The parameters and result layouts were set to default layout. Here we + // set the layouts to what the hlo text says. + for (int p = 0; p < computation->num_parameters(); p++) { + const Shape& param_shape = computation->parameter_instruction(p)->shape(); + if (param_shape.has_layout()) { + module_->mutable_entry_computation_layout() + ->mutable_parameter_layout(p) + ->ResetLayout(param_shape.layout()); + } + } + const Shape& result_shape = computation->root_instruction()->shape(); + if (result_shape.has_layout()) { + module_->mutable_entry_computation_layout() + ->mutable_result_layout() + ->ResetLayout(result_shape.layout()); + } + } + return true; } -// computation ::= ('ENTRY')? name param_list '->' shape instruction_list -bool HloParser::ParseComputation() { +// computation ::= ('ENTRY')? name (param_list_to_shape)? instruction_list +bool HloParser::ParseComputation(HloComputation** entry_computation) { + LocTy maybe_entry_loc = lexer_.GetLoc(); const bool is_entry_computation = EatIfPresent(TokKind::kw_ENTRY); + string name; + LocTy name_loc = lexer_.GetLoc(); if (!ParseName(&name)) { return false; } auto builder = MakeUnique(name); + LocTy shape_loc = nullptr; Shape shape; + if (CanBeParamListToShape() && !ParseParamListToShape(&shape, &shape_loc)) { + return false; + } + string root_name; - if (!ParseParamList() || !ParseToken(TokKind::kArrow, "expects '->'") || - !ParseShape(&shape) || !ParseInstructionList(builder.get(), &root_name)) { + if (!ParseInstructionList(builder.get(), &root_name)) { return false; } @@ -273,14 +348,37 @@ bool HloParser::ParseComputation() { LOG(FATAL) << "instruction " << root_name << " was marked as ROOT but the parser has not seen it before"; } + // Now root can be either an existing instruction or a nullptr. If it's a // nullptr, the implementation of Builder will set the last instruction as // root instruction. - HloComputation* computation = - is_entry_computation - ? module_->AddEntryComputation(builder->Build(root)) - : module_->AddEmbeddedComputation(builder->Build(root)); - return AddComputation(name, computation); + computations_.emplace_back(builder->Build(root)); + HloComputation* computation = computations_.back().get(); + + if (!root) { + root = computation->root_instruction(); + } else { + CHECK_EQ(root, computation->root_instruction()); + } + + // If param_list_to_shape was present, check compatibility. + if (shape_loc != nullptr && !ShapeUtil::Compatible(root->shape(), shape)) { + return Error( + shape_loc, + StrCat("Shape of computation ", name, ", ", + ShapeUtil::HumanString(shape), + ", is not compatible with that of its root instruction ", + root_name, ", ", ShapeUtil::HumanString(root->shape()))); + } + + if (is_entry_computation) { + if (*entry_computation != nullptr) { + return Error(maybe_entry_loc, "expects only one ENTRY"); + } + *entry_computation = computation; + } + + return AddComputation(name, computation, name_loc); } // instruction_list ::= '{' instruction_list1 '}' @@ -307,13 +405,21 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, Shape shape; HloOpcode opcode; std::vector operands; + + LocTy maybe_root_loc = lexer_.GetLoc(); bool is_root = EatIfPresent(TokKind::kw_ROOT); + + const LocTy name_loc = lexer_.GetLoc(); if (!ParseName(&name) || !ParseToken(TokKind::kEqual, "expects '=' in instruction") || !ParseShape(&shape) || !ParseOpcode(&opcode)) { return false; } + if (is_root) { + if (!root_name->empty()) { + return Error(maybe_root_loc, "one computation should have only one ROOT"); + } *root_name = name; } @@ -395,7 +501,6 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, case HloOpcode::kLe: case HloOpcode::kLt: case HloOpcode::kNe: - case HloOpcode::kDot: case HloOpcode::kMaximum: case HloOpcode::kMinimum: case HloOpcode::kPower: @@ -444,12 +549,11 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kCrossReplicaSum: { - if (!ParseOperands(&operands, /*expected_size=*/1) || - !ParseAttributes(attrs)) { + if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { return false; } instruction = builder->AddInstruction( - HloInstruction::CreateCrossReplicaSum(shape, operands[0])); + HloInstruction::CreateCrossReplicaSum(shape, operands)); break; } case HloOpcode::kReshape: { @@ -590,6 +694,20 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, shape, /*lhs=*/operands[0], /*rhs=*/operands[1], *window, *dnums)); break; } + case HloOpcode::kFft: { + optional fft_type; + optional> fft_length; + attrs["fft_type"] = {/*required=*/true, AttrTy::kFftType, &fft_type}; + attrs["fft_length"] = {/*required=*/true, AttrTy::kBracedInt64List, + &fft_length}; + if (!ParseOperands(&operands, /*expected_size=*/1) || + !ParseAttributes(attrs)) { + return false; + } + instruction = builder->AddInstruction(HloInstruction::CreateFft( + shape, operands[0], *fft_type, *fft_length)); + break; + } case HloOpcode::kBroadcast: { optional> broadcast_dimensions; attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, @@ -816,15 +934,110 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, shape, operands[0], config ? *config : "")); break; } - case HloOpcode::kConditional: - case HloOpcode::kCustomCall: - case HloOpcode::kReducePrecision: - case HloOpcode::kRng: + case HloOpcode::kRng: { + optional distribution; + attrs["distribution"] = {/*required=*/true, AttrTy::kDistribution, + &distribution}; + if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { + return false; + } + instruction = builder->AddInstruction( + HloInstruction::CreateRng(shape, *distribution, operands)); + break; + } + case HloOpcode::kReducePrecision: { + optional exponent_bits; + optional mantissa_bits; + attrs["exponent_bits"] = {/*required=*/true, AttrTy::kInt64, + &exponent_bits}; + attrs["mantissa_bits"] = {/*required=*/true, AttrTy::kInt64, + &mantissa_bits}; + if (!ParseOperands(&operands, /*expected_size=*/1) || + !ParseAttributes(attrs)) { + return false; + } + instruction = + builder->AddInstruction(HloInstruction::CreateReducePrecision( + shape, operands[0], static_cast(*exponent_bits), + static_cast(*mantissa_bits))); + break; + } + case HloOpcode::kConditional: { + optional true_computation; + optional false_computation; + attrs["true_computation"] = {/*required=*/true, AttrTy::kHloComputation, + &true_computation}; + attrs["false_computation"] = {/*required=*/true, AttrTy::kHloComputation, + &false_computation}; + if (!ParseOperands(&operands, /*expected_size=*/3) || + !ParseAttributes(attrs)) { + return false; + } + instruction = builder->AddInstruction(HloInstruction::CreateConditional( + shape, /*pred=*/operands[0], + /*true_computation_arg=*/operands[1], *true_computation, + /*false_computation_arg=*/operands[2], *false_computation)); + break; + } + case HloOpcode::kCustomCall: { + optional custom_call_target; + attrs["custom_call_target"] = {/*required=*/true, AttrTy::kString, + &custom_call_target}; + if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { + return false; + } + instruction = builder->AddInstruction(HloInstruction::CreateCustomCall( + shape, operands, *custom_call_target)); + break; + } + case HloOpcode::kDot: { + optional> lhs_contracting_dims; + attrs["lhs_contracting_dims"] = { + /*required=*/false, AttrTy::kBracedInt64List, &lhs_contracting_dims}; + optional> rhs_contracting_dims; + attrs["rhs_contracting_dims"] = { + /*required=*/false, AttrTy::kBracedInt64List, &rhs_contracting_dims}; + optional> lhs_batch_dims; + attrs["lhs_batch_dims"] = {/*required=*/false, AttrTy::kBracedInt64List, + &lhs_batch_dims}; + optional> rhs_batch_dims; + attrs["rhs_batch_dims"] = {/*required=*/false, AttrTy::kBracedInt64List, + &rhs_batch_dims}; + + if (!ParseOperands(&operands, /*expected_size=*/2) || + !ParseAttributes(attrs)) { + return false; + } + + DotDimensionNumbers dnum; + if (lhs_contracting_dims) { + *dnum.mutable_lhs_contracting_dimensions() = { + lhs_contracting_dims->begin(), lhs_contracting_dims->end()}; + } + if (rhs_contracting_dims) { + *dnum.mutable_rhs_contracting_dimensions() = { + rhs_contracting_dims->begin(), rhs_contracting_dims->end()}; + } + if (lhs_batch_dims) { + *dnum.mutable_lhs_batch_dimensions() = {lhs_batch_dims->begin(), + lhs_batch_dims->end()}; + } + if (rhs_batch_dims) { + *dnum.mutable_rhs_batch_dimensions() = {rhs_batch_dims->begin(), + rhs_batch_dims->end()}; + } + + instruction = builder->AddInstruction( + HloInstruction::CreateDot(shape, operands[0], operands[1], dnum)); + break; + } case HloOpcode::kTrace: return TokenError(StrCat("parsing not yet implemented for op: ", HloOpcodeString(opcode))); } + instruction->set_name(name); + // Add common attrs (sharding, control predecessors) to the instruction, if // they were seen. if (sharding) { @@ -835,15 +1048,15 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, for (auto* pre : *predecessors) { Status status = pre->AddControlDependencyTo(instruction); if (!status.ok()) { - return TokenError(StrCat("error adding control dependency for: ", name, - " status: ", status.ToString())); + return Error(name_loc, StrCat("error adding control dependency for: ", + name, " status: ", status.ToString())); } } } if (metadata) { instruction->set_metadata(*metadata); } - return AddInstruction(name, instruction); + return AddInstruction(name, instruction, name_loc); } // NOLINT(readability/fn_size) // ::= '{' (single_sharding | tuple_sharding) '}' @@ -889,6 +1102,7 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding, return false; } + LocTy loc = lexer_.GetLoc(); bool maximal = false; bool replicated = false; std::vector devices; @@ -956,34 +1170,35 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding, if (replicated) { if (!devices.empty()) { - return TokenError( - "replicated shardings should not have any devices assigned"); + return Error(loc, + "replicated shardings should not have any devices assigned"); } if (!ShapeUtil::Equal(tile_shape, Shape())) { - return TokenError( - "replicated shardings should not have any tile shape set"); + return Error(loc, + "replicated shardings should not have any tile shape set"); } sharding->set_type(OpSharding::Type::OpSharding_Type_REPLICATED); } else if (maximal) { if (devices.size() != 1) { - return TokenError( - "maximal shardings should have exactly one device assigned"); + return Error(loc, + "maximal shardings should have exactly one device assigned"); } if (!ShapeUtil::Equal(tile_shape, Shape())) { - return TokenError("maximal shardings should not have any tile shape set"); + return Error(loc, "maximal shardings should not have any tile shape set"); } sharding->set_type(OpSharding::Type::OpSharding_Type_MAXIMAL); sharding->add_tile_assignment_devices(devices[0]); } else { if (devices.size() <= 1) { - return TokenError( - "non-maximal shardings must have more than one device assigned"); + return Error( + loc, "non-maximal shardings must have more than one device assigned"); } if (ShapeUtil::Equal(tile_shape, Shape())) { - return TokenError("non-maximal shardings should have a tile shape set"); + return Error(loc, "non-maximal shardings should have a tile shape set"); } if (tile_assignment_dimensions.empty()) { - return TokenError( + return Error( + loc, "non-maximal shardings must have a tile assignment list including " "dimensions"); } @@ -1008,10 +1223,11 @@ bool HloParser::ParseInstructionNames( "expects '{' at the beginning of instruction name list")) { return false; } + LocTy loc = lexer_.GetLoc(); do { string name; if (!ParseName(&name)) { - return TokenError("expects a instruction name"); + return Error(loc, "expects a instruction name"); } HloInstruction* instr = tensorflow::gtl::FindPtrOrNull(instruction_pool_, name); @@ -1023,7 +1239,7 @@ bool HloParser::ParseInstructionNames( } while (EatIfPresent(TokKind::kComma)); return ParseToken(TokKind::kRbrace, - "expects '}' at the end of control instructions"); + "expects '}' at the end of instruction name list"); } bool HloParser::SetValueInLiteral(int64 value, int64 linear_index, @@ -1058,6 +1274,8 @@ bool HloParser::SetValueInLiteral(double value, int64 linear_index, switch (shape.element_type()) { case F16: return SetValueInLiteralHelper(value, linear_index, literal); + case BF16: + return SetValueInLiteralHelper(value, linear_index, literal); case F32: return SetValueInLiteralHelper(value, linear_index, literal); case F64: @@ -1096,7 +1314,8 @@ bool HloParser::SetValueInLiteralHelper(ParsedElemT value, int64 linear_index, (std::numeric_limits::infinity() == value || -std::numeric_limits::infinity() == value))) { // Skip range checking for non-finite value. - } else if (literal->shape().element_type() == F16) { + } else if (literal->shape().element_type() == F16 || + literal->shape().element_type() == BF16) { if (value > kF16max || value < -kF16max) { return TokenError(StrCat( "value ", value, " is out of range for literal's primitive type ", @@ -1112,7 +1331,7 @@ bool HloParser::SetValueInLiteralHelper(ParsedElemT value, int64 linear_index, PrimitiveType_Name(literal->shape().element_type()))); } - literal->GetMutableArraySlice().at(linear_index) = + literal->data().at(linear_index) = static_cast(value); return true; } @@ -1179,9 +1398,19 @@ bool HloParser::ParseTupleLiteral(std::unique_ptr* literal, // non_tuple // ::= rank01 // ::= rank2345 -// rank2345 ::= shape nested_array +// rank2345 ::= shape sparse_or_nested_array bool HloParser::ParseNonTupleLiteral(std::unique_ptr* literal, const Shape& shape) { + if (LayoutUtil::IsSparseArray(shape)) { + return ParseSparseLiteral(literal, shape); + } + + CHECK(LayoutUtil::IsDenseArray(shape)); + return ParseDenseLiteral(literal, shape); +} + +bool HloParser::ParseDenseLiteral(std::unique_ptr* literal, + const Shape& shape) { const int64 rank = ShapeUtil::Rank(shape); if (rank > 1 && !EatShapeAndCheckCompatible(shape)) { return false; @@ -1282,26 +1511,28 @@ bool HloParser::ParseNonTupleLiteral(std::unique_ptr* literal, } lexer_.Lex(); } else if (primitive_util::IsIntegralType(shape.element_type())) { + LocTy loc = lexer_.GetLoc(); int64 value; if (!ParseInt64(&value)) { - return TokenError(StrCat("expects integer for primitive type: ", + return Error(loc, StrCat("expects integer for primitive type: ", PrimitiveType_Name(shape.element_type()))); } if (!SetValueInLiteral(value, linear_index++, literal->get())) { return false; } } else if (primitive_util::IsFloatingPointType(shape.element_type())) { + LocTy loc = lexer_.GetLoc(); double value; if (!ParseDouble(&value)) { - return TokenError( - StrCat("expect floating point value for primitive type: ", - PrimitiveType_Name(shape.element_type()))); + return Error( + loc, StrCat("expect floating point value for primitive type: ", + PrimitiveType_Name(shape.element_type()))); } if (!SetValueInLiteral(value, linear_index++, literal->get())) { return false; } } else { - return TokenError(StrCat("unsupported premitive type ", + return TokenError(StrCat("unsupported primitive type ", PrimitiveType_Name(shape.element_type()))); } break; @@ -1313,11 +1544,147 @@ bool HloParser::ParseNonTupleLiteral(std::unique_ptr* literal, return true; } +bool HloParser::ParseSparseLiteral(std::unique_ptr* literal, + const Shape& shape) { + if (!EatShapeAndCheckCompatible(shape)) { + return false; + } + + switch (shape.element_type()) { + case PRED: + return ParseSparseLiteralHelper(literal, shape); + case S8: + return ParseSparseLiteralHelper(literal, shape); + case S16: + return ParseSparseLiteralHelper(literal, shape); + case S32: + return ParseSparseLiteralHelper(literal, shape); + case S64: + return ParseSparseLiteralHelper(literal, shape); + case U8: + return ParseSparseLiteralHelper(literal, shape); + case U16: + return ParseSparseLiteralHelper(literal, shape); + case U32: + return ParseSparseLiteralHelper(literal, shape); + case U64: + return ParseSparseLiteralHelper(literal, shape); + case F16: + return ParseSparseLiteralHelper(literal, shape); + case F32: + return ParseSparseLiteralHelper(literal, shape); + case BF16: + return ParseSparseLiteralHelper(literal, shape); + case F64: + return ParseSparseLiteralHelper(literal, shape); + default: + return Error(lexer_.GetLoc(), + StrCat("invalid primitive type for sparse literal: ", + PrimitiveType_Name(shape.element_type()))); + } +} + +template +bool HloParser::ParseSparseLiteralHelper(std::unique_ptr* literal, + const Shape& shape) { + std::vector index; + + int64 rank = ShapeUtil::Rank(shape); + + *literal = MakeUnique(shape); + + if (!ParseToken(TokKind::kLbrace, + "expects '{' at the beginning of a sparse literal")) { + return false; + } + + for (;;) { + if (lexer_.GetKind() == TokKind::kRbrace) { + lexer_.Lex(); + break; + } + + LocTy index_loc = lexer_.GetLoc(); + index.clear(); + if (lexer_.GetKind() == TokKind::kInt) { + int64 single_index = lexer_.GetInt64Val(); + lexer_.Lex(); + if (rank != 1) { + return Error( + index_loc, + StrCat("invalid single-dimensional index for shape with rank ", + rank, ": ", single_index)); + } + index.push_back(single_index); + } else { + if (!ParseInt64List(TokKind::kLsquare, TokKind::kRsquare, TokKind::kComma, + &index)) { + return false; + } + if (index.size() != rank) { + return Error( + index_loc, + StrCat("invalid multi-dimension index for shape with rank ", rank, + ": [", tensorflow::str_util::Join(index, ", "), "]")); + } + } + if (!ParseToken(TokKind::kColon, + "expects ':' after after the sparse array index and before " + "the sparse array value")) { + return false; + } + LocTy value_loc = lexer_.GetLoc(); + LiteralNativeT value; + if (lexer_.GetKind() == TokKind::kw_true || + lexer_.GetKind() == TokKind::kw_false) { + value = static_cast(lexer_.GetKind() == TokKind::kw_true); + lexer_.Lex(); + } else if (primitive_util::IsIntegralType(shape.element_type())) { + int64 value_s64; + if (!ParseInt64(&value_s64)) { + return Error(value_loc, + StrCat("expects integer for primitive type: ", + PrimitiveType_Name(shape.element_type()))); + } + value = static_cast(value_s64); + } else if (primitive_util::IsFloatingPointType(shape.element_type())) { + double value_f64; + if (!ParseDouble(&value_f64)) { + return Error(value_loc, + StrCat("expects floating point value for primitive type: ", + PrimitiveType_Name(shape.element_type()))); + } + value = static_cast(value_f64); + } else { + LOG(FATAL) << "Unexpected element type: " + << PrimitiveType_Name(shape.element_type()); + } + if (lexer_.GetKind() != TokKind::kRbrace && + !ParseToken(TokKind::kComma, + "expects ',' separator between sparse array elements")) { + return false; + } + + if ((*literal)->sparse_element_count() + 1 == + LayoutUtil::MaxSparseElements(shape.layout())) { + return Error( + lexer_.GetLoc(), + StrCat("number of sparse elements exceeds maximum for layout: ", + ShapeUtil::HumanStringWithLayout(shape))); + } + + (*literal)->AppendSparseElement(index, value); + } + + (*literal)->SortSparseElements(); + return true; +} + // operands ::= '(' operands1 ')' // operands1 // ::= /*empty*/ // ::= operand (, operand)* -// operand ::= shape name +// operand ::= (shape)? name bool HloParser::ParseOperands(std::vector* operands) { if (!ParseToken(TokKind::kLparen, "expects '(' at the beginning of operands")) { @@ -1327,15 +1694,21 @@ bool HloParser::ParseOperands(std::vector* operands) { // empty } else { do { - Shape shape; + LocTy loc = lexer_.GetLoc(); string name; - if (!ParseShape(&shape) || !ParseName(&name)) { + if (CanBeShape()) { + Shape shape; + if (!ParseShape(&shape)) { + return false; + } + } + if (!ParseName(&name)) { return false; } HloInstruction* instruction = tensorflow::gtl::FindPtrOrNull(instruction_pool_, name); if (!instruction) { - return TokenError(StrCat("instruction does not exist: ", name)); + return Error(loc, StrCat("instruction does not exist: ", name)); } operands->push_back(instruction); } while (EatIfPresent(TokKind::kComma)); @@ -1345,11 +1718,12 @@ bool HloParser::ParseOperands(std::vector* operands) { bool HloParser::ParseOperands(std::vector* operands, const int expected_size) { + LocTy loc = lexer_.GetLoc(); if (!ParseOperands(operands)) { return false; } if (expected_size != operands->size()) { - return TokenError(StrCat("expects ", expected_size, " operands, but has ", + return Error(loc, StrCat("expects ", expected_size, " operands, but has ", operands->size(), " operands")); } return true; @@ -1358,6 +1732,7 @@ bool HloParser::ParseOperands(std::vector* operands, // sub_attributes ::= '{' (','? attribute)* '}' bool HloParser::ParseSubAttributes( const std::unordered_map& attrs) { + LocTy loc = lexer_.GetLoc(); if (!ParseToken(TokKind::kLbrace, "expects '{' to start sub attributes")) { return false; } @@ -1376,7 +1751,7 @@ bool HloParser::ParseSubAttributes( for (const auto& attr_it : attrs) { if (attr_it.second.required && seen_attrs.find(attr_it.first) == seen_attrs.end()) { - return TokenError(Printf("sub-attribute %s is expected but not seen", + return Error(loc, Printf("sub-attribute %s is expected but not seen", attr_it.first.c_str())); } } @@ -1386,6 +1761,7 @@ bool HloParser::ParseSubAttributes( // attributes ::= (',' attribute)* bool HloParser::ParseAttributes( const std::unordered_map& attrs) { + LocTy loc = lexer_.GetLoc(); std::unordered_set seen_attrs; while (EatIfPresent(TokKind::kComma)) { if (!ParseAttributeHelper(attrs, &seen_attrs)) { @@ -1396,7 +1772,7 @@ bool HloParser::ParseAttributes( for (const auto& attr_it : attrs) { if (attr_it.second.required && seen_attrs.find(attr_it.first) == seen_attrs.end()) { - return TokenError(Printf("attribute %s is expected but not seen", + return Error(loc, Printf("attribute %s is expected but not seen", attr_it.first.c_str())); } } @@ -1406,21 +1782,23 @@ bool HloParser::ParseAttributes( bool HloParser::ParseAttributeHelper( const std::unordered_map& attrs, std::unordered_set* seen_attrs) { + LocTy loc = lexer_.GetLoc(); string name; if (!ParseAttributeName(&name)) { - return TokenError("error parsing attributes"); + return Error(loc, "error parsing attributes"); } VLOG(1) << "Parsing attribute " << name; if (!seen_attrs->insert(name).second) { - return TokenError(Printf("attribute %s already exists", name.c_str())); + return Error(loc, Printf("attribute %s already exists", name.c_str())); } auto attr_it = attrs.find(name); if (attr_it == attrs.end()) { - return TokenError(Printf("unexpected attribute %s", name.c_str())); + return Error(loc, Printf("unexpected attribute %s", name.c_str())); } AttrTy attr_type = attr_it->second.attr_type; void* attr_out_ptr = attr_it->second.result; bool success = [&] { + LocTy attr_loc = lexer_.GetLoc(); switch (attr_type) { case AttrTy::kInt64: { int64 result; @@ -1436,7 +1814,7 @@ bool HloParser::ParseAttributeHelper( return false; } if (result != static_cast(result)) { - return TokenError("value out of range for int32"); + return Error(attr_loc, "value out of range for int32"); } static_cast*>(attr_out_ptr) ->emplace(static_cast(result)); @@ -1449,7 +1827,7 @@ bool HloParser::ParseAttributeHelper( } if (result > std::numeric_limits::max() || result < std::numeric_limits::lowest()) { - return TokenError("value out of range for float"); + return Error(attr_loc, "value out of range for float"); } static_cast*>(attr_out_ptr) ->emplace(static_cast(result)); @@ -1463,6 +1841,14 @@ bool HloParser::ParseAttributeHelper( static_cast*>(attr_out_ptr)->emplace(result); return true; } + case AttrTy::kFftType: { + FftType result; + if (!ParseFftType(&result)) { + return false; + } + static_cast*>(attr_out_ptr)->emplace(result); + return true; + } case AttrTy::kWindow: { Window result; if (!ParseWindow(&result)) { @@ -1548,22 +1934,32 @@ bool HloParser::ParseAttributeHelper( static_cast*>(attr_out_ptr)->emplace(result); return true; } + case AttrTy::kDistribution: { + RandomDistribution result; + if (!ParseRandomDistribution(&result)) { + return false; + } + static_cast*>(attr_out_ptr) + ->emplace(result); + return true; + } } }(); if (!success) { - return TokenError(Printf("error parsing attribute %s", name.c_str())); + return Error(loc, Printf("error parsing attribute %s", name.c_str())); } return true; } bool HloParser::ParseComputationName(HloComputation** value) { string name; + LocTy loc = lexer_.GetLoc(); if (!ParseName(&name)) { - return TokenError("expects computation name"); + return Error(loc, "expects computation name"); } *value = tensorflow::gtl::FindPtrOrNull(computation_pool_, name); if (*value == nullptr) { - return TokenError(StrCat("computation does not exist: ", name)); + return Error(loc, StrCat("computation does not exist: ", name)); } return true; } @@ -1572,6 +1968,7 @@ bool HloParser::ParseComputationName(HloComputation** value) { // The subattributes can appear in any order. 'size=' is required, others are // optional. bool HloParser::ParseWindow(Window* window) { + LocTy loc = lexer_.GetLoc(); if (!ParseToken(TokKind::kLbrace, "expected '{' to start window attribute")) { return false; } @@ -1581,10 +1978,12 @@ bool HloParser::ParseWindow(Window* window) { std::vector> pad; std::vector lhs_dilate; std::vector rhs_dilate; + std::vector rhs_reversal; while (lexer_.GetKind() != TokKind::kRbrace) { + LocTy attr_loc = lexer_.GetLoc(); string field_name; if (!ParseAttributeName(&field_name)) { - return TokenError("expects sub-attributes in window"); + return Error(attr_loc, "expects sub-attributes in window"); } bool ok = [&] { if (field_name == "size") { @@ -1602,7 +2001,10 @@ bool HloParser::ParseWindow(Window* window) { if (field_name == "pad") { return ParseWindowPad(&pad); } - return TokenError(StrCat("unexpected attribute name: ", field_name)); + if (field_name == "rhs_reversal") { + return ParseDxD("rhs_reversal", &rhs_reversal); + } + return Error(attr_loc, StrCat("unexpected attribute name: ", field_name)); }(); if (!ok) { return false; @@ -1610,20 +2012,20 @@ bool HloParser::ParseWindow(Window* window) { } if (size.empty()) { - return TokenError( - "sub-attribute 'size=' is required in the window attribute"); + return Error(loc, + "sub-attribute 'size=' is required in the window attribute"); } if (!stride.empty() && stride.size() != size.size()) { - return TokenError("expects 'stride=' has the same size as 'size='"); + return Error(loc, "expects 'stride=' has the same size as 'size='"); } if (!lhs_dilate.empty() && lhs_dilate.size() != size.size()) { - return TokenError("expects 'lhs_dilate=' has the same size as 'size='"); + return Error(loc, "expects 'lhs_dilate=' has the same size as 'size='"); } if (!rhs_dilate.empty() && rhs_dilate.size() != size.size()) { - return TokenError("expects 'rhs_dilate=' has the same size as 'size='"); + return Error(loc, "expects 'rhs_dilate=' has the same size as 'size='"); } if (!pad.empty() && pad.size() != size.size()) { - return TokenError("expects 'pad=' has the same size as 'size='"); + return Error(loc, "expects 'pad=' has the same size as 'size='"); } for (int i = 0; i < size.size(); i++) { @@ -1638,6 +2040,8 @@ bool HloParser::ParseWindow(Window* window) { lhs_dilate.empty() ? 1 : lhs_dilate[i]); window->mutable_dimensions(i)->set_window_dilation( rhs_dilate.empty() ? 1 : rhs_dilate[i]); + window->mutable_dimensions(i)->set_window_reversal( + rhs_reversal.empty() ? false : (rhs_reversal[i] == 1)); } return ParseToken(TokKind::kRbrace, "expected '}' to end window attribute"); } @@ -1783,20 +2187,19 @@ bool HloParser::ParseSliceRanges(SliceRanges* result) { return ParseToken(TokKind::kRbrace, "expects '}' to end ranges"); } do { + LocTy loc = lexer_.GetLoc(); ranges.emplace_back(); if (!ParseInt64List(TokKind::kLsquare, TokKind::kRsquare, TokKind::kColon, &ranges.back())) { return false; } - } while (EatIfPresent(TokKind::kComma)); - - for (const auto& range : ranges) { + const auto& range = ranges.back(); if (range.size() != 2 && range.size() != 3) { - return TokenError(Printf( - "expects [start:limit:step] or [start:limit], but sees %ld elements.", - range.size())); + return Error(loc, Printf("expects [start:limit:step] or [start:limit], " + "but sees %ld elements.", + range.size())); } - } + } while (EatIfPresent(TokKind::kComma)); for (const auto& range : ranges) { result->starts.push_back(range[0]); @@ -1832,6 +2235,19 @@ bool HloParser::ParseInt64List(const TokKind start, const TokKind end, end, StrCat("expects an int64 list to end with ", TokKindToString(end))); } +// param_list_to_shape ::= param_list '->' shape +bool HloParser::ParseParamListToShape(Shape* shape, LocTy* shape_loc) { + if (!ParseParamList() || !ParseToken(TokKind::kArrow, "expects '->'")) { + return false; + } + *shape_loc = lexer_.GetLoc(); + return ParseShape(shape); +} + +bool HloParser::CanBeParamListToShape() { + return lexer_.GetKind() == TokKind::kLparen; +} + // param_list ::= '(' param_list1 ')' // param_list1 // ::= /*empty*/ @@ -1848,8 +2264,8 @@ bool HloParser::ParseParamList() { } else { do { Shape shape; - if (!ParseToken(TokKind::kName, "expects name in parameter") || - !ParseShape(&shape)) { + string name; + if (!ParseName(&name) || !ParseShape(&shape)) { return false; } } while (EatIfPresent(TokKind::kComma)); @@ -1888,9 +2304,17 @@ bool HloParser::ParseShape(Shape* result) { return true; } +bool HloParser::CanBeShape() { + // A non-tuple shape starts with a kShape token; a tuple shape starts with + // '('. + return lexer_.GetKind() == TokKind::kShape || + lexer_.GetKind() == TokKind::kLparen; +} + bool HloParser::ParseName(string* result) { VLOG(1) << "ParseName"; - if (lexer_.GetKind() != TokKind::kName) { + if (lexer_.GetKind() != TokKind::kIdent && + lexer_.GetKind() != TokKind::kName) { return TokenError("expects name"); } *result = lexer_.GetStrVal(); @@ -1918,15 +2342,16 @@ bool HloParser::ParseString(string* result) { } bool HloParser::ParseDxD(const string& name, std::vector* result) { + LocTy loc = lexer_.GetLoc(); if (!result->empty()) { - return TokenError( - Printf("sub-attribute '%s=' already exists", name.c_str())); + return Error(loc, + Printf("sub-attribute '%s=' already exists", name.c_str())); } // 1D if (lexer_.GetKind() == TokKind::kInt) { int64 number; if (!ParseInt64(&number)) { - return TokenError(Printf("expects sub-attribute '%s=i'", name.c_str())); + return Error(loc, Printf("expects sub-attribute '%s=i'", name.c_str())); } result->push_back(number); return true; @@ -1935,8 +2360,8 @@ bool HloParser::ParseDxD(const string& name, std::vector* result) { if (lexer_.GetKind() == TokKind::kDxD) { string str = lexer_.GetStrVal(); if (!SplitAndParseAsInts(str, 'x', result)) { - return TokenError( - Printf("expects sub-attribute '%s=ixj...'", name.c_str())); + return Error(loc, + Printf("expects sub-attribute '%s=ixj...'", name.c_str())); } lexer_.Lex(); return true; @@ -1945,8 +2370,9 @@ bool HloParser::ParseDxD(const string& name, std::vector* result) { } bool HloParser::ParseWindowPad(std::vector>* pad) { + LocTy loc = lexer_.GetLoc(); if (!pad->empty()) { - return TokenError("sub-attribute 'pad=' already exists"); + return Error(loc, "sub-attribute 'pad=' already exists"); } if (lexer_.GetKind() != TokKind::kPad) { return TokenError("expects window pad pattern, e.g., '0_0x3_3'"); @@ -1957,8 +2383,8 @@ bool HloParser::ParseWindowPad(std::vector>* pad) { std::vector low_high; if (!SplitAndParseAsInts(padding_str[i], '_', &low_high) || low_high.size() != 2) { - return TokenError( - "expects padding_low and padding_high separated by '_'"); + return Error(loc, + "expects padding_low and padding_high separated by '_'"); } pad->push_back(low_high); } @@ -1974,15 +2400,16 @@ bool HloParser::ParsePaddingConfig(PaddingConfig* padding) { if (lexer_.GetKind() != TokKind::kPad) { return TokenError("expects padding config, e.g., '0_0_0x3_3_1'"); } + LocTy loc = lexer_.GetLoc(); string str = lexer_.GetStrVal(); std::vector padding_str = Split(str, 'x'); for (const auto& padding_dim_str : padding_str) { std::vector padding_dim; if (!SplitAndParseAsInts(padding_dim_str, '_', &padding_dim) || (padding_dim.size() != 2 && padding_dim.size() != 3)) { - return TokenError( - "expects padding config pattern like 'low_high_interior' or " - "'low_high'"); + return Error(loc, + "expects padding config pattern like 'low_high_interior' or " + "'low_high'"); } auto* dim = padding->add_dimensions(); dim->set_edge_padding_low(padding_dim[0]); @@ -2024,20 +2451,64 @@ bool HloParser::ParseMetadata(OpMetadata* metadata) { bool HloParser::ParseOpcode(HloOpcode* result) { VLOG(1) << "ParseOpcode"; - if (lexer_.GetKind() != TokKind::kOpcode) { + if (lexer_.GetKind() != TokKind::kIdent) { return TokenError("expects opcode"); } - *result = lexer_.GetOpcodeVal(); + string val = lexer_.GetStrVal(); + auto status_or_result = StringToHloOpcode(val); + if (!status_or_result.ok()) { + return TokenError( + Printf("expects opcode but sees: %s, error: %s", val.c_str(), + status_or_result.status().error_message().c_str())); + } + *result = status_or_result.ValueOrDie(); + lexer_.Lex(); + return true; +} + +bool HloParser::ParseFftType(FftType* result) { + VLOG(1) << "ParseFftType"; + if (lexer_.GetKind() != TokKind::kIdent) { + return TokenError("expects fft type"); + } + string val = lexer_.GetStrVal(); + if (!FftType_Parse(val, result) || !FftType_IsValid(*result)) { + return TokenError(Printf("expects fft type but sees: %s", val.c_str())); + } lexer_.Lex(); return true; } bool HloParser::ParseFusionKind(HloInstruction::FusionKind* result) { VLOG(1) << "ParseFusionKind"; - if (lexer_.GetKind() != TokKind::kFusionKind) { + if (lexer_.GetKind() != TokKind::kIdent) { return TokenError("expects fusion kind"); } - *result = lexer_.GetFusionKindVal(); + string val = lexer_.GetStrVal(); + auto status_or_result = StringToFusionKind(val); + if (!status_or_result.ok()) { + return TokenError( + Printf("expects fusion kind but sees: %s, error: %s", val.c_str(), + status_or_result.status().error_message().c_str())); + } + *result = status_or_result.ValueOrDie(); + lexer_.Lex(); + return true; +} + +bool HloParser::ParseRandomDistribution(RandomDistribution* result) { + VLOG(1) << "ParseRandomDistribution"; + if (lexer_.GetKind() != TokKind::kIdent) { + return TokenError("expects random distribution"); + } + string val = lexer_.GetStrVal(); + auto status_or_result = StringToRandomDistribution(val); + if (!status_or_result.ok()) { + return TokenError( + Printf("expects random distribution but sees: %s, error: %s", + val.c_str(), status_or_result.status().error_message().c_str())); + } + *result = status_or_result.ValueOrDie(); lexer_.Lex(); return true; } @@ -2103,20 +2574,20 @@ bool HloParser::EatIfPresent(TokKind kind) { return true; } -bool HloParser::AddInstruction(const string& name, - HloInstruction* instruction) { +bool HloParser::AddInstruction(const string& name, HloInstruction* instruction, + LocTy name_loc) { auto result = instruction_pool_.insert({name, instruction}); if (!result.second) { - return TokenError(StrCat("instruction already exists: ", name)); + return Error(name_loc, StrCat("instruction already exists: ", name)); } return true; } -bool HloParser::AddComputation(const string& name, - HloComputation* computation) { +bool HloParser::AddComputation(const string& name, HloComputation* computation, + LocTy name_loc) { auto result = computation_pool_.insert({name, computation}); if (!result.second) { - return TokenError(StrCat("computation already exists: ", name)); + return Error(name_loc, StrCat("computation already exists: ", name)); } return true; } @@ -2127,7 +2598,7 @@ StatusOr> Parse(StringPiece str, const HloModuleConfig& config) { HloParser parser(str, config); if (!parser.Run()) { - return InvalidArgument("Syntax error: %s", parser.GetError().c_str()); + return InvalidArgument("Syntax error:\n%s", parser.GetError().c_str()); } return parser.ConsumeHloModule(); } diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc index 90cdb87a1ebcf59d291eebd52963a130f19f4403..dd76d8d0fee7cdfa22829fe92ff889e44157216e 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc @@ -25,7 +25,6 @@ namespace tools { namespace { using tensorflow::StringPiece; -using tensorflow::strings::StrCat; struct TestData { string test_name; @@ -46,7 +45,7 @@ std::vector CreateTestCases() { // ax + y { "AxpyParam", -R"(HloModule axpy_module: +R"(HloModule axpy_module ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] { %alpha = f32[] parameter(0) @@ -62,7 +61,7 @@ ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] { // pred constant { "ConstantPred", -R"(HloModule constant_pred_module: +R"(HloModule constant_pred_module ENTRY %constant_pred () -> pred[] { ROOT %constant = pred[] constant(true), metadata={op_type="const" op_name="\"it\'s not a problem\n" source_file="path/to/test.cc" source_line=68} @@ -73,7 +72,7 @@ ENTRY %constant_pred () -> pred[] { // s32 constant { "ConstantS32", -R"(HloModule constant_s32_module: +R"(HloModule constant_s32_module ENTRY %constant_s32 () -> s32[] { ROOT %constant = s32[] constant(-42) @@ -84,7 +83,7 @@ ENTRY %constant_s32 () -> s32[] { // f32 constant, but the value is not a decimal { "ConstantF32", -R"(HloModule ConstantF32_module: +R"(HloModule ConstantF32_module ENTRY %ConstantF32.v4 () -> f32[] { ROOT %constant = f32[] constant(42) @@ -95,7 +94,7 @@ ENTRY %ConstantF32.v4 () -> f32[] { // f32 constant, rank 1 empty array. { "ConstantF32R1Empty", -R"(HloModule ConstantF32Empty_module: +R"(HloModule ConstantF32Empty_module ENTRY %ConstantF32Empty.v4 () -> f32[0] { ROOT %constant = f32[0]{0} constant({}) @@ -106,7 +105,7 @@ ENTRY %ConstantF32Empty.v4 () -> f32[0] { // f32 constant, rank 4 empty array. { "ConstantF32R4Empty", -R"(HloModule ConstantF32R4Empty_module: +R"(HloModule ConstantF32R4Empty_module ENTRY %ConstantF32R4Empty.v4 () -> f32[2,0,4,3] { ROOT %constant = f32[2,0,4,3]{3,2,1,0} constant(f32[2,0,4,3] { { /*i0=0*/ }, { /*i0=1*/ } }) @@ -117,7 +116,7 @@ ENTRY %ConstantF32R4Empty.v4 () -> f32[2,0,4,3] { // constant 4D { "Constant4D", -R"(HloModule Small_3x2x1x1_module: +R"(HloModule Small_3x2x1x1_module ENTRY %Small_3x2x1x1.v1 () -> f32[3,2,1,1] { ROOT %constant = f32[3,2,1,1]{3,2,1,0} constant(f32[3,2,1,1] { { /*i0=0*/ { /*i1=0*/ {-1} }, { /*i1=1*/ {4.1} } }, { /*i0=1*/ { /*i1=0*/ {2} }, { /*i1=1*/ {4.1} } }, { /*i0=2*/ { /*i1=0*/ {5} }, { /*i1=1*/ {4.4} } } }) @@ -128,7 +127,7 @@ ENTRY %Small_3x2x1x1.v1 () -> f32[3,2,1,1] { // non-finite constants: nan, inf, -inf { "ConstantNonFinite", -R"(HloModule IsFiniteR1F32s_module: +R"(HloModule IsFiniteR1F32s_module ENTRY %IsFiniteR1F32s.v2 () -> pred[6] { %constant = f32[6]{0} constant({nan, 7, nan, -1, inf, -inf}) @@ -140,18 +139,29 @@ ENTRY %IsFiniteR1F32s.v2 () -> pred[6] { // constant f16 { "ConstantF16", -R"(HloModule ConstantF16_module: +R"(HloModule ConstantF16_module ENTRY %ConstantF16.v4 () -> f16[] { ROOT %constant = f16[] constant(500) } +)" +}, +// bf16 +{ +"BF16", +R"(HloModule BF16 + +ENTRY %BF16.v4 () -> bf16[] { + ROOT %constant = bf16[] constant(500) +} + )" }, // constant + constant { "AddConstants", -R"(HloModule add_constants_module: +R"(HloModule add_constants_module ENTRY %add_constants () -> f32[] { %constant = f32[] constant(3.14) @@ -163,7 +173,7 @@ ENTRY %add_constants () -> f32[] { // tuple constant { "TupleConstant", -R"(HloModule TupleConstant_module: +R"(HloModule TupleConstant_module ENTRY %TupleConstant.v1 () -> (f32[2,1], f32[2]) { ROOT %constant = (f32[2,1]{1,0}, f32[2]{0}) constant((f32[2,1], f32[2]) ( f32[2,1] { { 1 }, { 2 } }, {2, 42} )) @@ -174,7 +184,7 @@ ENTRY %TupleConstant.v1 () -> (f32[2,1], f32[2]) { // v1 > v2 ? v1 : v2 { "SelectR1F32", -R"(HloModule SelectR1F32WithCmpR1F32sFromParamsSmall_module: +R"(HloModule SelectR1F32WithCmpR1F32sFromParamsSmall_module ENTRY %SelectR1F32WithCmpR1F32sFromParamsSmall.v4 (v1: f32[4], v2: f32[4]) -> f32[4] { %v1 = f32[4]{0} parameter(0), sharding={maximal device=1} @@ -188,7 +198,7 @@ ENTRY %SelectR1F32WithCmpR1F32sFromParamsSmall.v4 (v1: f32[4], v2: f32[4]) -> f3 // empty tuple { "EmptyTupleCreate", -R"(HloModule EmptyTupleCreate_module: +R"(HloModule EmptyTupleCreate_module ENTRY %EmptyTupleCreate.v1 () -> () { ROOT %tuple = () tuple() @@ -199,7 +209,7 @@ ENTRY %EmptyTupleCreate.v1 () -> () { // tuple { "TupleCreate", -R"(HloModule TupleCreate_module: +R"(HloModule TupleCreate_module ENTRY %TupleCreate.v4 (v1: f32[], v2: f32[3], v3: f32[2,3]) -> (f32[], f32[3], f32[2,3]) { %v1 = f32[] parameter(0) @@ -212,7 +222,7 @@ ENTRY %TupleCreate.v4 (v1: f32[], v2: f32[3], v3: f32[2,3]) -> (f32[], f32[3], f }, { "ShardedTupleCreate", -R"(HloModule ShardedTupleCreate_module: +R"(HloModule ShardedTupleCreate_module ENTRY %ShardedTupleCreate.v4 (v1: f32[], v2: f32[3], v3: f32[2,3]) -> (f32[], f32[3], f32[2,3]) { %v1 = f32[] parameter(0) @@ -227,7 +237,7 @@ ENTRY %ShardedTupleCreate.v4 (v1: f32[], v2: f32[3], v3: f32[2,3]) -> (f32[], f3 // while (result < 5) { result = result + 1; } { "WhileWithScalarS32Result", -R"(HloModule WhileWithScalarS32Result_module: +R"(HloModule WhileWithScalarS32Result_module %body.v3 (prev.1: s32[]) -> s32[] { %constant = s32[] constant(1) @@ -251,7 +261,7 @@ ENTRY %WhileWithScalarS32Result.v2 () -> s32[] { // send and recv { "SendRecv", -R"(HloModule TwoSendRecvBothWayRecvFist_module: +R"(HloModule TwoSendRecvBothWayRecvFist_module ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { %recv = (f32[], u32[]) recv(), channel_id=15, sharding={maximal device=1} @@ -266,7 +276,7 @@ ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { // get-tuple-element { "GetTupleElement", -R"(HloModule GetTupleElement_module: +R"(HloModule GetTupleElement_module ENTRY %GetTupleElement.v4 () -> s32[2,3] { %constant = f32[3]{0} constant({1, 2, 3}) @@ -280,7 +290,7 @@ ENTRY %GetTupleElement.v4 () -> s32[2,3] { // call { "Call", -R"(HloModule CallR0F32IdentityScalar_module: +R"(HloModule CallR0F32IdentityScalar_module %Identity.v1 (x: f32[]) -> f32[] { ROOT %x = f32[] parameter(0) @@ -296,7 +306,7 @@ ENTRY %CallR0F32IdentityScalar.v2 () -> f32[] { // reduce window { "ReduceWindow", -R"(HloModule R4UnitWindow_module: +R"(HloModule R4UnitWindow_module %add_F32.v3 (lhs: f32[], rhs: f32[]) -> f32[] { %lhs = f32[] parameter(0) @@ -315,7 +325,7 @@ ENTRY %R4UnitWindow.v3 (operand: f32[13,12,8,15]) -> f32[13,3,8,15] { // reduce window on scalar { "ReduceWindowScalar", -R"(HloModule reduce_window_scalar: +R"(HloModule reduce_window_scalar %add_F32.v3 (lhs: f32[], rhs: f32[]) -> f32[] { %lhs = f32[] parameter(0) @@ -334,7 +344,7 @@ ENTRY %R4UnitWindowScalar () -> f32[] { // convolution { "Convolution", -R"(HloModule Convolve1D1Window_0_module: +R"(HloModule Convolve1D1Window_0_module ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2,1] { %input = f32[1,2,1]{2,1,0} parameter(0) @@ -348,7 +358,7 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2 // convolution rank 2 { "ConvolutionR2", -R"(HloModule ConvolveR2_module: +R"(HloModule ConvolveR2_module ENTRY %ConvolveR2.v3 (input: f32[1,2], filter: f32[1,1]) -> f32[1,2] { %input = f32[1,2]{1,0} parameter(0) @@ -356,12 +366,25 @@ ENTRY %ConvolveR2.v3 (input: f32[1,2], filter: f32[1,1]) -> f32[1,2] { ROOT %convolution = f32[1,2]{0,1} convolution(f32[1,2]{1,0} %input, f32[1,1]{1,0} %filter), dim_labels=bf_io->bf } +)" +}, +// convolution backward +{ +"ConvolutionBackward", +R"(HloModule ConvolveBackward_module + +ENTRY %ConvolveBackward (input: f32[128,7,7,512], filter: f32[3,3,512,512]) -> f32[128,14,14,512] { + %input = f32[128,7,7,512]{0,3,2,1} parameter(0) + %filter = f32[3,3,512,512]{3,2,1,0} parameter(1) + ROOT %convolution-base-dilated = f32[128,14,14,512]{0,3,2,1} convolution(f32[128,7,7,512]{0,3,2,1} %input, f32[3,3,512,512]{3,2,1,0} %filter), window={size=3x3 pad=1_2x1_2 lhs_dilate=2x2 rhs_reversal=1x1}, dim_labels=b01f_01oi->b01f +} + )" }, // reverse(constant) { "Reverse4D", -R"(HloModule Reverse4DFloatArrayOnDim01_module: +R"(HloModule Reverse4DFloatArrayOnDim01_module ENTRY %Reverse4DFloatArrayOnDim01.v2 () -> f32[4,3,2,1] { %constant = f32[4,3,2,1]{0,1,2,3} constant(f32[4,3,2,1] { { /*i0=0*/ { /*i1=0*/ {1}, {2} }, { /*i1=1*/ {3}, {4} }, { /*i1=2*/ {5}, {6} } }, { /*i0=1*/ { /*i1=0*/ {7}, {8} }, { /*i1=1*/ {9}, {10} }, { /*i1=2*/ {11}, {12} } }, { /*i0=2*/ { /*i1=0*/ {13}, {14} }, { /*i1=1*/ {15}, {16} }, { /*i1=2*/ {17}, {18} } }, { /*i0=3*/ { /*i1=0*/ {19}, {20} }, { /*i1=1*/ {21}, {22} }, { /*i1=2*/ {23}, {24} } } }) @@ -373,7 +396,7 @@ ENTRY %Reverse4DFloatArrayOnDim01.v2 () -> f32[4,3,2,1] { // concat { "Concat", -R"(HloModule Concat2x3With2x5_module: +R"(HloModule Concat2x3With2x5_module ENTRY %Concat2x3With2x5.v3 () -> f32[2,8] { %constant = f32[2,3]{1,0} constant(f32[2,3] { { 0, 1, 2 }, { 1000, 1001, 1002 } }) @@ -381,50 +404,12 @@ ENTRY %Concat2x3With2x5.v3 () -> f32[2,8] { ROOT %concatenate = f32[2,8]{1,0} concatenate(f32[2,3]{1,0} %constant, f32[2,5]{1,0} %constant.1), dimensions={1} } -)" -}, -// map -{ -"Map", -R"(HloModule MapBinaryAdder_module: - -%add_F32.v3 (lhs: f32[], rhs: f32[]) -> f32[] { - %lhs = f32[] parameter(0) - %rhs = f32[] parameter(1) - ROOT %add = f32[] add(f32[] %lhs, f32[] %rhs) -} - -ENTRY %MapBinaryAdder.v3 (param0: f32[4], param1: f32[4]) -> f32[4] { - %param0 = f32[4]{0} parameter(0) - %param1 = f32[4]{0} parameter(1) - ROOT %map = f32[4]{0} map(f32[4]{0} %param0, f32[4]{0} %param1), to_apply=%add_F32.v3 -} - -)" -}, -// reduce -{ -"Reduce", -R"(HloModule ReduceR3ToR2_module: - -%add_F32.v3 (lhs: f32[], rhs: f32[]) -> f32[] { - %lhs = f32[] parameter(0) - %rhs = f32[] parameter(1) - ROOT %add = f32[] add(f32[] %lhs, f32[] %rhs) -} - -ENTRY %ReduceR3ToR2.v3 (input: f32[8,16,256]) -> f32[8,16] { - %input = f32[8,16,256]{2,1,0} parameter(0) - %constant = f32[] constant(0) - ROOT %reduce = f32[8,16]{1,0} reduce(f32[8,16,256]{2,1,0} %input, f32[] %constant), dimensions={2}, to_apply=%add_F32.v3 -} - )" }, // select and scatter { "SelectAndScatter", -R"(HloModule R4F32OverlapSmall_module: +R"(HloModule R4F32OverlapSmall_module %ge_F32.v3 (lhs: f32[], rhs: f32[]) -> pred[] { %lhs = f32[] parameter(0) @@ -450,7 +435,7 @@ ENTRY %R4F32OverlapSmall.v4 () -> f32[4,5,1,1] { // select and scatter on scalar { "SelectAndScatterScalar", -R"(HloModule select_and_scatter_scalar: +R"(HloModule select_and_scatter_scalar %ge_F32.v3 (lhs: f32[], rhs: f32[]) -> pred[] { %lhs = f32[] parameter(0) @@ -476,7 +461,7 @@ ENTRY %SelectAndScatterScalar () -> f32[] { // slice { "Slice", -R"(HloModule slice_module: +R"(HloModule slice_module ENTRY %slice.v2 (p0: f32[3,3,4,4]) -> f32[3,3,2,4] { %p0 = f32[3,3,4,4]{3,2,1,0} parameter(0) @@ -488,7 +473,7 @@ ENTRY %slice.v2 (p0: f32[3,3,4,4]) -> f32[3,3,2,4] { // slice, no stride { "SliceNoStride", -R"(HloModule Slice3x3x3_To_1x3x3_F32_module: +R"(HloModule Slice3x3x3_To_1x3x3_F32_module ENTRY %Slice3x3x3_To_1x3x3_F32.v2 () -> f32[1,3,3] { %constant = f32[3,3,3]{2,1,0} constant(f32[3,3,3] { { { 0, 1, 2 }, { 3, 4, 5 }, { 6, 7, 8 } }, { { 9, 10, 11 }, { 12, 13, 14 }, { 15, 16, 17 } }, { { 18, 19, 20 }, { 21, 22, 23 }, { 24, 25, 26 } } }) @@ -500,7 +485,7 @@ ENTRY %Slice3x3x3_To_1x3x3_F32.v2 () -> f32[1,3,3] { // slice R0 { "SliceR0", -R"(HloModule SliceR0_module: +R"(HloModule SliceR0_module ENTRY %SliceR0.v2 () -> s32[] { %constant = s32[] constant(1) @@ -512,7 +497,7 @@ ENTRY %SliceR0.v2 () -> s32[] { // transpose { "Transpose", -R"(HloModule Transpose_module: +R"(HloModule Transpose_module ENTRY %Transpose.v2 () -> s32[1,2,3] { %constant = s32[1,2,3]{2,1,0} constant(s32[1,2,3] { { { 1, 2, 3 }, { 4, 5, 6 } } }) @@ -524,7 +509,7 @@ ENTRY %Transpose.v2 () -> s32[1,2,3] { // Dynamic slice { "DynamicSlice", -R"(HloModule DynamicSlice_module: +R"(HloModule DynamicSlice_module ENTRY %DynamicSlice.v5 (original_parameter: s32[2,2,258], start_index: s32[1]) -> s32[2,2,258] { %original_parameter = s32[2,2,258]{2,1,0} parameter(0) @@ -539,7 +524,7 @@ ENTRY %DynamicSlice.v5 (original_parameter: s32[2,2,258], start_index: s32[1]) - // Dynamic update slice { "DynamicUpdateSlice", -R"(HloModule DynamicUpdateSlice_module: +R"(HloModule DynamicUpdateSlice_module ENTRY %DynamicUpdateSlice.v4 (input: s32[1,1,25,1], update: s32[1,1,2,1], start_indices: s32[4]) -> s32[1,1,25,1] { %input = s32[1,1,25,1]{3,2,1,0} parameter(0) @@ -553,7 +538,7 @@ ENTRY %DynamicUpdateSlice.v4 (input: s32[1,1,25,1], update: s32[1,1,2,1], start_ // batch norm training { "BatchNormTraining", -R"(HloModule BasicTraining_module: +R"(HloModule BasicTraining_module ENTRY %BasicTraining.v4 () -> (f32[2,2,1,2], f32[2], f32[2]) { %constant = f32[2,2,1,2]{3,2,1,0} constant(f32[2,2,1,2] { { /*i0=0*/ { /*i1=0*/ {1, 2} }, { /*i1=1*/ {3, 4} } }, { /*i0=1*/ { /*i1=0*/ {5, 6} }, { /*i1=1*/ {7, 8} } } }) @@ -567,7 +552,7 @@ ENTRY %BasicTraining.v4 () -> (f32[2,2,1,2], f32[2], f32[2]) { // batch norm inference { "BatchNormInference", -R"(HloModule BatchNormInference_module: +R"(HloModule BatchNormInference_module ENTRY %BatchNormInference.v6 (input: f32[2,2,2,2], offset: f32[2], scale: f32[2], mean: f32[2], variance: f32[2]) -> f32[2,2,2,2] { %input = f32[2,2,2,2]{3,2,1,0} parameter(0) @@ -583,7 +568,7 @@ ENTRY %BatchNormInference.v6 (input: f32[2,2,2,2], offset: f32[2], scale: f32[2] // batch norm grad { "BatchNormGrad", -R"(HloModule BatchNormGrad_module: +R"(HloModule BatchNormGrad_module ENTRY %BatchNormGrad.v4 (input: f32[2,2,2,2], scale: f32[2], mean: f32[2], variance: f32[2], grad_output: f32[2,2,2,2]) -> (f32[2,2,2,2], f32[2], f32[2]) { %input = f32[2,2,2,2]{3,2,1,0} parameter(0) @@ -594,12 +579,60 @@ ENTRY %BatchNormGrad.v4 (input: f32[2,2,2,2], scale: f32[2], mean: f32[2], varia ROOT %batch-norm-grad = (f32[2,2,2,2]{3,2,1,0}, f32[2]{0}, f32[2]{0}) batch-norm-grad(f32[2,2,2,2]{3,2,1,0} %input, f32[2]{0} %scale, f32[2]{0} %mean, f32[2]{0} %variance, f32[2,2,2,2]{3,2,1,0} %grad_output), epsilon=0.001, feature_index=0 } +)" +}, +// fft +{ +"Fft", +R"(HloModule Fft_module + +ENTRY %Fft (input: c64[8,32]) -> c64[8,32] { + %input = c64[8,32]{1,0} parameter(0) + ROOT %fft = c64[8,32]{1,0} fft(c64[8,32]{1,0} %input), fft_type=FFT, fft_length={32} +} + +)" +}, +// ifft +{ +"Ifft2d", +R"(HloModule Ifft2d_module + +ENTRY %Ifft2d (input: c64[5,8,32]) -> c64[5,8,32] { + %input = c64[5,8,32]{2,1,0} parameter(0) + ROOT %fft = c64[5,8,32]{2,1,0} fft(c64[5,8,32]{2,1,0} %input), fft_type=IFFT, fft_length={8,32} +} + +)" +}, +// rfft2d +{ +"Rfft2d", +R"(HloModule Rfft2d_module + +ENTRY %Rfft2d (input: f32[5,64,32]) -> c64[5,64,17] { + %input = f32[5,64,32]{2,1,0} parameter(0) + ROOT %fft = c64[5,64,17]{2,1,0} fft(f32[5,64,32]{2,1,0} %input), fft_type=RFFT, fft_length={64,32} +} + +)" +}, +// irfft3d +{ +"Irfft3d", +R"(HloModule Irfft3d_module + +ENTRY %Irfft3d (input: c64[5,64,128,33]) -> f32[5,64,128,64] { + %input = c64[5,64,128,33]{3,2,1,0} parameter(0) + ROOT %fft = f32[5,64,128,64]{3,2,1,0} fft(c64[5,64,128,33]{3,2,1,0} %input), fft_type=IRFFT, fft_length={64,128,64} +} + )" }, // pad { "Pad", -R"(HloModule Pad1DS3Array_module: +R"(HloModule Pad1DS3Array_module ENTRY %Pad1DS3Array.v3 () -> f32[8] { %constant = f32[3]{0} constant({1, 2, 3}) @@ -612,7 +645,7 @@ ENTRY %Pad1DS3Array.v3 () -> f32[8] { // pad has interior { "PadHasInterior", -R"(HloModule PadHasInterior_module: +R"(HloModule PadHasInterior_module ENTRY %PadHasInterior.v3 (input: f32[1,25,7,7]) -> f32[1,25,17,11] { %input = f32[1,25,7,7]{3,2,1,0} parameter(0) @@ -620,12 +653,25 @@ ENTRY %PadHasInterior.v3 (input: f32[1,25,7,7]) -> f32[1,25,17,11] { ROOT %pad = f32[1,25,17,11]{3,2,1,0} pad(f32[1,25,7,7]{3,2,1,0} %input, f32[] %constant), padding=0_0_0x0_0_0x2_2_1x2_2_0 } +)" +}, +// Negative padding +{ +"PadHasNegativePadding", +R"(HloModule PadHasNegativePadding_module + +ENTRY %PadHasNegativePadding (input: f32[1,25,7,7,10]) -> f32[1,15,6,3,29] { + %input = f32[1,25,7,7,10]{4,3,2,1,0} parameter(0) + %constant = f32[] constant(-5.123) + ROOT %pad = f32[1,15,6,3,29]{4,3,2,1,0} pad(f32[1,25,7,7,10]{4,3,2,1,0} %input, f32[] %constant), padding=0_0_0x0_-10_0x0_-1_0x-2_-2_0x-1_-1_3 +} + )" }, // fusion { "Fusion", -R"(HloModule fusion_module: +R"(HloModule fusion_module %fused_computation (constant.param_0: f32[3,2,1,1], constant.1.param_1: f32[2]) -> f32[3,2,1,1] { %constant.param_0 = f32[3,2,1,1]{3,2,1,0} parameter(0) @@ -640,22 +686,182 @@ ENTRY %fusion.v3 () -> f32[3,2,1,1] { ROOT %fusion = f32[3,2,1,1]{3,2,1,0} fusion(f32[3,2,1,1]{3,2,1,0} %constant, f32[2]{0} %constant.1), kind=kLoop, calls=%fused_computation } +)" +}, +{ +"Sparse", +R"(HloModule sparse_f32 + +ENTRY %sparse () -> f32[2,3,4] { + ROOT %foo = f32[2,3,4]sparse{10} constant(f32[2,3,4]{[0, 1, 2]: 1, [1, 2, 3]: 2, [2, 3, 4]: 3}) +} + +)" +}, +{ +"SparseEmpty", +R"(HloModule sparse_f32_empty + +ENTRY %sparse_f32_empty () -> f32[2,3,4] { + ROOT %foo = f32[2,3,4]sparse{10} constant(f32[2,3,4]{}) +} + +)" +}, +{ +"SparseR1", +R"(HloModule sparse_f32_r1 + +ENTRY %sparse_f32_r1 () -> f32[9] { + ROOT %foo = f32[9]sparse{10} constant(f32[9]{1: 2, 3: 4, 5: 6}) +} + +)" +}, + }); + // clang-format on +} + +std::vector CreateShortTestCases() { + // clang-format off + return std::vector({ +// map +{ +"Map", +R"(HloModule MapBinaryAdder_module + +add_F32.v3 { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) +} + +ENTRY MapBinaryAdder.v3 { + param0 = f32[4]{0} parameter(0) + param1 = f32[4]{0} parameter(1) + ROOT map = f32[4]{0} map(param0, param1), to_apply=add_F32.v3 +} + +)" +}, +// reduce +{ +"Reduce", +R"(HloModule ReduceR3ToR2_module + +add_F32.v3 { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) +} + +ENTRY ReduceR3ToR2.v3 { + input = f32[8,16,256]{2,1,0} parameter(0) + constant = f32[] constant(0) + ROOT reduce = f32[8,16]{1,0} reduce(input, constant), dimensions={2}, to_apply=add_F32.v3 +} + )" }, // infeed/outfeed { "InfeedOutfeed", -R"(HloModule outfeed_module: +R"(HloModule outfeed_module + +ENTRY InfeedToOutfeed { + infeed = (u32[3]{0}, pred[]) infeed() + outfeed = () outfeed(infeed) + ROOT infeed.1 = (u32[3]{0}, pred[]) infeed() + outfeed.1 = () outfeed(infeed.1) +} + +)" +}, +// Rng +{ +"Rng", +R"(HloModule rng_module -ENTRY %InfeedToOutfeed () -> (u32[3], pred[]) { - %infeed = (u32[3]{0}, pred[]) infeed() - %outfeed = () outfeed((u32[3]{0}, pred[]) %infeed) - ROOT %infeed.1 = (u32[3]{0}, pred[]) infeed() - %outfeed.1 = () outfeed((u32[3]{0}, pred[]) %infeed.1) +ENTRY Rng { + constant = f32[] constant(0) + constant.1 = f32[] constant(1) + ROOT rng = f32[8]{0} rng(constant, constant.1), distribution=rng_uniform } )" +}, +// Reduce precision +{ +"ReducePrevison", +R"(HloModule reduce_precision + +ENTRY ReducePrecision { + constant = f32[1]{0} constant({3.14159}) + ROOT reduce-precision = f32[1]{0} reduce-precision(constant), exponent_bits=8, mantissa_bits=10 } + +)" +}, +// Conditional +{ +"Conditional", +R"(HloModule conditional + +Negate { + x = f32[] parameter(0) + ROOT negate = f32[] negate(x) +} + +Identity { + y = f32[] parameter(0) + ROOT copy = f32[] copy(y) +} + +ENTRY Parameters1.v4 { + constant = pred[] constant(true) + constant.1 = f32[] constant(56) + constant.2 = f32[] constant(12) + ROOT conditional = f32[] conditional(constant, constant.1, constant.2), true_computation=Negate, false_computation=Identity +} + +)" +}, +// CustomCall +{ +"CustomCall", +R"(HloModule custom_call + +ENTRY CustomCall { + constant = f32[1]{0} constant({12345}) + ROOT custom-call = f32[1,2,3]{0,2,1} custom-call(constant), custom_call_target="foo\"bar" +} + +)" +}, +// Variables with non-default names +{ +"NonDefaultNames", +R"(HloModule add_constants_module + +ENTRY add_constants { + foo = f32[] constant(3.14) + ROOT bar = f32[] add(foo, foo) +} + +)" +}, +{ +"Dot", +R"(HloModule dot + +ENTRY dot { + a = f32[2,10]{1,0} parameter(0) + b = f32[10,3]{1,0} parameter(1) + ROOT dot = f32[2,3]{1,0} dot(a, b), lhs_batch_dims={0}, lhs_contracting_dims={1}, rhs_contracting_dims={0} +} + +)" +}, }); // clang-format on } @@ -674,18 +880,35 @@ class HloParserTest : public ::testing::Test, void ExpectEqual() { const string& original = GetParam().module_string; auto result = Parse(original); - TF_EXPECT_OK(result.status()); + TF_ASSERT_OK(result.status()); + EXPECT_EQ(original, result.ValueOrDie()->ToString( + HloPrintOptions().set_print_large_constants(true))); + } +}; + +class HloParserShortTest : public HloParserTest { + protected: + void ExpectEqualShort() { + const string& original = GetParam().module_string; + auto result = Parse(original); + TF_ASSERT_OK(result.status()); EXPECT_EQ(original, - result.ValueOrDie()->ToString(/*include_large_constants=*/true)); + result.ValueOrDie()->ToString(HloPrintOptions::ShortParsable())); } }; TEST_P(HloParserTest, Run) { ExpectEqual(); } +TEST_P(HloParserShortTest, Run) { ExpectEqualShort(); } + INSTANTIATE_TEST_CASE_P(HloParserTestSuccessInstantiation, HloParserTest, ::testing::ValuesIn(CreateTestCases()), TestDataToString); +INSTANTIATE_TEST_CASE_P(HloParserTestSuccessInstantiation, HloParserShortTest, + ::testing::ValuesIn(CreateShortTestCases()), + TestDataToString); + TEST_F(HloParserTest, Empty) { const string original = ""; auto result = Parse(original); @@ -749,7 +972,7 @@ ENTRY %blabla (x: f32[]) -> pred[] { } TEST_F(HloParserTest, MoreConstants) { - const string original = R"(HloModule SelectScalarS32True_module: + const string original = R"(HloModule SelectScalarS32True_module ENTRY %SelectScalarS32True.v4 () -> s32[] { %constant.2 = pred[] constant(true) @@ -766,7 +989,7 @@ ENTRY %SelectScalarS32True.v4 () -> s32[] { } TEST_F(HloParserTest, LiteralDimensionsMismatch_1) { - const string original = R"(HloModule some_2_module: + const string original = R"(HloModule some_2_module ENTRY %some_2 () -> f32[2] { ROOT %constant = f32[2]{0} constant({1,{2}}) @@ -780,7 +1003,7 @@ ENTRY %some_2 () -> f32[2] { } TEST_F(HloParserTest, LiteralDimensionsMismatch_2) { - const string original = R"(HloModule some_2x3_module: + const string original = R"(HloModule some_2x3_module ENTRY %some_2x3 () -> f32[2,3] { ROOT %constant = f32[2,3]{1,0} constant(f32[2,3] {1, 2, 3, 4, 5, 6}) @@ -794,7 +1017,7 @@ ENTRY %some_2x3 () -> f32[2,3] { } TEST_F(HloParserTest, LiteralDimensionsMismatch_3) { - const string original = R"(HloModule some_2x3x2_module: + const string original = R"(HloModule some_2x3x2_module ENTRY %some_2x3x2 () -> f32[2,3,2] { ROOT %constant = f32[2,3,2]{2,1,0} constant(f32[2,3,2] {{{1, 2}, {3, 4}, {5, 6}, {7, 8}, {9, 10}, {11, 12}}}) @@ -809,7 +1032,7 @@ ENTRY %some_2x3x2 () -> f32[2,3,2] { TEST_F(HloParserTest, ConstantF16Overflow) { const string original = - R"(HloModule ConstantF16Overflow_module: + R"(HloModule ConstantF16Overflow_module ENTRY %ConstantF16Overflow.v4 () -> f16[] { ROOT %constant = f16[] constant(-65505) @@ -823,7 +1046,7 @@ ENTRY %ConstantF16Overflow.v4 () -> f16[] { } TEST_F(HloParserTest, ConstantWithExp) { - const string original = R"(HloModule ConstantWithExp_module: + const string original = R"(HloModule ConstantWithExp_module ENTRY %ConstantWithExp.v4 () -> f32[] { %constant.1 = f32[] constant(3e+2) @@ -838,7 +1061,7 @@ ENTRY %ConstantWithExp.v4 () -> f32[] { } TEST_F(HloParserTest, AttibutesAnyOrder) { - const string original = R"(HloModule any_order_module: + const string original = R"(HloModule any_order_module ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2,1] { %input = f32[1,2,1]{2,1,0} parameter(0) @@ -852,7 +1075,7 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2 } TEST_F(HloParserTest, InvalidDimLabels) { - string prefix = R"(HloModule invalid_dim_labels_module: + string prefix = R"(HloModule invalid_dim_labels_module ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2,1] { %input = f32[1,2,1]{2,1,0} parameter(0) @@ -864,19 +1087,21 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2 )"; - ExpectHasSubstr(Parse(StrCat(prefix, ",dim_labels=00_01_10", suffix)) - .status() - .error_message(), - "expects dim labels pattern"); + ExpectHasSubstr( + Parse(tensorflow::strings::StrCat(prefix, ",dim_labels=00_01_10", suffix)) + .status() + .error_message(), + "expects dim labels pattern"); - ExpectHasSubstr(Parse(StrCat(prefix, ",dim_labels=010_1100->010", suffix)) + ExpectHasSubstr(Parse(tensorflow::strings::StrCat( + prefix, ",dim_labels=010_1100->010", suffix)) .status() .error_message(), "must have the same rank"); } TEST_F(HloParserTest, UnexpectedAttribute) { - const string original = R"(HloModule unexpected_attr_module: + const string original = R"(HloModule unexpected_attr_module ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { %recv = (f32[], u32[]) recv(), channel_id=15 @@ -892,7 +1117,7 @@ ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { } TEST_F(HloParserTest, MissingAttribute) { - const string original = R"(HloModule missing_attr_module: + const string original = R"(HloModule missing_attr_module ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { %recv = (f32[], u32[]) recv(), channel_id=15 @@ -908,7 +1133,7 @@ ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { } TEST_F(HloParserTest, PredecessorUndefined) { - const string original = R"(HloModule pre_not_found_module: + const string original = R"(HloModule pre_not_found_module ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { %recv = (f32[], u32[]) recv(), channel_id=15 @@ -924,7 +1149,7 @@ ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { } TEST_F(HloParserTest, SliceAllowOmitStride1) { - const string original = R"(HloModule slice_module: + const string original = R"(HloModule slice_module ENTRY %slice.v2 (p0: f32[3,3,4,4]) -> f32[3,3,2,4] { %p0 = f32[3,3,4,4]{3,2,1,0} parameter(0) @@ -936,7 +1161,7 @@ ENTRY %slice.v2 (p0: f32[3,3,4,4]) -> f32[3,3,2,4] { } TEST_F(HloParserTest, PaddingConfigIsNotWindowPad) { - const string original = R"(HloModule window_pad_module: + const string original = R"(HloModule window_pad_module ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2,1] { %input = f32[1,2,1]{2,1,0} parameter(0) @@ -951,7 +1176,7 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2 } TEST_F(HloParserTest, CommaBetweenSubAttributes) { - const string original = R"(HloModule test_comma_module: + const string original = R"(HloModule test_comma_module ENTRY %test_comma.v4 () -> f32[] { ROOT %constant = f32[] constant(-4.2), metadata={source_line=5, op_type="::const"} @@ -961,6 +1186,95 @@ ENTRY %test_comma.v4 () -> f32[] { TF_EXPECT_OK(Parse(original).status()); } +TEST_F(HloParserTest, ComputationShapeDoesNotMatchRootShape) { + const string original = R"(HloModule custom_call: + +ENTRY %CustomCall () -> f32[1] { + %constant = f32[1]{0} constant({12345}) + ROOT %foo = f32[1,2,3]{0,2,1} custom-call(f32[1]{0} %constant), custom_call_target="foo\"bar" +})"; + ExpectHasSubstr(Parse(original).status().error_message(), + "Shape of computation CustomCall, f32[1], is not compatible " + "with that of its root instruction foo, f32[1,2,3]"); +} + +TEST_F(HloParserTest, EntryComputationWithLayout) { + const string original = R"(HloModule layout: +add_F32.v3 { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) +} + +ENTRY %Reduce (input: f32[8,16,256]) -> f32[8,16] { + input = f32[8,16,256]{0,1,2} parameter(0) + constant = f32[] constant(0) + ROOT reduce = f32[8,16]{0,1} reduce(input, constant), dimensions={2}, to_apply=add_F32.v3 +})"; + + auto module = Parse(original); + TF_ASSERT_OK(module.status()); + auto program_layout = module.ValueOrDie()->entry_computation_layout(); + ASSERT_EQ(program_layout.parameter_count(), 1); + auto param_layout = program_layout.parameter_layout(0).layout(); + auto result_layout = program_layout.result_layout().layout(); + EXPECT_TRUE( + LayoutUtil::Equal(LayoutUtil::MakeLayout({0, 1, 2}), param_layout)) + << "actual layout of parameter(0) is " + << LayoutUtil::HumanString(param_layout); + EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({0, 1}), result_layout)) + << "actual layout of result is " + << LayoutUtil::HumanString(result_layout); +} + +TEST_F(HloParserTest, NoEntry) { + const string original = R"(HloModule no_entry: +c1 { + const1 = f32[1]{0} constant({12345}) +} +c2 { + const2 = f32[1]{0} constant({67890}) +})"; + auto module = Parse(original); + TF_ASSERT_OK(module.status()); + EXPECT_EQ(module.ValueOrDie()->entry_computation()->name(), "c2"); +} + +TEST_F(HloParserTest, NoRoot) { + const string original = R"(HloModule no_root: +ENTRY consts { + first = f32[1]{0} constant({12345}) + last = f32[1]{0} constant({67890}) +})"; + auto module = Parse(original); + TF_ASSERT_OK(module.status()); + EXPECT_EQ( + module.ValueOrDie()->entry_computation()->root_instruction()->name(), + "last"); +} + +TEST_F(HloParserTest, MultipleEntries) { + const string original = R"(HloModule multiple_entries: +ENTRY c1 { + const1 = f32[1]{0} constant({12345}) +} +ENTRY c2 { + const2 = f32[1]{0} constant({67890}) +})"; + ExpectHasSubstr(Parse(original).status().error_message(), + "expects only one ENTRY"); +} + +TEST_F(HloParserTest, MultipleRoots) { + const string original = R"(HloModule multiple_roots: +ENTRY consts { + ROOT const1 = f32[1]{0} constant({12345}) + ROOT const2 = f32[1]{0} constant({12345}) +})"; + ExpectHasSubstr(Parse(original).status().error_message(), + "one computation should have only one ROOT"); +} + } // namespace } // namespace tools } // namespace xla diff --git a/tensorflow/compiler/xla/tools/parser/hlo_token.h b/tensorflow/compiler/xla/tools/parser/hlo_token.h index 07e48804d053f31bdff6678f09ee2c1e3b731e0f..7928bee5c2097f353b182095a555c334d7b69c95 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_token.h +++ b/tensorflow/compiler/xla/tools/parser/hlo_token.h @@ -18,6 +18,9 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/platform/types.h" + namespace xla { namespace tools { @@ -60,10 +63,9 @@ enum class TokKind { kDimLabels, // [0-9bf]{2,}_[0-9io]{2,}->[0-9bf]{2,} kDxD, // [0-9]+(x[0-9]+)+ kPad, // [0-9]+_[0-9]+(_[0-9]+)?(x[0-9]+_[0-9]+(_[0-9]+)?)* + kIdent, // other identifiers kString, // "abcd\"\n" kShape, // f32[2,3]{1,0} - kOpcode, // add - kFusionKind, // kLoop, kOutput, ... kInt, // 42 kDecimal, // 4.2 }; diff --git a/tensorflow/compiler/xla/tools/replay_computation.cc b/tensorflow/compiler/xla/tools/replay_computation.cc index ec3f6a0471e2ae965846f5ef7560e448fe9d8073..eda5effbb92db92c9317a956497a00c0ec15c27c 100644 --- a/tensorflow/compiler/xla/tools/replay_computation.cc +++ b/tensorflow/compiler/xla/tools/replay_computation.cc @@ -59,25 +59,33 @@ namespace xla { namespace tools { namespace { +// Command-line opts to this tool. See main() for descriptions of these +// fields. +struct Options { + string fake_infeed_shape; + bool use_fake_data = false; + bool print_result = true; + int num_runs = 1; +}; + // Invokes the given computation passing arbitrary data for every (unbound) // parameter if use_fake_data, Otherwise use recorded data if available. // // Similarly, infeeds fake data of shape fake_infeed_shape if it is provided; // otherwise, no infeed is performed. StatusOr> ReplayComputation( - const SessionModule& module, int num_runs, - tensorflow::StringPiece fake_infeed_shape, bool use_fake_data, - Client* client) { + const SessionModule& module, Client* client, const Options& opts) { TF_ASSIGN_OR_RETURN(Computation computation, client->LoadSnapshot(module)); std::vector> arguments; - if (use_fake_data) { + if (opts.use_fake_data) { arguments = MakeFakeArgumentsOrDie(computation, client); } else { // use recorded data if available for (const auto& proto : module.arguments()) { - Literal literal(proto); + TF_ASSIGN_OR_RETURN(std::unique_ptr literal, + Literal::CreateFromProto(proto)); TF_ASSIGN_OR_RETURN(std::unique_ptr data, - client->TransferToServer(literal)); + client->TransferToServer(*literal)); arguments.push_back(std::move(data)); } } @@ -86,12 +94,12 @@ StatusOr> ReplayComputation( // concurrent infeed occur via the fake_infeed_shape. tensorflow::gtl::optional pool; - if (!fake_infeed_shape.empty()) { + if (!opts.fake_infeed_shape.empty()) { pool.emplace(tensorflow::Env::Default(), "infeed", /*num_threads=*/1); - pool->Schedule([fake_infeed_shape, client]() { + pool->Schedule([opts, client]() { StatusOr shape_status = - ShapeUtil::ParseShapeString(fake_infeed_shape); + ShapeUtil::ParseShapeString(opts.fake_infeed_shape); TF_CHECK_OK(shape_status.status()); Shape shape = std::move(shape_status).ValueOrDie(); StatusOr> data_status = MakeFakeLiteral(shape); @@ -112,19 +120,19 @@ StatusOr> ReplayComputation( // Run the computation num_runs times, and return the result from the last // execution. std::unique_ptr result; - for (int i = 0; i < num_runs; ++i) { + for (int i = 0; i < opts.num_runs; ++i) { ExecutionProfile profile; - if (use_fake_data) { - // If using fake data, execute the computation but don't bother retrieving - // the result -- presumably it's uninteresting, since our data is fake. + if (opts.print_result) { + TF_ASSIGN_OR_RETURN(result, client->ExecuteAndTransfer( + computation, execute_arguments, + /*execution_options=*/nullptr, &profile)); + } else { + // If we're not printing the result, execute the computation but don't + // bother retrieving the result. This can be a significant speedup. TF_RETURN_IF_ERROR(client ->Execute(computation, execute_arguments, /*execution_options=*/nullptr, &profile) .status()); - } else { - TF_ASSIGN_OR_RETURN(result, client->ExecuteAndTransfer( - computation, execute_arguments, - /*execution_options=*/nullptr, &profile)); } LOG(INFO) << "Execution took " << static_cast(profile.compute_time_ns()) / 1e9 << "s"; @@ -133,16 +141,15 @@ StatusOr> ReplayComputation( return std::move(result); } -int RealMain(tensorflow::gtl::ArraySlice args, int num_runs, - tensorflow::StringPiece fake_infeed_shape, bool use_fake_data) { +int RealMain(tensorflow::gtl::ArraySlice args, const Options& opts) { Client* client = ClientLibrary::LocalClientOrDie(); tensorflow::Env* env = tensorflow::Env::Default(); int exit_status = EXIT_SUCCESS; for (char* arg : args) { SessionModule module; TF_CHECK_OK(tensorflow::ReadBinaryProto(env, arg, &module)); - StatusOr> result_status = ReplayComputation( - module, num_runs, fake_infeed_shape, use_fake_data, client); + StatusOr> result_status = + ReplayComputation(module, client, opts); if (!result_status.ok()) { fprintf(stderr, "%s: error: %s\n", arg, result_status.status().ToString().c_str()); @@ -156,12 +163,16 @@ int RealMain(tensorflow::gtl::ArraySlice args, int num_runs, ShapeUtil::HumanString(result->shape()).c_str(), result->ToString().c_str()); if (module.has_result()) { + std::unique_ptr literal = + Literal::CreateFromProto(module.result()).ConsumeValueOrDie(); fprintf(stdout, "was %s:%s\n", ShapeUtil::HumanString(module.result().shape()).c_str(), - Literal(module.result()).ToString().c_str()); + literal->ToString().c_str()); } } } + + ClientLibrary::DestroyLocalInstances(); return exit_status; } @@ -170,16 +181,15 @@ int RealMain(tensorflow::gtl::ArraySlice args, int num_runs, } // namespace xla int main(int argc, char** argv) { - // Flags - xla::string fake_infeed_shape; - bool use_fake_data = false; - int num_runs = 1; + xla::tools::Options opts; const std::vector flag_list = { - tensorflow::Flag("use_fake_data", &use_fake_data, + tensorflow::Flag("use_fake_data", &opts.use_fake_data, "Replay computation using fake data"), - tensorflow::Flag("num_runs", &num_runs, + tensorflow::Flag("print_result", &opts.print_result, + "Print the result of the computation to stdout"), + tensorflow::Flag("num_runs", &opts.num_runs, "Number of times to run each computation"), - tensorflow::Flag("fake_infeed_shape", &fake_infeed_shape, + tensorflow::Flag("fake_infeed_shape", &opts.fake_infeed_shape, "Shape of fake data to construct for (infinite) infeed"), }; xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); @@ -191,5 +201,5 @@ int main(int argc, char** argv) { tensorflow::gtl::ArraySlice args(argv, argc); args.pop_front(); // Pop off the binary name, argv[0] - return xla::tools::RealMain(args, num_runs, fake_infeed_shape, use_fake_data); + return xla::tools::RealMain(args, opts); } diff --git a/tensorflow/compiler/xla/tools/show_literal.cc b/tensorflow/compiler/xla/tools/show_literal.cc index b50cb5e28eac14ed99af566939f8bd64e393ff64..fe8e72ba32bb4493b2751cfdfeb977f271092f9c 100644 --- a/tensorflow/compiler/xla/tools/show_literal.cc +++ b/tensorflow/compiler/xla/tools/show_literal.cc @@ -40,7 +40,8 @@ int main(int argc, char **argv) { xla::LiteralProto literal_proto; TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), argv[1], &literal_proto)); - xla::Literal literal(literal_proto); + std::unique_ptr literal = + xla::Literal::CreateFromProto(literal_proto).ConsumeValueOrDie(); LOG(INFO) << "literal: " << literal_proto.ShortDebugString(); - fprintf(stderr, "%s\n", literal.ToString().c_str()); + fprintf(stderr, "%s\n", literal->ToString().c_str()); } diff --git a/tensorflow/compiler/xla/tools/show_text_literal.cc b/tensorflow/compiler/xla/tools/show_text_literal.cc index bbe9902aa17a585c4bad5b732330305dfdd45302..8525873e913185554d18df8c8c3584bfcdcdcabe 100644 --- a/tensorflow/compiler/xla/tools/show_text_literal.cc +++ b/tensorflow/compiler/xla/tools/show_text_literal.cc @@ -39,13 +39,13 @@ int main(int argc, char **argv) { std::unique_ptr literal = xla::TextLiteralReader::ReadPath(argv[1]).ConsumeValueOrDie(); - LOG(INFO) << "literal: " << literal->ShortDebugString(); + LOG(INFO) << "literal: " << *literal; fprintf(stderr, "%s\n", literal->ToString().c_str()); if (literal->shape().element_type() == xla::F32) { - float min = - *std::min_element(literal->f32s().begin(), literal->f32s().end()); - float max = - *std::max_element(literal->f32s().begin(), literal->f32s().end()); + float min = *std::min_element(literal->data().begin(), + literal->data().end()); + float max = *std::max_element(literal->data().begin(), + literal->data().end()); fprintf(stderr, "min: %a=%f\n", min, min); fprintf(stderr, "max: %a=%f\n", max, max); } diff --git a/tensorflow/compiler/xla/util.cc b/tensorflow/compiler/xla/util.cc index e595df3052c3de64de503d7627eff72dcba177ee..fe5d29a6b655a89d559eb1214c2b8dd54d34094c 100644 --- a/tensorflow/compiler/xla/util.cc +++ b/tensorflow/compiler/xla/util.cc @@ -191,9 +191,9 @@ std::vector ComposePermutations(tensorflow::gtl::ArraySlice p1, return output; } -bool IsIdentityPermutation(tensorflow::gtl::ArraySlice p) { - for (int64 i = 0; i < p.size(); ++i) { - if (p[i] != i) { +bool IsIdentityPermutation(tensorflow::gtl::ArraySlice permutation) { + for (int64 i = 0; i < permutation.size(); ++i) { + if (permutation[i] != i) { return false; } } diff --git a/tensorflow/compiler/xla/util.h b/tensorflow/compiler/xla/util.h index b722095d1f38bf8a984c3ce9092a65f8e0baa911..bb2db2010c5e0da6ed3fde628eb5928d555815b2 100644 --- a/tensorflow/compiler/xla/util.h +++ b/tensorflow/compiler/xla/util.h @@ -239,11 +239,14 @@ std::vector Permute(tensorflow::gtl::ArraySlice permutation, // Override of the above that works around compile failures with gcc 7.1.1. // For details see https://github.com/tensorflow/tensorflow/issues/10843 +// Hide this workaround from MSVC as it causes ambiguous error. +#ifndef _MSC_VER template std::vector Permute(tensorflow::gtl::ArraySlice permutation, const std::vector& input) { return Permute(permutation, input); } +#endif // Inverts a permutation, i.e., output_permutation[input_permutation[i]] = i. std::vector InversePermutation( @@ -395,6 +398,31 @@ std::vector> CommonFactors( // Removes illegal characters from filenames. string SanitizeFileName(string file_name); +// Simple wrapper around std::all_of. +template +bool c_all_of(Container container, Predicate predicate) { + return std::all_of(std::begin(container), std::end(container), predicate); +} + +// Simple wrapper around std::transform. +template +OutputIterator c_transform(InputContainer input_container, + OutputIterator output_iterator, + UnaryOperation unary_op) { + return std::transform(std::begin(input_container), std::end(input_container), + output_iterator, unary_op); +} + +// Simple wrapper around std::copy_if. +template +OutputIterator c_copy_if(InputContainer input_container, + OutputIterator output_iterator, + UnaryPredicate predicate) { + return std::copy_if(std::begin(input_container), std::end(input_container), + output_iterator, predicate); +} + } // namespace xla #define XLA_LOG_LINES(SEV, STRING) \ diff --git a/tensorflow/compiler/xla/window_util.cc b/tensorflow/compiler/xla/window_util.cc index 2e0eba8de0100fb4e7e45348618febd778c88c9a..224eb2a20c8fc5ac4bfe2bb92a65a3bd178dbaf6 100644 --- a/tensorflow/compiler/xla/window_util.cc +++ b/tensorflow/compiler/xla/window_util.cc @@ -88,6 +88,11 @@ string ToString(const Window& window) { return StrCat(dim.window_dilation()); }); } + if (HasWindowReversal(window)) { + add_field(" rhs_reversal", [](const WindowDimension& dim) { + return StrCat(dim.window_reversal() ? 1 : 0); + }); + } return str; } @@ -141,10 +146,25 @@ bool HasWindowDilation(const Window& window) { return false; } +bool HasWindowReversal(const Window& window) { + for (const auto& dim : window.dimensions()) { + if (dim.window_reversal()) { + return true; + } + } + return false; +} + bool HasDilation(const Window& window) { return HasBaseDilation(window) || HasWindowDilation(window); } +bool IsInactiveWindowDimension(const Window& window, int64 logical_dim) { + const WindowDimension& window_dim = window.dimensions(logical_dim); + return window_dim.size() == 1 && window_dim.stride() == 1 && + window_dim.padding_low() == 0 && window_dim.padding_high() == 0; +} + int64 DilatedBound(int64 bound, int64 dilation) { CHECK_GE(bound, 0); CHECK_GE(dilation, 1); diff --git a/tensorflow/compiler/xla/window_util.h b/tensorflow/compiler/xla/window_util.h index 235cb2d59d451a25dc4f824ab488f8cef6b03bfb..17c388fc0b551ec227802434b7db435c4d25d985 100644 --- a/tensorflow/compiler/xla/window_util.h +++ b/tensorflow/compiler/xla/window_util.h @@ -39,6 +39,12 @@ bool HasBaseDilation(const Window& window); bool HasWindowDilation(const Window& window); bool HasDilation(const Window& window); +bool HasWindowReversal(const Window& window); + +// Returns true if the given logical dimension is inactive in the sense that it +// has window bound 1, no striding and no padding. +bool IsInactiveWindowDimension(const Window& window, int64 logical_dim); + // Returns the new bound after dilation. // // If a window with the given bound in some dimension is dilated with the given diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto index 127e5e81ac6d21945c7125ef913d236e8892758e..fda1a4c27b6dea1b7e4dee76de976f93ba61c007 100644 --- a/tensorflow/compiler/xla/xla.proto +++ b/tensorflow/compiler/xla/xla.proto @@ -175,6 +175,10 @@ message DebugOptions { // assignments, if available. bool xla_hlo_tfgraph_device_scopes = 93; + // If true, the GPU backend is free to use cudnn for HLO batch normalization + // ops. + bool xla_gpu_use_cudnn_batchnorm = 94; + // Extra options to pass to the compilation backend; specific interpretation // of these values is left to the backend. map xla_backend_extra_options = 500; diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto index 2ba1a2d904e45e582ee4e8a4ea889ee69d55e747..3aea0217539b89b5d60ecfaf2605eee4b69af728 100644 --- a/tensorflow/compiler/xla/xla_data.proto +++ b/tensorflow/compiler/xla/xla_data.proto @@ -114,6 +114,17 @@ message PaddingConfig { repeated PaddingConfigDimension dimensions = 1; } +// A format specifies the method used by a layout to store an array in memory. +enum Format { + INVALID_FORMAT = 0; + // The default layout, with exactly one storage location per element (ignoring + // padding). + DENSE = 1; + // A sparsely encoded layout, providing only the index/value pairs of non-zero + // elements. + SPARSE = 2; +} + // A layout describes how the array is placed in (1D) memory space. This // includes the minor-to-major ordering of dimensions within a shape, as well as // any padding present in those dimensions. @@ -124,21 +135,30 @@ message PaddingConfig { // // See the XLA documentation for more information on shapes and layouts. message Layout { + // The method used to store the data in memory. The format determines which of + // the other fields are used by the layout. + Format format = 4; + // Sequence of dimension numbers, from minor (fastest varying index) to major // (slowest varying index). This field is required. repeated int64 minor_to_major = 1; - // The width to which the layout of each dimension is padded up - // to. If present, the size of the padded_dimensions must equal the - // rank of the shape. The padding appears at the end of a dimension, - // not at the beginning. This kind of padding, unlike padding in - // e.g. convolution, is not part of the shape. + // The width to which the layout of each dimension is padded up to. If + // present, the size of the padded_dimensions must equal the rank of the + // shape. The padding appears at the end of a dimension, not at the + // beginning. This kind of padding, unlike padding in e.g. convolution, is not + // part of the shape. This field must be unset unless the format is DENSE. repeated int64 padded_dimensions = 2; - // Describes the values in the padding specified by - // padded_dimensions. + // Describes the values in the padding specified by padded_dimensions. This + // field must be unset unless the format is DENSE. PaddingValue padding_value = 3; + // The maximum number of elements that can be stored for SPARSE formats. This + // can be used to determine the maximum size in bytes of arrays stored in + // memory. This field must be unset unless the format is SPARSE. + int64 max_sparse_elements = 5; + // Important: if any field is added, be sure to modify ShapeUtil::Equal() // appropriately to account for the new field. } @@ -321,7 +341,8 @@ message LiteralProto { // The F16s and BF16s are encoded in little endian byte order bytes f16s = 11; bytes bf16s = 13; - // Next = 14 + repeated int64 sparse_indices = 14; + // Next = 15 } message WindowDimension { @@ -498,6 +519,23 @@ message CustomCallRequest { Shape shape = 4; } +message DotDimensionNumbers { + // The dimension numbers that represent the 'lhs' contracting dimensions. + repeated int64 lhs_contracting_dimensions = 1; + // The dimension numbers that represent the 'rhs' contracting dimensions. + repeated int64 rhs_contracting_dimensions = 2; + // The dimension numbers that represent the 'lhs' batch dimensions. + repeated int64 lhs_batch_dimensions = 3; + // The dimension numbers that represent the 'rhs' batch dimensions. + repeated int64 rhs_batch_dimensions = 4; +}; + +message DotRequest { + ComputationDataHandle lhs = 2; + ComputationDataHandle rhs = 3; + DotDimensionNumbers dimension_numbers = 4; +} + message MapRequest { repeated ComputationDataHandle operands = 2; ComputationHandle to_apply = 3; @@ -651,6 +689,14 @@ message ConcatenateRequest { int64 dimension = 3; } +message ConditionalRequest { + ComputationDataHandle predicate = 2; + ComputationDataHandle true_operand = 3; + ComputationHandle true_computation = 4; + ComputationDataHandle false_operand = 5; + ComputationHandle false_computation = 6; +} + message WhileRequest { ComputationHandle condition = 2; ComputationHandle body = 3; @@ -732,9 +778,6 @@ enum BinaryOperation { BINOP_LT = 9; BINOP_NE = 10; - // Dot product, matrix multiply. - BINOP_DOT = 12; - // Element-wise maximum. BINOP_MAX = 14; @@ -780,9 +823,7 @@ enum RandomDistribution { // parameter[0] and standard deviation parameter[1]. RNG_NORMAL = 2; - // Creates a Bernoulli-distribution-generated random number with mean - // parameter[0]. - RNG_BERNOULLI = 3; + // Next: 4 } message RngRequest { @@ -885,6 +926,7 @@ message OpRequest { ConvolveRequest convolve_request = 8; CrossReplicaSumRequest cross_replica_sum_request = 9; CustomCallRequest custom_call_request = 10; + DotRequest dot_request = 43; DynamicSliceRequest dynamic_slice_request = 11; DynamicUpdateSliceRequest dynamic_update_slice_request = 12; GetTupleElementRequest get_tuple_element_request = 13; @@ -914,7 +956,8 @@ message OpRequest { BatchNormInferenceRequest batch_norm_inference_request = 38; FftRequest fft_request = 41; ConvertRequest bitcast_convert_request = 42; - // Next: 43 + ConditionalRequest conditional_request = 44; + // Next: 45 } } diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD index 61f7821519bc1d053ee3b273a6b36b9dbd973245..8bed0fabd743c9cf9a51fe574401ae42730d15b4 100644 --- a/tensorflow/contrib/BUILD +++ b/tensorflow/contrib/BUILD @@ -6,10 +6,16 @@ licenses(["notice"]) # Apache 2.0 package(default_visibility = ["//tensorflow:__subpackages__"]) load("//third_party/mpi:mpi.bzl", "if_mpi") +load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") py_library( name = "contrib_py", - srcs = glob(["**/*.py"]), + srcs = glob( + ["**/*.py"], + exclude = [ + "**/*_test.py", + ], + ), srcs_version = "PY2AND3", visibility = ["//visibility:public"], deps = [ @@ -19,6 +25,7 @@ py_library( "//tensorflow/contrib/boosted_trees:init_py", "//tensorflow/contrib/cloud:cloud_py", "//tensorflow/contrib/cluster_resolver:cluster_resolver_py", + "//tensorflow/contrib/coder:coder_ops_py", "//tensorflow/contrib/compiler:compiler_py", "//tensorflow/contrib/copy_graph:copy_graph_py", "//tensorflow/contrib/crf:crf_py", @@ -48,6 +55,7 @@ py_library( "//tensorflow/contrib/layers:layers_py", "//tensorflow/contrib/learn", "//tensorflow/contrib/legacy_seq2seq:seq2seq_py", + "//tensorflow/contrib/libsvm", "//tensorflow/contrib/linalg:linalg_py", "//tensorflow/contrib/linear_optimizer:sdca_estimator_py", "//tensorflow/contrib/linear_optimizer:sdca_ops_py", @@ -68,6 +76,7 @@ py_library( "//tensorflow/contrib/predictor", "//tensorflow/contrib/quantization:quantization_py", "//tensorflow/contrib/quantize:quantize_graph", + "//tensorflow/contrib/receptive_field:receptive_field_py", "//tensorflow/contrib/reduce_slice_ops:reduce_slice_ops_py", "//tensorflow/contrib/remote_fused_graph/pylib:remote_fused_graph_ops_py", "//tensorflow/contrib/resampler:resampler_py", @@ -95,7 +104,7 @@ py_library( "//tensorflow/contrib/training:training_py", "//tensorflow/contrib/util:util_py", "//tensorflow/python:util", - ] + if_mpi(["//tensorflow/contrib/mpi_collectives:mpi_ops_py"]), + ] + if_mpi(["//tensorflow/contrib/mpi_collectives:mpi_collectives_py"]), ) cc_library( @@ -104,11 +113,11 @@ cc_library( deps = [ "//tensorflow/contrib/batching:batch_ops_kernels", "//tensorflow/contrib/boosted_trees:boosted_trees_kernels", + "//tensorflow/contrib/coder:all_kernels", "//tensorflow/contrib/cudnn_rnn:cudnn_rnn_kernels", "//tensorflow/contrib/factorization/kernels:all_kernels", "//tensorflow/contrib/input_pipeline:input_pipeline_ops_kernels", "//tensorflow/contrib/layers:sparse_feature_cross_op_kernel", - "//tensorflow/contrib/nccl:nccl_kernels", "//tensorflow/contrib/nearest_neighbor:nearest_neighbor_ops_kernels", "//tensorflow/contrib/rnn:all_kernels", "//tensorflow/contrib/seq2seq:beam_search_ops_kernels", @@ -116,7 +125,9 @@ cc_library( "//tensorflow/contrib/tensor_forest:stats_ops_kernels", "//tensorflow/contrib/tensor_forest:tensor_forest_kernels", "//tensorflow/contrib/text:all_kernels", - ], + ] + if_mpi(["//tensorflow/contrib/mpi_collectives:mpi_collectives_py"]) + if_cuda([ + "//tensorflow/contrib/nccl:nccl_kernels", + ]), ) cc_library( @@ -125,6 +136,7 @@ cc_library( deps = [ "//tensorflow/contrib/batching:batch_ops_op_lib", "//tensorflow/contrib/boosted_trees:boosted_trees_ops_op_lib", + "//tensorflow/contrib/coder:all_ops", "//tensorflow/contrib/cudnn_rnn:cudnn_rnn_ops_op_lib", "//tensorflow/contrib/factorization:all_ops", "//tensorflow/contrib/framework:all_ops", diff --git a/tensorflow/contrib/__init__.py b/tensorflow/contrib/__init__.py index 08247c6b38a4df663ad28a6b4d3c41a1da41a020..f600a8a99816586d6bd7d7ab51354888c435e739 100644 --- a/tensorflow/contrib/__init__.py +++ b/tensorflow/contrib/__init__.py @@ -22,6 +22,7 @@ from __future__ import print_function from tensorflow.contrib import bayesflow from tensorflow.contrib import cloud from tensorflow.contrib import cluster_resolver +from tensorflow.contrib import coder from tensorflow.contrib import compiler from tensorflow.contrib import copy_graph from tensorflow.contrib import crf @@ -82,13 +83,14 @@ from tensorflow.contrib import util from tensorflow.contrib.eager.python import tfe as eager from tensorflow.contrib.lite.python import lite from tensorflow.contrib.ndlstm import python as ndlstm +from tensorflow.contrib.receptive_field import receptive_field_api as receptive_field from tensorflow.contrib.remote_fused_graph import pylib as remote_fused_graph from tensorflow.contrib.specs import python as specs from tensorflow.contrib.summary import summary from tensorflow.python.util.lazy_loader import LazyLoader -ffmpeg = LazyLoader("ffmpeg", - globals(), "tensorflow.contrib.ffmpeg") +ffmpeg = LazyLoader("ffmpeg", globals(), + "tensorflow.contrib.ffmpeg") del LazyLoader del absolute_import diff --git a/tensorflow/contrib/all_reduce/python/all_reduce.py b/tensorflow/contrib/all_reduce/python/all_reduce.py index a5057da9fd43a88575813613d6ac9d17fd2b2e28..28f60b34996945d573facc665c01d0bc10cf5cd1 100644 --- a/tensorflow/contrib/all_reduce/python/all_reduce.py +++ b/tensorflow/contrib/all_reduce/python/all_reduce.py @@ -744,13 +744,13 @@ def _build_nccl_hybrid(input_tensors, red_op, upper_level_f): level_2_output = upper_level_f(up_values) # Third stage: propagate within each worker using NCCL Broadcast for w in range(0, num_workers): - dst_devices = per_worker_devices[w][1:] - send_op, dst_tensors = nccl.broadcast(level_2_output[w], dst_devices) - # NOTE: need control dependency to ensure send_op executes - with ops.control_dependencies([send_op]): - with ops.device(per_worker_devices[w][0]): - dst_tensors.insert(0, array_ops.identity(level_2_output[w])) - down_values[w] = dst_tensors + dst_tensors = [] + with ops.device(per_worker_devices[w][0]): + broadcast_src = nccl.broadcast(array_ops.identity(level_2_output[w])) + for d in per_worker_devices[w]: + with ops.device(d): + dst_tensors.append(array_ops.identity(broadcast_src)) + down_values[w] = dst_tensors output_tensors = [v for sublist in down_values for v in sublist] if len(shape) > 1: output_tensors = _reshape_tensors(output_tensors, shape) diff --git a/tensorflow/contrib/android/README.md b/tensorflow/contrib/android/README.md index f49e5857fe5255c2459793cb1389052a2ff5f88f..b8d73bf24ce60e0b3850d4f39ac9e6d6c2194a02 100644 --- a/tensorflow/contrib/android/README.md +++ b/tensorflow/contrib/android/README.md @@ -15,9 +15,9 @@ For prebuilt libraries, see the page for a recent build. The TensorFlow Inference Interface is also available as a -[JCenter package](https://bintray.com/google/tensorflow/tensorflow-android) and -can be included quite simply in your android project with a couple of lines in -the project's `build.gradle` file: +[JCenter package](https://bintray.com/google/tensorflow/tensorflow) +(see the tensorflow-android directory) and can be included quite simply in your +android project with a couple of lines in the project's `build.gradle` file: ``` allprojects { @@ -32,9 +32,9 @@ dependencies { ``` This will tell Gradle to use the -[latest version](https://bintray.com/google/tensorflow/tensorflow-android/_latestVersion) +[latest version](https://bintray.com/google/tensorflow/tensorflow/_latestVersion) of the TensorFlow AAR that has been released to -[https://bintray.com/google/tensorflow/tensorflow-android](https://bintray.com/google/tensorflow/tensorflow-android). +[JCenter](https://jcenter.bintray.com/org/tensorflow/tensorflow-android/). You may replace the `+` with an explicit version label if you wish to use a specific release of TensorFlow in your app. diff --git a/tensorflow/contrib/android/cmake/CMakeLists.txt b/tensorflow/contrib/android/cmake/CMakeLists.txt index aba356d6167658f125001cbed6e3190c716ee7d6..a115d1610e2334a6626f29674f3dd195e3a3c648 100644 --- a/tensorflow/contrib/android/cmake/CMakeLists.txt +++ b/tensorflow/contrib/android/cmake/CMakeLists.txt @@ -34,6 +34,8 @@ add_library(lib_tf STATIC IMPORTED ) set_target_properties(lib_tf PROPERTIES IMPORTED_LOCATION ${PREBUILT_DIR}/lib/libtensorflow-core.a) # Change to compile flags should be replicated into bazel build file +# TODO: Consider options other than -O2 for binary size. +# e.g. -Os for gcc, and -Oz for clang. set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DIS_SLIM_BUILD \ -std=c++11 -fno-rtti -fno-exceptions \ -O2 -Wno-narrowing -fomit-frame-pointer \ diff --git a/tensorflow/contrib/batching/BUILD b/tensorflow/contrib/batching/BUILD index a111cfecb366fe245150cc71d2c43662d0d69090..cd98f0e70335db715b8cb6c76a9d7df3e2280552 100644 --- a/tensorflow/contrib/batching/BUILD +++ b/tensorflow/contrib/batching/BUILD @@ -12,7 +12,7 @@ cc_library( name = "batch_scheduler_hdrs", hdrs = ["batch_scheduler.h"], deps = [ - "//tensorflow/core:framework_headers_lib", + "//tensorflow/core/kernels/batching_util:batch_scheduler_hdrs", ], ) @@ -20,18 +20,7 @@ cc_library( name = "batch_scheduler", hdrs = ["batch_scheduler.h"], deps = [ - "//tensorflow/core:lib", - ], -) - -tf_cc_test( - name = "batch_scheduler_test", - srcs = ["batch_scheduler_test.cc"], - deps = [ - ":batch_scheduler", - "//tensorflow/core:lib", - "//tensorflow/core:test", - "//tensorflow/core:test_main", + "//tensorflow/core/kernels/batching_util:batch_scheduler", ], ) @@ -39,9 +28,7 @@ cc_library( name = "shared_batch_scheduler_hdrs", hdrs = ["shared_batch_scheduler.h"], deps = [ - ":batch_scheduler_hdrs", - "//tensorflow/contrib/batching/util:periodic_function_dynamic", - "//tensorflow/core:framework_headers_lib", + "//tensorflow/core/kernels/batching_util:shared_batch_scheduler_hdrs", ], ) @@ -49,46 +36,16 @@ cc_library( name = "shared_batch_scheduler", hdrs = ["shared_batch_scheduler.h"], deps = [ - ":batch_scheduler", - "//tensorflow/contrib/batching/util:periodic_function_dynamic", - "//tensorflow/core:lib", + "//tensorflow/core/kernels/batching_util:shared_batch_scheduler", ], alwayslink = 1, ) -tf_cc_test( - name = "shared_batch_scheduler_test", - srcs = ["shared_batch_scheduler_test.cc"], - deps = [ - ":shared_batch_scheduler", - "//tensorflow/contrib/batching/test_util:fake_clock_env", - "//tensorflow/core:lib", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - ], -) - cc_library( name = "adaptive_shared_batch_scheduler", hdrs = ["adaptive_shared_batch_scheduler.h"], deps = [ - ":batch_scheduler", - "//tensorflow/contrib/batching/util:periodic_function_dynamic", - "//tensorflow/core:lib", - ], -) - -tf_cc_test( - name = "adaptive_shared_batch_scheduler_test", - srcs = ["adaptive_shared_batch_scheduler_test.cc"], - tags = ["manual"], # b/69013768 - deps = [ - ":adaptive_shared_batch_scheduler", - "//tensorflow/contrib/batching/test_util:fake_clock_env", - "//tensorflow/core:lib", - "//tensorflow/core:test", - "//tensorflow/core:test_main", + "//tensorflow/core/kernels/batching_util:adaptive_shared_batch_scheduler", ], ) @@ -96,34 +53,7 @@ cc_library( name = "basic_batch_scheduler", hdrs = ["basic_batch_scheduler.h"], deps = [ - ":shared_batch_scheduler", - ], -) - -tf_cc_test( - name = "basic_batch_scheduler_test", - srcs = ["basic_batch_scheduler_test.cc"], - deps = [ - ":basic_batch_scheduler", - ":batch_scheduler", - "//tensorflow/core:lib", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - ], -) - -tf_cc_test( - name = "basic_batch_scheduler_benchmark", - srcs = ["basic_batch_scheduler_benchmark.cc"], - tags = [ - "local", - "manual", - ], - deps = [ - ":basic_batch_scheduler", - "//tensorflow/core:lib", - "//tensorflow/core:tensorflow", - "//tensorflow/core:test", + "//tensorflow/core/kernels/batching_util:basic_batch_scheduler", ], ) diff --git a/tensorflow/contrib/batching/adaptive_shared_batch_scheduler.h b/tensorflow/contrib/batching/adaptive_shared_batch_scheduler.h index 6ed177e001758ad8c566c7965e1ec10ae5235fc8..60861f83f450d3f67f21a46bdfa3fda223b9d2b4 100644 --- a/tensorflow/contrib/batching/adaptive_shared_batch_scheduler.h +++ b/tensorflow/contrib/batching/adaptive_shared_batch_scheduler.h @@ -16,447 +16,6 @@ limitations under the License. #ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_ADAPTIVE_SHARED_BATCH_SCHEDULER_H_ #define THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_ADAPTIVE_SHARED_BATCH_SCHEDULER_H_ -#include -#include -#include -#include -#include - -#include "tensorflow/contrib/batching/batch_scheduler.h" -#include "tensorflow/contrib/batching/util/periodic_function.h" -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/core/threadpool.h" -#include "tensorflow/core/platform/cpu_info.h" -#include "tensorflow/core/platform/env.h" -#include "tensorflow/core/platform/thread_annotations.h" -#include "tensorflow/core/platform/types.h" - -namespace tensorflow { -namespace serving { -namespace internal { -template -class ASBSBatch; - -template -class ASBSQueue; -} // namespace internal - -// Shared batch scheduler designed to minimize latency. The scheduler keeps -// track of a number of queues (one per model or model version) which are -// continuously enqueuing requests. The scheduler groups the requests into -// batches which it periodically sends off for processing (see -// shared_batch_scheduler.h for more details). The AdaptiveSharedBatchScheduler -// prioritizes batches by age (i.e. the batch's oldest request) irrespective of -// queue. The scheduler will process the oldest batch at an adjustable rate, -// regardless of batch size. The user can provide feedback to help set this rate -// to achieve some goal (i.e. minimize overall latency, limit cpu usage, etc). -// -// The rate (or rather, the corresponding period) is adjusted each time a batch -// is processed, using an exponentially weighted moving average to smooth -// potentially noisy feedback: -// ewma_feedback = ((N - 1) * ewma_feedback + feedback()) / N -// period *= (1 + K * emwa_feedback) -// -// Some potential use cases: -// Hardware Accelerators (GPUs & TPUs) - If some phase of batch processing -// involves serial processing by a device, from a latency perspective it is -// desirable to keep the device evenly loaded, avoiding the need to wait for -// the device to process prior batches. -// feedback = num_pending_on_device() - desired_pending. -// CPU utilization - If the batch processing is cpu dominated, you can reap -// latency gains when underutilized by increasing the processing rate, but -// back the rate off when the load increases to avoid overload. -// feedback = cpu_rate() - desired_cpu_rate. - -template -class AdaptiveSharedBatchScheduler - : public std::enable_shared_from_this< - AdaptiveSharedBatchScheduler> { - public: - struct Options { - // The name to use for the pool of batch threads. - string thread_pool_name = {"batch_threads"}; - // Number of batch processing threads; equivalently the maximum number of - // concurrently running batches. - int64 num_batch_threads = port::NumSchedulableCPUs(); - // The environment to use (typically only overridden by test code). - Env* env = Env::Default(); - // Initial batch scheduling period in microseconds. Will be altered for - // non-zero rate_feedback. - double initial_scheduling_period_micros = 500; - // Minimum batch scheduling period in microseconds. Recommend setting this - // value greater than 0, otherwise it may take a while to recover from a - // sustained time of negative scheduling_period_feedback (which may occur - // under low load). - double min_scheduling_period_micros = 100; - // Maximum batch scheduling period in microseconds. - double max_scheduling_period_micros = 10000; - // Feedback function used to modify the scheduling period each time a batch - // is scheduled. Should return values roughly O(1), with positive values - // resulting in an increased period. - std::function scheduling_period_feedback{[] { return 0.; }}; - // To handle potentially noisy scheduling_period_feedback, the period is - // adjusted using an exponentially weighted moving average over the previous - // feedback_smoothing_batches batches. Must be greater than 0. - int64 feedback_smoothing_batches = 10; - }; - - // Ownership is shared between the caller of Create() and any queues created - // via AddQueue(). - static Status Create( - const Options& options, - std::shared_ptr>* scheduler); - - struct QueueOptions { - // Maximum size of each batch. - int max_batch_size = 1000; - // Maximum number of enqueued (i.e. non-scheduled) batches. - int max_enqueued_batches = 10; - }; - - using BatchProcessor = std::function>)>; - - // Adds queue (and its callback) to be managed by this scheduler. - Status AddQueue(const QueueOptions& options, - BatchProcessor process_batch_callback, - std::unique_ptr>* queue); - - private: - // access to AddBatch, RemoveQueue, GetEnv. - friend class internal::ASBSQueue; - - explicit AdaptiveSharedBatchScheduler(const Options& options); - - // Batch scheduling function which runs every scheduling_period_ microseconds. - void ProcessOneBatch(); - - // Notifies scheduler of non-empty batch which is eligible for processing. - void AddBatch(internal::ASBSBatch*); - - // Removes queue from scheduler. - void RemoveQueue(const internal::ASBSQueue* queue); - - Env* GetEnv() const { return options_.env; } - - const Options options_; - - struct BatchCompare { - bool operator()(const internal::ASBSBatch* a, - const internal::ASBSBatch* b); - }; - - // Collection of batches added by AddBatch, ordered by age. Owned by scheduler - // until they are released for processing. - std::priority_queue*, - std::vector*>, BatchCompare> - batches_ GUARDED_BY(mu_); - - // Unowned queues and callbacks added by AddQueue. - std::unordered_map*, BatchProcessor> - queues_and_callbacks_ GUARDED_BY(mu_); - - mutex mu_; - - // Responsible for running ProcessOneBatch. PeriodicFunction was used in order - // to check for deletion so that the thread can be shut down. - std::unique_ptr scheduling_thread_; - - // Responsible for running the batch processing callbacks. - std::unique_ptr batch_thread_pool_; - - // Time interval in microseconds between successive ProcessOneBatch calls. - double scheduling_period_; - - // Exponentially weighted moving average of - // options_.scheduling_period_feedback() evaluated in each ProcessOneBatch - // call. - double ewma_feedback_ = 0; - - TF_DISALLOW_COPY_AND_ASSIGN(AdaptiveSharedBatchScheduler); -}; - -////////////////////////////////////////////////////////// -// Implementation details follow. API users need not read. - -namespace internal { -// Consolidates tasks into batches, passing them off to the -// AdaptiveSharedBatchScheduler for processing. -template -class ASBSQueue : public BatchScheduler { - public: - using QueueOptions = - typename AdaptiveSharedBatchScheduler::QueueOptions; - - ASBSQueue(std::shared_ptr> scheduler, - const QueueOptions& options); - - ~ASBSQueue() override; - - // Adds task to current batch. Fails if the task size is larger than the batch - // size or if the current batch is full and this queue's number of outstanding - // batches is at its maximum. - Status Schedule(std::unique_ptr* task) override; - - // Number of tasks waiting to be scheduled. - size_t NumEnqueuedTasks() const override; - - // Number of size 1 tasks which could currently be scheduled without failing. - size_t SchedulingCapacity() const override; - - // Notifies queue that a batch is about to be scheduled; the queue should not - // place any more tasks in this batch. - void ReleaseBatch(const ASBSBatch* batch); - - private: - std::shared_ptr> scheduler_; - const QueueOptions options_; - // Owned by scheduler_. - ASBSBatch* current_batch_ GUARDED_BY(mu_) = nullptr; - int64 num_enqueued_batches_ GUARDED_BY(mu_) = 0; - int64 num_enqueued_tasks_ GUARDED_BY(mu_) = 0; - mutable mutex mu_; - TF_DISALLOW_COPY_AND_ASSIGN(ASBSQueue); -}; - -// Batch which remembers when and by whom it was created. -template -class ASBSBatch : public Batch { - public: - ASBSBatch(ASBSQueue* queue, int64 creation_time_micros) - : queue_(queue), creation_time_micros_(creation_time_micros) {} - - ~ASBSBatch() override {} - - ASBSQueue* queue() const { return queue_; } - - int64 creation_time_micros() const { return creation_time_micros_; } - - private: - ASBSQueue* queue_; - const int64 creation_time_micros_; - TF_DISALLOW_COPY_AND_ASSIGN(ASBSBatch); -}; -} // namespace internal - -// ---------------- AdaptiveSharedBatchScheduler ---------------- - -template -Status AdaptiveSharedBatchScheduler::Create( - const Options& options, - std::shared_ptr>* scheduler) { - if (options.num_batch_threads < 1) { - return errors::InvalidArgument("num_batch_threads must be positive; was ", - options.num_batch_threads); - } - if (options.min_scheduling_period_micros < 0) { - return errors::InvalidArgument( - "min_scheduling_period_micros must be >= 0; was ", - options.min_scheduling_period_micros); - } - if (options.min_scheduling_period_micros > - options.initial_scheduling_period_micros) { - return errors::InvalidArgument( - "initial_scheduling_period_micros (", - options.initial_scheduling_period_micros, - ") must be >= min_scheduling_period_micros (", - options.min_scheduling_period_micros, ")"); - } - if (options.initial_scheduling_period_micros > - options.max_scheduling_period_micros) { - return errors::InvalidArgument( - "initial_scheduling_period_micros (", - options.initial_scheduling_period_micros, - ") must be <= max_scheduling_period_micros (", - options.max_scheduling_period_micros, ")"); - } - if (options.feedback_smoothing_batches < 1) { - return errors::InvalidArgument( - "feedback_smoothing_batches must be positive; was ", - options.feedback_smoothing_batches); - } - scheduler->reset(new AdaptiveSharedBatchScheduler(options)); - return Status::OK(); -} - -template -AdaptiveSharedBatchScheduler::AdaptiveSharedBatchScheduler( - const Options& options) - : options_(options), - scheduling_period_(options.initial_scheduling_period_micros) { - PeriodicFunction::Options opts; - opts.thread_name_prefix = "scheduling_thread"; - opts.env = GetEnv(); - scheduling_thread_.reset( - new PeriodicFunction([this] { ProcessOneBatch(); }, 0, opts)); - batch_thread_pool_.reset(new thread::ThreadPool( - GetEnv(), options.thread_pool_name, options.num_batch_threads)); -} - -template -Status AdaptiveSharedBatchScheduler::AddQueue( - const QueueOptions& options, BatchProcessor process_batch_callback, - std::unique_ptr>* queue) { - if (options.max_batch_size <= 0) { - return errors::InvalidArgument("max_batch_size must be positive; was ", - options.max_batch_size); - } - if (options.max_enqueued_batches <= 0) { - return errors::InvalidArgument( - "max_enqueued_batches must be positive; was ", - options.max_enqueued_batches); - } - internal::ASBSQueue* asbs_queue_raw; - queue->reset(asbs_queue_raw = new internal::ASBSQueue( - this->shared_from_this(), options)); - mutex_lock l(mu_); - queues_and_callbacks_[asbs_queue_raw] = process_batch_callback; - return Status::OK(); -} - -template -void AdaptiveSharedBatchScheduler::AddBatch( - internal::ASBSBatch* batch) { - mutex_lock l(mu_); - batches_.push(batch); -} - -template -void AdaptiveSharedBatchScheduler::RemoveQueue( - const internal::ASBSQueue* queue) { - mutex_lock l(mu_); - queues_and_callbacks_.erase(queue); -} - -template -void AdaptiveSharedBatchScheduler::ProcessOneBatch() { - static const double kFeedbackMultiplier = .001; - internal::ASBSBatch* batch = nullptr; - BatchProcessor callback; - const int64 start_time_micros = GetEnv()->NowMicros(); - { - mutex_lock l(mu_); - if (!batches_.empty()) { - batch = batches_.top(); - batches_.pop(); - callback = queues_and_callbacks_[batch->queue()]; - } - } - if (batch != nullptr) { - double feedback = options_.scheduling_period_feedback(); - const int64 N = options_.feedback_smoothing_batches; - ewma_feedback_ = ((N - 1) * ewma_feedback_ + feedback) / N; - scheduling_period_ *= (1 + kFeedbackMultiplier * ewma_feedback_); - if (scheduling_period_ < options_.min_scheduling_period_micros) { - scheduling_period_ = options_.min_scheduling_period_micros; - } else if (scheduling_period_ > options_.max_scheduling_period_micros) { - scheduling_period_ = options_.max_scheduling_period_micros; - } - // Queue may destroy itself after ReleaseBatch is called. - batch->queue()->ReleaseBatch(batch); - batch_thread_pool_->Schedule([callback, batch] { - callback(std::unique_ptr>(batch)); - }); - } - const int64 sleep_time = - scheduling_period_ - (GetEnv()->NowMicros() - start_time_micros); - if (sleep_time > 0) { - GetEnv()->SleepForMicroseconds(sleep_time); - } -} - -template -bool AdaptiveSharedBatchScheduler::BatchCompare::operator()( - const internal::ASBSBatch* a, - const internal::ASBSBatch* b) { - return a->creation_time_micros() > b->creation_time_micros(); -} - -// ---------------- ASBSQueue ---------------- - -namespace internal { -template -ASBSQueue::ASBSQueue( - std::shared_ptr> scheduler, - const QueueOptions& options) - : scheduler_(scheduler), options_(options) {} - -template -ASBSQueue::~ASBSQueue() { - // Wait until last batch has been scheduled. - const int kSleepMicros = 1000; - for (;;) { - { - mutex_lock l(mu_); - if (num_enqueued_batches_ == 0) { - break; - } - } - scheduler_->GetEnv()->SleepForMicroseconds(kSleepMicros); - } - scheduler_->RemoveQueue(this); -} - -template -Status ASBSQueue::Schedule(std::unique_ptr* task) { - ASBSBatch* new_batch = nullptr; - size_t size = (*task)->size(); - if (size > options_.max_batch_size) { - return errors::InvalidArgument("Task size ", size, - " is larger than maximum batch size ", - options_.max_batch_size); - } - { - mutex_lock l(mu_); - // Current batch is full, create another if allowed. - if (current_batch_ && - current_batch_->size() + size > options_.max_batch_size) { - if (num_enqueued_batches_ >= options_.max_enqueued_batches) { - return errors::Unavailable("The batch scheduling queue is full"); - } - current_batch_->Close(); - current_batch_ = nullptr; - } - if (!current_batch_) { - num_enqueued_batches_++; - current_batch_ = new_batch = - new ASBSBatch(this, scheduler_->GetEnv()->NowMicros()); - } - current_batch_->AddTask(std::move(*task)); - num_enqueued_tasks_++; - } - if (new_batch != nullptr) scheduler_->AddBatch(new_batch); - return Status::OK(); -} - -template -void ASBSQueue::ReleaseBatch(const ASBSBatch* batch) { - mutex_lock l(mu_); - num_enqueued_batches_--; - num_enqueued_tasks_ -= batch->num_tasks(); - if (batch == current_batch_) { - current_batch_->Close(); - current_batch_ = nullptr; - } -} - -template -size_t ASBSQueue::NumEnqueuedTasks() const { - mutex_lock l(mu_); - return num_enqueued_tasks_; -} - -template -size_t ASBSQueue::SchedulingCapacity() const { - mutex_lock l(mu_); - const int current_batch_capacity = - current_batch_ ? options_.max_batch_size - current_batch_->size() : 0; - const int spare_batches = - options_.max_enqueued_batches - num_enqueued_batches_; - return spare_batches * options_.max_batch_size + current_batch_capacity; -} -} // namespace internal -} // namespace serving -} // namespace tensorflow +#include "tensorflow/core/kernels/batching_util/adaptive_shared_batch_scheduler.h" #endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_ADAPTIVE_SHARED_BATCH_SCHEDULER_H_ diff --git a/tensorflow/contrib/batching/basic_batch_scheduler.h b/tensorflow/contrib/batching/basic_batch_scheduler.h index 9d3805fbaf39978159dd2f4a754e6d41a07acf6a..63ba8fcf45d8e6caad14c267bb19c0bc4eea20bf 100644 --- a/tensorflow/contrib/batching/basic_batch_scheduler.h +++ b/tensorflow/contrib/batching/basic_batch_scheduler.h @@ -16,249 +16,6 @@ limitations under the License. #ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_BASIC_BATCH_SCHEDULER_H_ #define THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_BASIC_BATCH_SCHEDULER_H_ -#include -#include -#include -#include -#include - -#include "tensorflow/contrib/batching/shared_batch_scheduler.h" - -namespace tensorflow { -namespace serving { - -// A BatchScheduler implementation geared toward handling a single request type -// running on a specific set of hardware resources. A typical scenario is one in -// which all requests invoke the same machine-learned model on one GPU. -// -// If there are, say, two GPUs and two models each bound to one of the GPUs, one -// could use two BasicBatchScheduler instances to schedule the two model/GPU -// combinations independently. If multiple models must share a given GPU or -// other hardware resource, consider using SharedBatchScheduler instead. -// -// -// PARAMETERS AND BEHAVIOR: -// -// BasicBatchScheduler runs a fixed pool of threads, which it uses to process -// batches of tasks. It enforces a maximum batch size, and enqueues a bounded -// number of tasks. If the queue is nearly empty, such that a full batch cannot -// be formed, when a thread becomes free, it anyway schedules a batch -// immediately if a task has been in the queue for longer than a given timeout -// parameter. If the timeout parameter is set to 0, then the batch threads will -// always be kept busy (unless there are zero tasks waiting to be processed). -// -// For online serving, it is recommended to set the maximum number of enqueued -// batches worth of tasks equal to the number of batch threads, which allows -// enqueuing of enough tasks s.t. if every thread becomes available it can be -// kept busy, but no more. For bulk processing jobs and throughput-oriented -// benchmarks, you may want to set it much higher. -// -// When Schedule() is called, if the queue is full the call will fail with an -// UNAVAILABLE error (after which the client may retry again later). If the call -// succeeds, the maximum time the task will spend in the queue before being -// placed in a batch and assigned to a thread for processing, is the greater of: -// - the maximum time to process ceil(max_enqueued_batches/num_batch_threads) -// (1 in the recommended configuration) batches of previously-submitted tasks -// - the configured timeout parameter (which can be 0, as mentioned above) -// -// Unlike StreamingBatchScheduler, when BasicBatchScheduler assigns a batch to a -// thread, it closes the batch. The process-batch callback may assume that every -// batch it receives is closed at the outset. -// -// -// RECOMMENDED USE-CASES: -// -// BasicBatchScheduler is suitable for use-cases that feature a single kind of -// request (e.g. a server performing inference with a single machine-learned -// model, possibly evolving over time), with loose versioning semantics. -// Concretely, the following conditions should hold: -// -// A. All requests batched onto a given resource (e.g. a hardware accelerator, -// or a pool accelerators) are of the same type. For example, they all -// invoke the same machine-learned model. -// -// These variations are permitted: -// - The model may reside in a single servable, or it may be spread across -// multiple servables that are used in unison (e.g. a vocabulary lookup -// table servable and a tensorflow session servable). -// - The model's servable(s) may be static, or they may evolve over time -// (successive servable versions). -// - Zero or more of the servables are used in the request thread; the rest -// are used in the batch thread. In our running example, the vocabulary -// lookups and tensorflow runs may both be performed in the batch thread, -// or alternatively the vocabulary lookup may occur in the request thread -// with only the tensorflow run performed in the batch thread. -// -// In contrast, BasicBatchScheduler is not a good fit if the server -// hosts multiple distinct models running on a pool accelerators, with each -// request specifying which model it wants to use. BasicBatchScheduler -// has no facility to time-multiplex the batch threads across multiple -// models in a principled way. More basically, it cannot ensure that a given -// batch doesn't contain a mixture of requests for different models. -// -// B. Requests do not specify a particular version of the servable(s) that must -// be used. Instead, each request is content to use the "latest" version. -// -// BasicBatchScheduler does not constrain which requests get grouped -// together into a batch, so using this scheduler there is no way to achieve -// cohesion of versioned requests to version-specific batches. -// -// C. No servable version coordination needs to be performed between the -// request threads and the batch threads. Often, servables are only used in -// the batch threads, in which case this condition trivially holds. If -// servables are used in both threads, then the use-case must tolerate -// version skew across the servables used in the two kinds of threads. -// -// -// EXAMPLE USE-CASE FLOW: -// -// For such use-cases, request processing via BasicBatchScheduler generally -// follows this flow (given for illustration; variations are possible): -// 1. Optionally perform some pre-processing on each request in the request -// threads. -// 2. Route the requests to the batch scheduler, as batching::Task objects. -// (Since all requests are of the same type and are not versioned, the -// scheduler is free to group them into batches arbitrarily.) -// 3. Merge the requests into a single batched representation B. -// 4. Obtain handles to the servable(s) needed to process B. The simplest -// approach is to obtain the latest version of each servable. Alternatively, -// if cross-servable consistency is required (e.g. the vocabulary lookup -// table's version number must match that of the tensorflow session), -// identify an appropriate version number and obtain the servable handles -// accordingly. -// 5. Process B using the obtained servable handles, and split the result into -// individual per-request units. -// 6. Perform any post-processing in the batch thread and/or request thread. -// -// -// PERFORMANCE TUNING: See README.md. -// -template -class BasicBatchScheduler : public BatchScheduler { - public: - // TODO(b/25089730): Tune defaults based on best practices as they develop. - // (Keep them mirrored to the ones in SharedBatchScheduler::QueueOptions and - // SharedBatchScheduler::Options.) - struct Options { - // The maximum size of each batch. - // - // The scheduler may form batches of any size between 1 and this number - // (inclusive). If there is a need to quantize the batch sizes, i.e. only - // submit batches whose size is in a small set of allowed sizes, that can be - // done by adding padding in the process-batch callback. - int max_batch_size = 1000; - - // If a task has been enqueued for this amount of time (in microseconds), - // and a thread is available, the scheduler will immediately form a batch - // from enqueued tasks and assign the batch to the thread for processing, - // even if the batch's size is below 'max_batch_size'. - // - // This parameter offers a way to bound queue latency, so that a task isn't - // stuck in the queue indefinitely waiting for enough tasks to arrive to - // make a full batch. (The latency bound is given in the class documentation - // above.) - // - // The goal is to smooth out batch sizes under low request rates, and thus - // avoid latency spikes. - int64 batch_timeout_micros = 0; - - // The name to use for the pool of batch threads. - string thread_pool_name = {"batch_threads"}; - - // The number of threads to use to process batches. - // Must be >= 1, and should be tuned carefully. - int num_batch_threads = port::NumSchedulableCPUs(); - - // The maximum allowable number of enqueued (accepted by Schedule() but - // not yet being processed on a batch thread) tasks in terms of batches. - // If this limit is reached, Schedule() will return an UNAVAILABLE error. - // See the class documentation above for guidelines on how to tune this - // parameter. - int max_enqueued_batches = 10; - - // The following options are typically only overridden by test code. - - // The environment to use. - Env* env = Env::Default(); - }; - static Status Create(const Options& options, - std::function>)> - process_batch_callback, - std::unique_ptr* scheduler); - - ~BasicBatchScheduler() override = default; - - Status Schedule(std::unique_ptr* task) override; - size_t NumEnqueuedTasks() const override; - size_t SchedulingCapacity() const override; - - private: - explicit BasicBatchScheduler( - std::unique_ptr> shared_scheduler_queue); - - // This class is merely a thin wrapper around a SharedBatchScheduler with a - // single queue. - std::unique_ptr> shared_scheduler_queue_; - - TF_DISALLOW_COPY_AND_ASSIGN(BasicBatchScheduler); -}; - -////////// -// Implementation details follow. API users need not read. - -template -Status BasicBatchScheduler::Create( - const Options& options, - std::function>)> - process_batch_callback, - std::unique_ptr* scheduler) { - typename SharedBatchScheduler::Options shared_scheduler_options; - shared_scheduler_options.thread_pool_name = options.thread_pool_name; - shared_scheduler_options.num_batch_threads = options.num_batch_threads; - shared_scheduler_options.env = options.env; - std::shared_ptr> shared_scheduler; - TF_RETURN_IF_ERROR(SharedBatchScheduler::Create( - shared_scheduler_options, &shared_scheduler)); - - typename SharedBatchScheduler::QueueOptions - shared_scheduler_queue_options; - shared_scheduler_queue_options.max_batch_size = options.max_batch_size; - shared_scheduler_queue_options.batch_timeout_micros = - options.batch_timeout_micros; - shared_scheduler_queue_options.max_enqueued_batches = - options.max_enqueued_batches; - std::unique_ptr> shared_scheduler_queue; - TF_RETURN_IF_ERROR(shared_scheduler->AddQueue(shared_scheduler_queue_options, - process_batch_callback, - &shared_scheduler_queue)); - - scheduler->reset( - new BasicBatchScheduler(std::move(shared_scheduler_queue))); - return Status::OK(); -} - -template -Status BasicBatchScheduler::Schedule( - std::unique_ptr* task) { - return shared_scheduler_queue_->Schedule(task); -} - -template -size_t BasicBatchScheduler::NumEnqueuedTasks() const { - return shared_scheduler_queue_->NumEnqueuedTasks(); -} - -template -size_t BasicBatchScheduler::SchedulingCapacity() const { - return shared_scheduler_queue_->SchedulingCapacity(); -} - -template -BasicBatchScheduler::BasicBatchScheduler( - std::unique_ptr> shared_scheduler_queue) - : shared_scheduler_queue_(std::move(shared_scheduler_queue)) {} - -} // namespace serving -} // namespace tensorflow +#include "tensorflow/core/kernels/batching_util/basic_batch_scheduler.h" #endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_BASIC_BATCH_SCHEDULER_H_ diff --git a/tensorflow/contrib/batching/batch_scheduler.h b/tensorflow/contrib/batching/batch_scheduler.h index a5072f439abad3c5db79a514a7f2baff0b021b39..3afce2761f748136f4d556017823db8dbd4af50e 100644 --- a/tensorflow/contrib/batching/batch_scheduler.h +++ b/tensorflow/contrib/batching/batch_scheduler.h @@ -13,264 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// Abstractions for processing small tasks in a batched fashion, to reduce -// processing times and costs that can be amortized across multiple tasks. -// -// The core class is BatchScheduler, which groups tasks into batches. -// -// BatchScheduler encapsulates logic for aggregating multiple tasks into a -// batch, and kicking off processing of a batch on a thread pool it manages. -// -// This file defines an abstract BatchScheduler class. - #ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_BATCH_SCHEDULER_H_ #define THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_BATCH_SCHEDULER_H_ -#include -#include -#include -#include -#include -#include - -#include "tensorflow/core/lib/core/notification.h" -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/platform/mutex.h" -#include "tensorflow/core/platform/thread_annotations.h" -#include "tensorflow/core/platform/types.h" - -namespace tensorflow { -namespace serving { - -// The abstract superclass for a unit of work to be done as part of a batch. -// -// An implementing subclass typically contains (or points to): -// (a) input data; -// (b) a thread-safe completion signal (e.g. a Notification); -// (c) a place to store the outcome (success, or some error), upon completion; -// (d) a place to store the output data, upon success. -// -// Items (b), (c) and (d) are typically non-owned pointers to data homed -// elsewhere, because a task's ownership gets transferred to a BatchScheduler -// (see below) and it may be deleted as soon as it is done executing. -class BatchTask { - public: - virtual ~BatchTask() = default; - - // Returns the size of the task, in terms of how much it contributes to the - // size of a batch. (A batch's size is the sum of its task sizes.) - virtual size_t size() const = 0; -}; - -// A thread-safe collection of BatchTasks, to be executed together in some -// fashion. -// -// At a given time, a batch is either "open" or "closed": an open batch can -// accept new tasks; a closed one cannot. A batch is monotonic: initially it is -// open and tasks can be added to it; then it is closed and its set of tasks -// remains fixed for the remainder of its life. A closed batch cannot be re- -// opened. Tasks can never be removed from a batch. -// -// Type parameter TaskType must be a subclass of BatchTask. -template -class Batch { - public: - Batch() = default; - virtual ~Batch(); // Blocks until the batch is closed. - - // Appends 'task' to the batch. After calling AddTask(), the newly-added task - // can be accessed via task(num_tasks()-1) or mutable_task(num_tasks()-1). - // Dies if the batch is closed. - void AddTask(std::unique_ptr task); - - // Removes the most recently added task. Returns nullptr if the batch is - // empty. - std::unique_ptr RemoveTask(); - - // Returns the number of tasks in the batch. - int num_tasks() const; - - // Returns true iff the batch contains 0 tasks. - bool empty() const; - - // Returns a reference to the ith task (in terms of insertion order). - const TaskType& task(int i) const; - - // Returns a pointer to the ith task (in terms of insertion order). - TaskType* mutable_task(int i); - - // Returns the sum of the task sizes. - size_t size() const; - - // Returns true iff the batch is currently closed. - bool IsClosed() const; - - // Blocks until the batch is closed. - void WaitUntilClosed() const; - - // Marks the batch as closed. Dies if called more than once. - void Close(); - - private: - mutable mutex mu_; - - // The tasks in the batch. - std::vector> tasks_ GUARDED_BY(mu_); - - // The sum of the sizes of the tasks in 'tasks_'. - size_t size_ GUARDED_BY(mu_) = 0; - - // Whether the batch has been closed. - Notification closed_; - - TF_DISALLOW_COPY_AND_ASSIGN(Batch); -}; - -// An abstract batch scheduler class. Collects individual tasks into batches, -// and processes each batch on a pool of "batch threads" that it manages. The -// actual logic for processing a batch is accomplished via a callback. -// -// Type parameter TaskType must be a subclass of BatchTask. -template -class BatchScheduler { - public: - virtual ~BatchScheduler() = default; - - // Submits a task to be processed as part of a batch. - // - // Ownership of '*task' is transferred to the callee iff the method returns - // Status::OK. In that case, '*task' is left as nullptr. Otherwise, '*task' is - // left as-is. - // - // If no batch processing capacity is available to process this task at the - // present time, and any task queue maintained by the implementing subclass is - // full, this method returns an UNAVAILABLE error code. The client may retry - // later. - // - // Other problems, such as the task size being larger than the maximum batch - // size, yield other, permanent error types. - // - // In all cases, this method returns "quickly" without blocking for any - // substantial amount of time. If the method returns Status::OK, the task is - // processed asynchronously, and any errors that occur during the processing - // of the batch that includes the task can be reported to 'task'. - virtual Status Schedule(std::unique_ptr* task) = 0; - - // Returns the number of tasks that have been scheduled (i.e. accepted by - // Schedule()), but have yet to be handed to a thread for execution as part of - // a batch. Note that this returns the number of tasks, not the aggregate task - // size (so if there is one task of size 3 and one task of size 5, this method - // returns 2 rather than 8). - virtual size_t NumEnqueuedTasks() const = 0; - - // Returns a guaranteed number of size 1 tasks that can be Schedule()d without - // getting an UNAVAILABLE error. In a typical implementation, returns the - // available space on a queue. - // - // There are two important caveats: - // 1. The guarantee does not extend to varying-size tasks due to possible - // internal fragmentation of batches. - // 2. The guarantee only holds in a single-thread environment or critical - // section, i.e. if an intervening thread cannot call Schedule(). - // - // This method is useful for monitoring, or for guaranteeing a future slot in - // the schedule (but being mindful about the caveats listed above). - virtual size_t SchedulingCapacity() const = 0; -}; - -////////// -// Implementation details follow. API users need not read. - -template -Batch::~Batch() { - WaitUntilClosed(); -} - -template -void Batch::AddTask(std::unique_ptr task) { - DCHECK(!IsClosed()); - { - mutex_lock l(mu_); - size_ += task->size(); - tasks_.push_back(std::move(task)); - } -} - -template -std::unique_ptr Batch::RemoveTask() { - { - mutex_lock l(mu_); - if (tasks_.empty()) { - return nullptr; - } - std::unique_ptr task = std::move(tasks_.back()); - tasks_.pop_back(); - return task; - } -} - -template -int Batch::num_tasks() const { - { - mutex_lock l(mu_); - return tasks_.size(); - } -} - -template -bool Batch::empty() const { - { - mutex_lock l(mu_); - return tasks_.empty(); - } -} - -template -const TaskType& Batch::task(int i) const { - DCHECK_GE(i, 0); - { - mutex_lock l(mu_); - DCHECK_LT(i, tasks_.size()); - return *tasks_[i].get(); - } -} - -template -TaskType* Batch::mutable_task(int i) { - DCHECK_GE(i, 0); - { - mutex_lock l(mu_); - DCHECK_LT(i, tasks_.size()); - return tasks_[i].get(); - } -} - -template -size_t Batch::size() const { - { - mutex_lock l(mu_); - return size_; - } -} - -template -bool Batch::IsClosed() const { - return const_cast(&closed_)->HasBeenNotified(); -} - -template -void Batch::WaitUntilClosed() const { - const_cast(&closed_)->WaitForNotification(); -} - -template -void Batch::Close() { - closed_.Notify(); -} - -} // namespace serving -} // namespace tensorflow +#include "tensorflow/core/kernels/batching_util/batch_scheduler.h" #endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_BATCH_SCHEDULER_H_ diff --git a/tensorflow/contrib/batching/shared_batch_scheduler.h b/tensorflow/contrib/batching/shared_batch_scheduler.h index 41a3f99137ade2552432fee62ddce17d064148a4..7eb1e20c42283a38564f7686db0015f153f469ed 100644 --- a/tensorflow/contrib/batching/shared_batch_scheduler.h +++ b/tensorflow/contrib/batching/shared_batch_scheduler.h @@ -16,685 +16,6 @@ limitations under the License. #ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_SHARED_BATCH_SCHEDULER_H_ #define THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_SHARED_BATCH_SCHEDULER_H_ -#include -#include -#include -#include -#include -#include -#include -#include - -#include "tensorflow/contrib/batching/batch_scheduler.h" -#include "tensorflow/contrib/batching/util/periodic_function.h" -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/platform/cpu_info.h" -#include "tensorflow/core/platform/env.h" -#include "tensorflow/core/platform/thread_annotations.h" -#include "tensorflow/core/platform/types.h" - -namespace tensorflow { -namespace serving { -namespace internal { -template -class Queue; -} // namespace internal -} // namespace serving -} // namespace tensorflow - -namespace tensorflow { -namespace serving { - -// A batch scheduler for server instances that service multiple request types -// (e.g. multiple machine-learned models, or multiple versions of a model served -// concurrently), or even multiple distinct tasks for a given request. The -// scheduler multiplexes batches of different kinds of tasks onto a fixed-size -// thread pool (each batch contains tasks of a single type), in a carefully -// controlled manner. A common configuration is to set the number of threads -// equal to the number of hardware accelerator units, in which case the -// scheduler takes care of multiplexing the task types onto the shared hardware, -// in a manner that is both fair and efficient. -// -// Semantically, SharedBatchScheduler behaves like having N instances of -// BasicBatchScheduler (see basic_batch_scheduler.h), one per task type. The -// difference is that under the covers there is a single shared thread pool, -// instead of N independent ones, with their sharing deliberately coordinated. -// -// SharedBatchScheduler does not implement the BatchScheduler API; rather, it -// presents an abstraction of "queues", where each queue coresponds to one type -// of task. Tasks submitted to a given queue are placed in their own batches, -// and cannot be mixed with other tasks. Queues can be added and deleted -// dynamically, to accommodate e.g. versions of a model being brought up and -// down over the lifetime of a server. -// -// The batch thread pool round-robins through the queues, running one batch -// from a queue and then moving to the next queue. Each queue behaves like a -// BasicBatchScheduler instance, in the sense that it has maximum batch size and -// timeout parameters, which govern when a batch is eligible to be processed. -// -// Each queue is independently configured with a maximum size (in terms of the -// maximum number of batches worth of enqueued tasks). For online serving, it is -// recommended that the queue sizes be configured such that the sum of the sizes -// of the active queues roughly equal the number of batch threads. (The idea is -// that if all threads become available at roughly the same time, there will be -// enough enqueued work for them to take on, but no more.) -// -// If queue sizes are configured in the manner suggested above, the maximum time -// a task can spend in a queue before being placed in a batch and assigned to a -// thread for processing, is the greater of: -// - the maximum time to process one batch of tasks from any active queue -// - the configured timeout parameter for the task's queue (which can be 0) -// -// For bulk processing jobs and throughput-oriented benchmarks, you may want to -// set the maximum queue size to a large value. -// -// TODO(b/26539183): Support queue servicing policies other than round-robin. -// E.g. let each queue specify a "share" (an int >= 1), so e.g. with queues A -// and B having shares 1 and 2 respectively, the servicing pattern is ABBABB... -// -// -// PERFORMANCE TUNING: See README.md. -// -template -class SharedBatchScheduler - : public std::enable_shared_from_this> { - public: - // TODO(b/25089730): Tune defaults based on best practices as they develop. - struct Options { - // The name to use for the pool of batch threads. - string thread_pool_name = {"batch_threads"}; - - // The number of threads to use to process batches. - // Must be >= 1, and should be tuned carefully. - int num_batch_threads = port::NumSchedulableCPUs(); - - // The environment to use. - // (Typically only overridden by test code.) - Env* env = Env::Default(); - }; - // Ownership is shared between the caller of Create() and any queues created - // via AddQueue(). - static Status Create( - const Options& options, - std::shared_ptr>* scheduler); - - ~SharedBatchScheduler(); - - // Adds a queue to which tasks may be submitted. The returned queue implements - // the BatchScheduler API. Each queue has its own set of scheduling options, - // and its own callback to process batches of tasks submitted to the queue. - // - // The returned queue's destructor blocks until all tasks submitted to it have - // been processed. - struct QueueOptions { - // The maximum size of each batch. - // - // The scheduler may form batches of any size between 1 and this number - // (inclusive). If there is a need to quantize the batch sizes, i.e. only - // submit batches whose size is in a small set of allowed sizes, that can be - // done by adding padding in the process-batch callback. - int max_batch_size = 1000; - - // If a task has been enqueued for this amount of time (in microseconds), - // and a thread is available, the scheduler will immediately form a batch - // from enqueued tasks and assign the batch to the thread for processing, - // even if the batch's size is below 'max_batch_size'. - // - // This parameter offers a way to bound queue latency, so that a task isn't - // stuck in the queue indefinitely waiting for enough tasks to arrive to - // make a full batch. (The latency bound is given in the class documentation - // above.) - // - // The goal is to smooth out batch sizes under low request rates, and thus - // avoid latency spikes. - int64 batch_timeout_micros = 0; - - // The maximum allowable number of enqueued (accepted by Schedule() but - // not yet being processed on a batch thread) tasks in terms of batches. - // If this limit is reached, Schedule() will return an UNAVAILABLE error. - // See the class documentation above for guidelines on how to tune this - // parameter. - int max_enqueued_batches = 10; - }; - Status AddQueue(const QueueOptions& options, - std::function>)> - process_batch_callback, - std::unique_ptr>* queue); - - private: - explicit SharedBatchScheduler(const Options& options); - - // The code executed in 'batch_threads_'. Obtains a batch to process from the - // queue pointed to by 'next_queue_to_schedule_', and processes it. If that - // queue declines to provide a batch to process, moves onto the next queue. If - // no queues provide a batch to process, just sleeps briefly and exits. - void ThreadLogic(); - - const Options options_; - - mutex mu_; - - // A list of queues. (We use std::list instead of std::vector to ensure that - // iterators are not invalidated by adding/removing elements. It also offers - // efficient removal of elements from the middle.) - using QueueList = std::list>>; - - // All "active" queues, i.e. ones that either: - // - have not been removed, or - // - have been removed but are not yet empty. - QueueList queues_ GUARDED_BY(mu_); - - // An iterator over 'queues_', pointing to the queue from which the next - // available batch thread should grab work. - typename QueueList::iterator next_queue_to_schedule_ GUARDED_BY(mu_); - - // Used by idle batch threads to wait for work to enter the system. Notified - // whenever a batch becomes schedulable. - condition_variable schedulable_batch_cv_; - - // Threads that process batches obtained from the queues. - std::vector> batch_threads_; - - TF_DISALLOW_COPY_AND_ASSIGN(SharedBatchScheduler); -}; - -////////// -// Implementation details follow. API users need not read. - -namespace internal { - -// A task queue for SharedBatchScheduler. Accepts tasks and accumulates them -// into batches, and dispenses those batches to be processed via a "pull" -// interface. The queue's behavior is governed by maximum batch size, timeout -// and maximum queue length parameters; see their documentation in -// SharedBatchScheduler. -// -// The queue is implemented as a deque of batches, with these invariants: -// - The number of batches is between 1 and 'options_.max_enqueued_batches'. -// - The back-most batch is open; the rest are closed. -// -// Submitted tasks are added to the open batch. If that batch doesn't have room -// but the queue isn't full, then that batch is closed and a new open batch is -// started. -// -// Batch pull requests are handled by dequeuing the front-most batch if it is -// closed. If the front-most batch is open (i.e. the queue contains only one -// batch) and has reached the timeout, it is immediately closed and returned; -// otherwise no batch is returned for the request. -template -class Queue { - public: - using ProcessBatchCallback = - std::function>)>; - using SchedulableBatchCallback = std::function; - Queue(const typename SharedBatchScheduler::QueueOptions& options, - Env* env, ProcessBatchCallback process_batch_callback, - SchedulableBatchCallback schdulable_batch_callback); - - // Illegal to destruct unless the queue is empty. - ~Queue(); - - // Submits a task to the queue, with the same semantics as - // BatchScheduler::Schedule(). - Status Schedule(std::unique_ptr* task); - - // Returns the number of enqueued tasks, with the same semantics as - // BatchScheduler::NumEnqueuedTasks(). - size_t NumEnqueuedTasks() const; - - // Returns the queue capacity, with the same semantics as - // BatchScheduler::SchedulingCapacity(). - size_t SchedulingCapacity() const; - - // Called by a thread that is ready to process a batch, to request one from - // this queue. Either returns a batch that is ready to be processed, or - // nullptr if the queue declines to schedule a batch at this time. If it - // returns a batch, the batch is guaranteed to be closed. - std::unique_ptr> ScheduleBatch(); - - // Processes a batch that has been returned earlier by ScheduleBatch(). - void ProcessBatch(std::unique_ptr> batch); - - // Determines whether the queue is empty, i.e. has no tasks waiting or being - // processed. - bool IsEmpty() const; - - // Marks the queue closed, and waits until it is empty. - void CloseAndWaitUntilEmpty(); - - bool closed() const { - mutex_lock l(mu_); - return closed_; - } - - private: - // Same as IsEmpty(), but assumes the caller already holds a lock on 'mu_'. - bool IsEmptyInternal() const EXCLUSIVE_LOCKS_REQUIRED(mu_); - - // Closes the open batch residing at the back of 'batches_', and inserts a - // fresh open batch behind it. - void StartNewBatch() EXCLUSIVE_LOCKS_REQUIRED(mu_); - - // Determines whether the open batch residing at the back of 'batches_' is - // currently schedulable. - bool IsOpenBatchSchedulable() const EXCLUSIVE_LOCKS_REQUIRED(mu_); - - const typename SharedBatchScheduler::QueueOptions options_; - - // The environment to use. - Env* env_; - - // A callback invoked to processes a batch of work units. Always invoked from - // a batch thread. - ProcessBatchCallback process_batch_callback_; - - // A callback invoked to notify the scheduler that a new batch has become - // schedulable. - SchedulableBatchCallback schedulable_batch_callback_; - - mutable mutex mu_; - - // Whether this queue can accept new tasks. This variable is monotonic: it - // starts as false, and then at some point gets set to true and remains true - // for the duration of this object's life. - bool closed_ GUARDED_BY(mu_) = false; - - // The enqueued batches. See the invariants in the class comments above. - std::deque>> batches_ GUARDED_BY(mu_); - - // The time at which the first task was added to the open (back-most) batch - // in 'batches_'. Valid iff that batch contains at least one task. - uint64 open_batch_start_time_micros_ GUARDED_BY(mu_); - - // Whether this queue contains a batch that is eligible to be scheduled. Used - // to keep track of when to call 'schedulable_batch_callback_'. - bool schedulable_batch_ GUARDED_BY(mu_) = false; - - // The number of batches currently being processed by batch threads. - // Incremented in ScheduleBatch() and decremented in ProcessBatch(). - int num_batches_being_processed_ GUARDED_BY(mu_) = 0; - - // Used by CloseAndWaitUntilEmpty() to wait until the queue is empty, for the - // case in which the queue is not empty when CloseAndWaitUntilEmpty() starts. - // When ProcessBatch() dequeues the last batch and makes the queue empty, if - // 'empty_notification_' is non-null it calls 'empty_notification_->Notify()'. - Notification* empty_notification_ GUARDED_BY(mu_) = nullptr; - - TF_DISALLOW_COPY_AND_ASSIGN(Queue); -}; - -// A RAII-style object that points to a Queue and implements -// the BatchScheduler API. To be handed out to clients who call AddQueue(). -template -class QueueHandle : public BatchScheduler { - public: - QueueHandle(std::shared_ptr> scheduler, - Queue* queue); - ~QueueHandle() override; - - Status Schedule(std::unique_ptr* task) override; - size_t NumEnqueuedTasks() const override; - size_t SchedulingCapacity() const override; - - private: - // The scheduler that owns 'queue_'. - std::shared_ptr> scheduler_; - - // The queue this handle wraps. Owned by 'scheduler_', which keeps it alive at - // least until this class's destructor closes it. - Queue* queue_; - - TF_DISALLOW_COPY_AND_ASSIGN(QueueHandle); -}; - -} // namespace internal - -template -Status SharedBatchScheduler::Create( - const Options& options, - std::shared_ptr>* scheduler) { - if (options.num_batch_threads < 1) { - return errors::InvalidArgument("num_batch_threads must be positive; was ", - options.num_batch_threads); - } - scheduler->reset(new SharedBatchScheduler(options)); - return Status::OK(); -} - -template -SharedBatchScheduler::~SharedBatchScheduler() { - // Wait until the batch threads finish clearing out and deleting the closed - // queues. - for (;;) { - { - mutex_lock l(mu_); - if (queues_.empty()) { - break; - } - } - const int64 kSleepTimeMicros = 100; - options_.env->SleepForMicroseconds(kSleepTimeMicros); - } - // Delete the batch threads before allowing state the threads may access (e.g. - // 'mu_') to be deleted. - batch_threads_.clear(); -} - -template -Status SharedBatchScheduler::AddQueue( - const QueueOptions& options, - std::function>)> - process_batch_callback, - std::unique_ptr>* queue) { - if (options.max_batch_size <= 0) { - return errors::InvalidArgument("max_batch_size must be positive; was ", - options.max_batch_size); - } - if (options.batch_timeout_micros < 0) { - return errors::InvalidArgument( - "batch_timeout_micros must be non-negative; was ", - options.batch_timeout_micros); - } - if (options.max_enqueued_batches < 0) { - return errors::InvalidArgument( - "max_enqueued_batches must be non-negative; was ", - options.max_enqueued_batches); - } - - auto schedulable_batch_callback = [this] { - mutex_lock l(mu_); - schedulable_batch_cv_.notify_one(); - }; - auto internal_queue = - std::unique_ptr>(new internal::Queue( - options, options_.env, process_batch_callback, - schedulable_batch_callback)); - auto handle = std::unique_ptr>( - new internal::QueueHandle(this->shared_from_this(), - internal_queue.get())); - { - mutex_lock l(mu_); - queues_.push_back(std::move(internal_queue)); - if (next_queue_to_schedule_ == queues_.end()) { - next_queue_to_schedule_ = queues_.begin(); - } - } - *queue = std::move(handle); - return Status::OK(); -} - -template -SharedBatchScheduler::SharedBatchScheduler(const Options& options) - : options_(options), next_queue_to_schedule_(queues_.end()) { - // Kick off the batch threads. - PeriodicFunction::Options periodic_fn_options; - periodic_fn_options.thread_name_prefix = - strings::StrCat(options.thread_pool_name, "_"); - for (int i = 0; i < options.num_batch_threads; ++i) { - std::unique_ptr thread(new PeriodicFunction( - [this] { this->ThreadLogic(); }, - 0 /* function invocation interval time */, periodic_fn_options)); - batch_threads_.push_back(std::move(thread)); - } -} - -template -void SharedBatchScheduler::ThreadLogic() { - // A batch to process next (or nullptr if no work to do). - std::unique_ptr> batch_to_process; - // The queue with which 'batch_to_process' is associated. - internal::Queue* queue_for_batch = nullptr; - { - mutex_lock l(mu_); - - const int num_queues = queues_.size(); - for (int num_queues_tried = 0; - batch_to_process == nullptr && num_queues_tried < num_queues; - ++num_queues_tried) { - DCHECK(next_queue_to_schedule_ != queues_.end()); - - // If a closed queue responds to ScheduleBatch() with nullptr, the queue - // will never yield any further batches so we can drop it. To avoid a - // race, we take a snapshot of the queue's closedness state *before* - // calling ScheduleBatch(). - const bool queue_closed = (*next_queue_to_schedule_)->closed(); - - // Ask '*next_queue_to_schedule_' if it wants us to process a batch. - batch_to_process = (*next_queue_to_schedule_)->ScheduleBatch(); - if (batch_to_process != nullptr) { - queue_for_batch = next_queue_to_schedule_->get(); - } - - // Advance 'next_queue_to_schedule_'. - if (queue_closed && (*next_queue_to_schedule_)->IsEmpty() && - batch_to_process == nullptr) { - // We've encountered a closed queue with no work to do. Drop it. - DCHECK_NE(queue_for_batch, next_queue_to_schedule_->get()); - next_queue_to_schedule_ = queues_.erase(next_queue_to_schedule_); - } else { - ++next_queue_to_schedule_; - } - if (next_queue_to_schedule_ == queues_.end() && !queues_.empty()) { - // We've hit the end. Wrap to the first queue. - next_queue_to_schedule_ = queues_.begin(); - } - } - - if (batch_to_process == nullptr) { - // We couldn't find any work to do. Wait until a new batch becomes - // schedulable, or some time has elapsed, before checking again. - const int64 kTimeoutMillis = 1; // The smallest accepted granule of time. - WaitForMilliseconds(&l, &schedulable_batch_cv_, kTimeoutMillis); - return; - } - } - - queue_for_batch->ProcessBatch(std::move(batch_to_process)); -} - -namespace internal { - -template -Queue::Queue( - const typename SharedBatchScheduler::QueueOptions& options, - Env* env, ProcessBatchCallback process_batch_callback, - SchedulableBatchCallback schedulable_batch_callback) - : options_(options), - env_(env), - process_batch_callback_(process_batch_callback), - schedulable_batch_callback_(schedulable_batch_callback) { - // Create an initial, open batch. - batches_.emplace_back(new Batch); -} - -template -Queue::~Queue() { - mutex_lock l(mu_); - DCHECK(IsEmptyInternal()); - - // Close the (empty) open batch, so its destructor doesn't block. - batches_.back()->Close(); -} - -template -Status Queue::Schedule(std::unique_ptr* task) { - if ((*task)->size() > options_.max_batch_size) { - return errors::InvalidArgument("Task size ", (*task)->size(), - " is larger than maximum batch size ", - options_.max_batch_size); - } - - bool notify_of_schedulable_batch = false; - { - mutex_lock l(mu_); - - DCHECK(!closed_); - - if (batches_.back()->size() + (*task)->size() > options_.max_batch_size) { - if (batches_.size() >= options_.max_enqueued_batches) { - return errors::Unavailable( - "The batch scheduling queue to which this task was submitted is " - "full"); - } - StartNewBatch(); - } - if (batches_.back()->empty()) { - open_batch_start_time_micros_ = env_->NowMicros(); - } - batches_.back()->AddTask(std::move(*task)); - - if (!schedulable_batch_) { - if (batches_.size() > 1 || IsOpenBatchSchedulable()) { - schedulable_batch_ = true; - notify_of_schedulable_batch = true; - } - } - } - - if (notify_of_schedulable_batch) { - schedulable_batch_callback_(); - } - - return Status::OK(); -} - -template -size_t Queue::NumEnqueuedTasks() const { - mutex_lock l(mu_); - size_t num_enqueued_tasks = 0; - for (const auto& batch : batches_) { - num_enqueued_tasks += batch->num_tasks(); - } - return num_enqueued_tasks; -} - -template -size_t Queue::SchedulingCapacity() const { - mutex_lock l(mu_); - const int num_new_batches_schedulable = - options_.max_enqueued_batches - batches_.size(); - const int open_batch_capacity = - options_.max_batch_size - batches_.back()->size(); - return (num_new_batches_schedulable * options_.max_batch_size) + - open_batch_capacity; -} - -template -std::unique_ptr> Queue::ScheduleBatch() { - // The batch to schedule, which we may populate below. (If left as nullptr, - // that means we are electing not to schedule a batch at this time.) - std::unique_ptr> batch_to_schedule; - - { - mutex_lock l(mu_); - - // Consider closing the open batch at this time, to schedule it. - if (batches_.size() == 1 && IsOpenBatchSchedulable()) { - StartNewBatch(); - } - - if (batches_.size() >= 2) { - // There is at least one closed batch that is ready to be scheduled. - ++num_batches_being_processed_; - batch_to_schedule = std::move(batches_.front()); - batches_.pop_front(); - } else { - schedulable_batch_ = false; - } - } - - return batch_to_schedule; -} - -template -void Queue::ProcessBatch(std::unique_ptr> batch) { - process_batch_callback_(std::move(batch)); - - { - mutex_lock l(mu_); - --num_batches_being_processed_; - if (empty_notification_ != nullptr && IsEmptyInternal()) { - empty_notification_->Notify(); - } - } -} - -template -bool Queue::IsEmpty() const { - mutex_lock l(mu_); - return IsEmptyInternal(); -} - -template -void Queue::CloseAndWaitUntilEmpty() { - Notification empty; - { - mutex_lock l(mu_); - closed_ = true; - if (IsEmptyInternal()) { - empty.Notify(); - } else { - // Arrange for ProcessBatch() to notify when the queue becomes empty. - empty_notification_ = ∅ - } - } - empty.WaitForNotification(); -} - -template -bool Queue::IsEmptyInternal() const { - return num_batches_being_processed_ == 0 && batches_.size() == 1 && - batches_.back()->empty(); -} - -template -void Queue::StartNewBatch() { - batches_.back()->Close(); - batches_.emplace_back(new Batch); -} - -template -bool Queue::IsOpenBatchSchedulable() const { - Batch* open_batch = batches_.back().get(); - if (open_batch->empty()) { - return false; - } - return closed_ || open_batch->size() >= options_.max_batch_size || - env_->NowMicros() >= - open_batch_start_time_micros_ + options_.batch_timeout_micros; -} - -template -QueueHandle::QueueHandle( - std::shared_ptr> scheduler, - Queue* queue) - : scheduler_(scheduler), queue_(queue) {} - -template -QueueHandle::~QueueHandle() { - queue_->CloseAndWaitUntilEmpty(); -} - -template -Status QueueHandle::Schedule(std::unique_ptr* task) { - return queue_->Schedule(task); -} - -template -size_t QueueHandle::NumEnqueuedTasks() const { - return queue_->NumEnqueuedTasks(); -} - -template -size_t QueueHandle::SchedulingCapacity() const { - return queue_->SchedulingCapacity(); -} - -} // namespace internal - -} // namespace serving -} // namespace tensorflow +#include "tensorflow/core/kernels/batching_util/shared_batch_scheduler.h" #endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_SHARED_BATCH_SCHEDULER_H_ diff --git a/tensorflow/contrib/batching/test_util/BUILD b/tensorflow/contrib/batching/test_util/BUILD index d1ced0d8c367f44b520a9bba2db8a3e0969bab4c..6db627faad1df4a4b73082e74e7754829ff2b514 100644 --- a/tensorflow/contrib/batching/test_util/BUILD +++ b/tensorflow/contrib/batching/test_util/BUILD @@ -22,11 +22,9 @@ filegroup( cc_library( name = "fake_clock_env", testonly = 1, - srcs = ["fake_clock_env.cc"], hdrs = ["fake_clock_env.h"], visibility = ["//visibility:public"], deps = [ - "//tensorflow/core:lib", - "//tensorflow/core:tensorflow", + "//tensorflow/core/kernels/batching_util:fake_clock_env", ], ) diff --git a/tensorflow/contrib/batching/test_util/fake_clock_env.h b/tensorflow/contrib/batching/test_util/fake_clock_env.h index 35cafcb73c51feb4e9e15a61d1830c8ef6bc3e0f..ced27a88336324fb8c4be490138291d9234693f9 100644 --- a/tensorflow/contrib/batching/test_util/fake_clock_env.h +++ b/tensorflow/contrib/batching/test_util/fake_clock_env.h @@ -16,61 +16,6 @@ limitations under the License. #ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_TEST_UTIL_FAKE_CLOCK_ENV_H_ #define THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_TEST_UTIL_FAKE_CLOCK_ENV_H_ -#include -#include -#include - -#include "tensorflow/core/lib/core/notification.h" -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/platform/env.h" -#include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/platform/mutex.h" -#include "tensorflow/core/platform/thread_annotations.h" -#include "tensorflow/core/platform/types.h" - -namespace tensorflow { -namespace serving { -namespace test_util { - -// An Env implementation with a fake clock for NowMicros() and -// SleepForMicroseconds(). The clock doesn't advance on its own; it advances via -// an explicit Advance() method. -// All other Env virtual methods pass through to a wrapped Env. -class FakeClockEnv : public EnvWrapper { - public: - explicit FakeClockEnv(Env* wrapped); - ~FakeClockEnv() override = default; - - // Advance the clock by a certain number of microseconds. - void AdvanceByMicroseconds(int micros); - - // Blocks until there is a sleeping thread that is scheduled to wake up at - // the given (absolute) time. - void BlockUntilSleepingThread(uint64 wake_time); - - // Blocks until there are at least num_threads sleeping. - void BlockUntilThreadsAsleep(int num_threads); - - // Methods that this class implements. - uint64 NowMicros() override; - void SleepForMicroseconds(int64 micros) override; - - private: - mutex mu_; - - uint64 current_time_ GUARDED_BY(mu_) = 0; - - struct SleepingThread { - uint64 wake_time; - Notification* wake_notification; - }; - std::vector sleeping_threads_ GUARDED_BY(mu_); - - TF_DISALLOW_COPY_AND_ASSIGN(FakeClockEnv); -}; - -} // namespace test_util -} // namespace serving -} // namespace tensorflow +#include "tensorflow/core/kernels/batching_util/fake_clock_env.h" #endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_TEST_UTIL_FAKE_CLOCK_ENV_H_ diff --git a/tensorflow/contrib/batching/util/BUILD b/tensorflow/contrib/batching/util/BUILD index f33a08cb817e9f2832be953ef6ff1aba04c4c288..2a84a7712a8fa66e89db41ff4e7ebe4f620029ca 100644 --- a/tensorflow/contrib/batching/util/BUILD +++ b/tensorflow/contrib/batching/util/BUILD @@ -22,12 +22,11 @@ filegroup( cc_library( name = "periodic_function_dynamic", - srcs = ["periodic_function.cc"], hdrs = ["periodic_function.h"], visibility = ["//visibility:public"], deps = [ - "//tensorflow/core:framework_headers_lib", - "//tensorflow/core:protos_all_cc", + "//tensorflow/core/kernels/batching_util:periodic_function_dynamic", + "//third_party/eigen3", ], ) @@ -36,17 +35,6 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":periodic_function_dynamic", - "//tensorflow/core:lib", - ], -) - -tf_cc_test( - name = "periodic_function_test", - srcs = ["periodic_function_test.cc"], - deps = [ - ":periodic_function_dynamic", - "//tensorflow/contrib/batching/test_util:fake_clock_env", - "//tensorflow/core:test", - "//tensorflow/core:test_main", + "//tensorflow/core/kernels/batching_util:periodic_function", ], ) diff --git a/tensorflow/contrib/batching/util/periodic_function.h b/tensorflow/contrib/batching/util/periodic_function.h index 2c032d802fe5f23a267db28dc869a253f16afc34..fb61bc2eea2ec6eb560670148611c66ddc3d73df 100644 --- a/tensorflow/contrib/batching/util/periodic_function.h +++ b/tensorflow/contrib/batching/util/periodic_function.h @@ -12,121 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ - -// PeriodicFunction will periodically call the given function with a specified -// period in a background thread. After Start() returns, the thread is -// guaranteed to have started. The destruction of the class causes the -// background thread to be destroyed as well. Start() should not be called more -// than once. -// -// PeriodicFunction runs the function as soon as any previous run both is -// complete and was started more than "interval_micros" earlier. Thus, runs are -// both serialized, and normally have a period of "interval_micros" if no run -// exceeds the time. -// -// Note that, if the function takes longer than two interval_micross to finish, -// then PeriodicFunction will "skip" at least one call to the function. For -// instance, if the period is 50ms and the function starts runs at time 0 for -// 150ms, then the function will immediately start executing again at time 150, -// but there will be no function runs corresponding to times 50 or 100. This is -// especially important to remember when using an environment with a simulated -// clock: advancing simulated time atomically over N interval_micross will not -// cause the function to be called N times. -// -// This object is thread-safe. -// -// Example: -// -// class Foo { -// public: -// Foo() : periodic_function_([this]() { Bar(); }, -// 1000 /* 1000us == 1ms*/) { -// } -// -// private: -// void Bar() { ... } -// -// PeriodicFunction periodic_function_; -// }; - #ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_UTIL_PERIODIC_FUNCTION_H_ #define THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_UTIL_PERIODIC_FUNCTION_H_ -#include -#include -#include - -#include "tensorflow/core/lib/core/notification.h" -#include "tensorflow/core/platform/env.h" -#include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/platform/mutex.h" -#include "tensorflow/core/platform/thread_annotations.h" -#include "tensorflow/core/platform/types.h" - -namespace tensorflow { -namespace serving { - -namespace internal { -class PeriodicFunctionTestAccess; -} - -class PeriodicFunction { - public: - // Provides the ability to customize several aspects of the PeriodicFunction. - // Passed to constructor of PeriodicFunction. - struct Options { - Options() {} - - // Any standard thread options, such as stack size, should - // be passed via "thread_options". - ThreadOptions thread_options; - - // Specifies the thread name prefix (see the description in class - // Thread). - string thread_name_prefix = "periodic_function"; - - // The environment to use. Does not take ownership, but must remain alive - // for as long as the PeriodicFunction exists. - Env* env = Env::Default(); - - // Specifies the length of sleep before the first invocation of the - // function. - // This can be used for adding a random jitter to avoid synchronous behavior - // across multiple periodic functions. - int64 startup_delay_micros = 0; - }; - - // Also starts the background thread which will be calling the function. - PeriodicFunction(const std::function& function, int64 interval_micros, - const Options& options = Options()); - - ~PeriodicFunction(); - - private: - friend class internal::PeriodicFunctionTestAccess; - - // Notifies the background thread to stop. - void NotifyStop(); - - // (Blocking.) Loops forever calling "function_" every "interval_micros_". - void RunLoop(int64 start) LOCKS_EXCLUDED(mutex_); - - const std::function function_; // Actual client function - const int64 interval_micros_; // Interval between calls. - const Options options_; - - // Protects state below. - mutable mutex mutex_; - // Used to notify the thread to stop. - Notification stop_thread_; - - // Thread for running "function_" - std::unique_ptr thread_ GUARDED_BY(mutex_) = nullptr; - - TF_DISALLOW_COPY_AND_ASSIGN(PeriodicFunction); -}; - -} // namespace serving -} // namespace tensorflow +#include "tensorflow/core/kernels/batching_util/periodic_function.h" #endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_UTIL_PERIODIC_FUNCTION_H_ diff --git a/tensorflow/contrib/bayesflow/BUILD b/tensorflow/contrib/bayesflow/BUILD index a262d4aecdbb69dfcd8b88bc0a09060500d6b1c9..11c3c037c4e8b4ba41eae60d28d6aac49f1488f2 100644 --- a/tensorflow/contrib/bayesflow/BUILD +++ b/tensorflow/contrib/bayesflow/BUILD @@ -99,6 +99,25 @@ cuda_py_test( ], ) +cuda_py_test( + name = "layers_conv_variational_test", + size = "small", + srcs = ["python/kernel_tests/layers_conv_variational_test.py"], + additional_deps = [ + ":bayesflow_py", + "//third_party/py/numpy", + "//tensorflow/contrib/distributions:distributions_py", + "//tensorflow/python/ops/distributions", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:gradients", + "//tensorflow/python:linalg_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:nn_ops", + ], +) + cuda_py_test( name = "layers_dense_variational_test", size = "small", @@ -200,6 +219,28 @@ cuda_py_test( ], ) +cuda_py_test( + name = "variational_sgd_optimizer_test", + size = "small", + srcs = ["python/kernel_tests/variational_sgd_optimizer_test.py"], + additional_deps = [ + ":bayesflow_py", + "//third_party/py/numpy", + "//tensorflow/contrib/distributions:distributions_py", + "//tensorflow/contrib/layers:layers_py", + "//tensorflow/python/ops/distributions", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:gradients", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + "//tensorflow/python:random_seed", + ], + tags = ["notsan"], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/hmc_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/hmc_test.py index b1f108e5f01e4945ee83d8262f1d99877f0fe9f0..cbc66b6dc13db62c25952de6b6c13b2fdfe27f12 100644 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/hmc_test.py +++ b/tensorflow/contrib/bayesflow/python/kernel_tests/hmc_test.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for Hamiltonian Monte Carlo. -""" +"""Tests for Hamiltonian Monte Carlo.""" from __future__ import absolute_import from __future__ import division @@ -27,6 +26,7 @@ from tensorflow.contrib.bayesflow.python.ops import hmc from tensorflow.python.framework import random_seed from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.platform import test @@ -46,6 +46,9 @@ class HMCTest(test.TestCase): random_seed.set_random_seed(10003) np.random.seed(10003) + def assertAllFinite(self, x): + self.assertAllEqual(np.ones_like(x).astype(bool), np.isfinite(x)) + def _log_gamma_log_prob(self, x, event_dims=()): """Computes log-pdf of a log-gamma random variable. @@ -345,5 +348,97 @@ class HMCTest(test.TestCase): def testAIS12(self): self._ais_gets_correct_log_normalizer_wrapper([1, 2]) + def testNanRejection(self): + """Tests that an update that yields NaN potentials gets rejected. + + We run HMC with a target distribution that returns NaN + log-likelihoods if any element of x < 0, and unit-scale + exponential log-likelihoods otherwise. The exponential potential + pushes x towards 0, ensuring that any reasonably large update will + push us over the edge into NaN territory. + """ + def _unbounded_exponential_log_prob(x): + """An exponential distribution with log-likelihood NaN for x < 0.""" + per_element_potentials = array_ops.where(x < 0, + np.nan * array_ops.ones_like(x), + -x) + return math_ops.reduce_sum(per_element_potentials) + + with self.test_session() as sess: + initial_x = math_ops.linspace(0.01, 5, 10) + updated_x, acceptance_probs, _, _ = hmc.kernel( + 2., 5, initial_x, _unbounded_exponential_log_prob, [0]) + initial_x_val, updated_x_val, acceptance_probs_val = sess.run( + [initial_x, updated_x, acceptance_probs]) + + logging.vlog(1, 'initial_x = {}'.format(initial_x_val)) + logging.vlog(1, 'updated_x = {}'.format(updated_x_val)) + logging.vlog(1, 'acceptance_probs = {}'.format(acceptance_probs_val)) + + self.assertAllEqual(initial_x_val, updated_x_val) + self.assertEqual(acceptance_probs_val, 0.) + + def testNanFromGradsDontPropagate(self): + """Test that update with NaN gradients does not cause NaN in results.""" + def _nan_log_prob_with_nan_gradient(x): + return np.nan * math_ops.reduce_sum(x) + + with self.test_session() as sess: + initial_x = math_ops.linspace(0.01, 5, 10) + updated_x, acceptance_probs, new_log_prob, new_grad = hmc.kernel( + 2., 5, initial_x, _nan_log_prob_with_nan_gradient, [0]) + initial_x_val, updated_x_val, acceptance_probs_val = sess.run( + [initial_x, updated_x, acceptance_probs]) + + logging.vlog(1, 'initial_x = {}'.format(initial_x_val)) + logging.vlog(1, 'updated_x = {}'.format(updated_x_val)) + logging.vlog(1, 'acceptance_probs = {}'.format(acceptance_probs_val)) + + self.assertAllEqual(initial_x_val, updated_x_val) + self.assertEqual(acceptance_probs_val, 0.) + + self.assertAllFinite( + gradients_impl.gradients(updated_x, initial_x)[0].eval()) + self.assertTrue( + gradients_impl.gradients(new_grad, initial_x)[0] is None) + + # Gradients of the acceptance probs and new log prob are not finite. + _ = new_log_prob # Prevent unused arg error. + # self.assertAllFinite( + # gradients_impl.gradients(acceptance_probs, initial_x)[0].eval()) + # self.assertAllFinite( + # gradients_impl.gradients(new_log_prob, initial_x)[0].eval()) + + def testChainWorksIn64Bit(self): + def log_prob(x): + return - math_ops.reduce_sum(x * x, axis=-1) + states, acceptance_probs = hmc.chain( + n_iterations=10, + step_size=np.float64(0.01), + n_leapfrog_steps=10, + initial_x=np.zeros(5).astype(np.float64), + target_log_prob_fn=log_prob, + event_dims=[-1]) + with self.test_session() as sess: + states_, acceptance_probs_ = sess.run([states, acceptance_probs]) + self.assertEqual(np.float64, states_.dtype) + self.assertEqual(np.float64, acceptance_probs_.dtype) + + def testChainWorksIn16Bit(self): + def log_prob(x): + return - math_ops.reduce_sum(x * x, axis=-1) + states, acceptance_probs = hmc.chain( + n_iterations=10, + step_size=np.float16(0.01), + n_leapfrog_steps=10, + initial_x=np.zeros(5).astype(np.float16), + target_log_prob_fn=log_prob, + event_dims=[-1]) + with self.test_session() as sess: + states_, acceptance_probs_ = sess.run([states, acceptance_probs]) + self.assertEqual(np.float16, states_.dtype) + self.assertEqual(np.float16, acceptance_probs_.dtype) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/layers_conv_variational_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/layers_conv_variational_test.py new file mode 100644 index 0000000000000000000000000000000000000000..750afb6654311fea30a1dc6b31b20aa3b4160ae2 --- /dev/null +++ b/tensorflow/contrib/bayesflow/python/kernel_tests/layers_conv_variational_test.py @@ -0,0 +1,521 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for convolutional Bayesian layers.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.bayesflow.python.ops import layers_conv_variational as prob_layers_lib +from tensorflow.contrib.bayesflow.python.ops import layers_util as prob_layers_util +from tensorflow.contrib.distributions.python.ops import independent as independent_lib +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops.distributions import normal as normal_lib +from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.platform import test + + +class Counter(object): + """Helper class to manage incrementing a counting `int`.""" + + def __init__(self): + self._value = -1 + + @property + def value(self): + return self._value + + def __call__(self): + self._value += 1 + return self._value + + +class MockDistribution(independent_lib.Independent): + """Monitors layer calls to the underlying distribution.""" + + def __init__(self, result_sample, result_log_prob, loc=None, scale=None): + self.result_sample = result_sample + self.result_log_prob = result_log_prob + self.result_loc = loc + self.result_scale = scale + self.result_distribution = normal_lib.Normal(loc=0.0, scale=1.0) + if loc is not None and scale is not None: + self.result_distribution = normal_lib.Normal(loc=self.result_loc, + scale=self.result_scale) + self.called_log_prob = Counter() + self.called_sample = Counter() + self.called_loc = Counter() + self.called_scale = Counter() + + def log_prob(self, *args, **kwargs): + self.called_log_prob() + return self.result_log_prob + + def sample(self, *args, **kwargs): + self.called_sample() + return self.result_sample + + @property + def distribution(self): # for dummy check on Independent(Normal) + return self.result_distribution + + @property + def loc(self): + self.called_loc() + return self.result_loc + + @property + def scale(self): + self.called_scale() + return self.result_scale + + +class MockKLDivergence(object): + """Monitors layer calls to the divergence implementation.""" + + def __init__(self, result): + self.result = result + self.args = [] + self.called = Counter() + + def __call__(self, *args, **kwargs): + self.called() + self.args.append(args) + return self.result + + +class ConvVariational(test.TestCase): + + def _testKLPenaltyKernel(self, layer_class): + with self.test_session(): + layer = layer_class(filters=2, kernel_size=3) + if layer_class in (prob_layers_lib.Conv1DReparameterization, + prob_layers_lib.Conv1DFlipout): + inputs = random_ops.random_uniform([2, 3, 1], seed=1) + elif layer_class in (prob_layers_lib.Conv2DReparameterization, + prob_layers_lib.Conv2DFlipout): + inputs = random_ops.random_uniform([2, 3, 3, 1], seed=1) + elif layer_class in (prob_layers_lib.Conv3DReparameterization, + prob_layers_lib.Conv3DFlipout): + inputs = random_ops.random_uniform([2, 3, 3, 3, 1], seed=1) + + # No keys. + losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) + self.assertEqual(len(losses), 0) + self.assertListEqual(layer.losses, losses) + + _ = layer(inputs) + + # Yes keys. + losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) + self.assertEqual(len(losses), 1) + self.assertListEqual(layer.losses, losses) + + def _testKLPenaltyBoth(self, layer_class): + def _make_normal(dtype, *args): # pylint: disable=unused-argument + return normal_lib.Normal( + loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)) + with self.test_session(): + layer = layer_class( + filters=2, + kernel_size=3, + bias_posterior_fn=prob_layers_util.default_mean_field_normal_fn(), + bias_prior_fn=_make_normal) + if layer_class in (prob_layers_lib.Conv1DReparameterization, + prob_layers_lib.Conv1DFlipout): + inputs = random_ops.random_uniform([2, 3, 1], seed=1) + elif layer_class in (prob_layers_lib.Conv2DReparameterization, + prob_layers_lib.Conv2DFlipout): + inputs = random_ops.random_uniform([2, 3, 3, 1], seed=1) + elif layer_class in (prob_layers_lib.Conv3DReparameterization, + prob_layers_lib.Conv3DFlipout): + inputs = random_ops.random_uniform([2, 3, 3, 3, 1], seed=1) + + # No keys. + losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) + self.assertEqual(len(losses), 0) + self.assertListEqual(layer.losses, losses) + + _ = layer(inputs) + + # Yes keys. + losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) + self.assertEqual(len(losses), 2) + self.assertListEqual(layer.losses, losses) + + def _testConvSetUp(self, layer_class, batch_size, depth=None, + height=None, width=None, channels=None, filters=None, + **kwargs): + seed = Counter() + if layer_class in (prob_layers_lib.Conv1DReparameterization, + prob_layers_lib.Conv1DFlipout): + inputs = random_ops.random_uniform( + [batch_size, width, channels], seed=seed()) + kernel_size = (2,) + elif layer_class in (prob_layers_lib.Conv2DReparameterization, + prob_layers_lib.Conv2DFlipout): + inputs = random_ops.random_uniform( + [batch_size, height, width, channels], seed=seed()) + kernel_size = (2, 2) + elif layer_class in (prob_layers_lib.Conv3DReparameterization, + prob_layers_lib.Conv3DFlipout): + inputs = random_ops.random_uniform( + [batch_size, depth, height, width, channels], seed=seed()) + kernel_size = (2, 2, 2) + + kernel_shape = kernel_size + (channels, filters) + kernel_posterior = MockDistribution( + loc=random_ops.random_uniform(kernel_shape, seed=seed()), + scale=random_ops.random_uniform(kernel_shape, seed=seed()), + result_log_prob=random_ops.random_uniform(kernel_shape, seed=seed()), + result_sample=random_ops.random_uniform(kernel_shape, seed=seed())) + kernel_prior = MockDistribution( + result_log_prob=random_ops.random_uniform(kernel_shape, seed=seed()), + result_sample=random_ops.random_uniform(kernel_shape, seed=seed())) + kernel_divergence = MockKLDivergence( + result=random_ops.random_uniform(kernel_shape, seed=seed())) + + bias_size = (filters,) + bias_posterior = MockDistribution( + result_log_prob=random_ops.random_uniform(bias_size, seed=seed()), + result_sample=random_ops.random_uniform(bias_size, seed=seed())) + bias_prior = MockDistribution( + result_log_prob=random_ops.random_uniform(bias_size, seed=seed()), + result_sample=random_ops.random_uniform(bias_size, seed=seed())) + bias_divergence = MockKLDivergence( + result=random_ops.random_uniform(bias_size, seed=seed())) + + layer = layer_class( + filters=filters, + kernel_size=kernel_size, + padding="SAME", + kernel_posterior_fn=lambda *args: kernel_posterior, + kernel_posterior_tensor_fn=lambda d: d.sample(seed=42), + kernel_prior_fn=lambda *args: kernel_prior, + kernel_divergence_fn=kernel_divergence, + bias_posterior_fn=lambda *args: bias_posterior, + bias_posterior_tensor_fn=lambda d: d.sample(seed=43), + bias_prior_fn=lambda *args: bias_prior, + bias_divergence_fn=bias_divergence, + **kwargs) + + outputs = layer(inputs) + + kl_penalty = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) + return (kernel_posterior, kernel_prior, kernel_divergence, + bias_posterior, bias_prior, bias_divergence, + layer, inputs, outputs, kl_penalty, kernel_shape) + + def _testConvReparameterization(self, layer_class): + batch_size, depth, height, width, channels, filters = 2, 4, 4, 4, 3, 5 + with self.test_session() as sess: + (kernel_posterior, kernel_prior, kernel_divergence, + bias_posterior, bias_prior, bias_divergence, layer, inputs, + outputs, kl_penalty, kernel_shape) = self._testConvSetUp( + layer_class, batch_size, + depth=depth, height=height, width=width, channels=channels, + filters=filters) + + convolution_op = nn_ops.Convolution( + tensor_shape.TensorShape(inputs.shape), + filter_shape=tensor_shape.TensorShape(kernel_shape), + padding="SAME") + expected_outputs = convolution_op(inputs, kernel_posterior.result_sample) + expected_outputs = nn.bias_add(expected_outputs, + bias_posterior.result_sample, + data_format="NHWC") + + [ + expected_outputs_, actual_outputs_, + expected_kernel_, actual_kernel_, + expected_kernel_divergence_, actual_kernel_divergence_, + expected_bias_, actual_bias_, + expected_bias_divergence_, actual_bias_divergence_, + ] = sess.run([ + expected_outputs, outputs, + kernel_posterior.result_sample, layer.kernel_posterior_tensor, + kernel_divergence.result, kl_penalty[0], + bias_posterior.result_sample, layer.bias_posterior_tensor, + bias_divergence.result, kl_penalty[1], + ]) + + self.assertAllClose( + expected_kernel_, actual_kernel_, + rtol=1e-6, atol=0.) + self.assertAllClose( + expected_bias_, actual_bias_, + rtol=1e-6, atol=0.) + self.assertAllClose( + expected_outputs_, actual_outputs_, + rtol=1e-6, atol=0.) + self.assertAllClose( + expected_kernel_divergence_, actual_kernel_divergence_, + rtol=1e-6, atol=0.) + self.assertAllClose( + expected_bias_divergence_, actual_bias_divergence_, + rtol=1e-6, atol=0.) + + self.assertAllEqual( + [[kernel_posterior.distribution, + kernel_prior.distribution, + kernel_posterior.result_sample]], + kernel_divergence.args) + + self.assertAllEqual( + [[bias_posterior.distribution, + bias_prior.distribution, + bias_posterior.result_sample]], + bias_divergence.args) + + def _testConvFlipout(self, layer_class): + batch_size, depth, height, width, channels, filters = 2, 4, 4, 4, 3, 5 + with self.test_session() as sess: + (kernel_posterior, kernel_prior, kernel_divergence, + bias_posterior, bias_prior, bias_divergence, layer, inputs, + outputs, kl_penalty, kernel_shape) = self._testConvSetUp( + layer_class, batch_size, + depth=depth, height=height, width=width, channels=channels, + filters=filters, seed=44) + + convolution_op = nn_ops.Convolution( + tensor_shape.TensorShape(inputs.shape), + filter_shape=tensor_shape.TensorShape(kernel_shape), + padding="SAME") + + expected_kernel_posterior_affine = normal_lib.Normal( + loc=array_ops.zeros_like(kernel_posterior.result_loc), + scale=kernel_posterior.result_scale) + expected_kernel_posterior_affine_tensor = ( + expected_kernel_posterior_affine.sample(seed=42)) + + expected_outputs = convolution_op( + inputs, kernel_posterior.distribution.loc) + + input_shape = array_ops.shape(inputs) + output_shape = array_ops.shape(expected_outputs) + batch_shape = array_ops.expand_dims(input_shape[0], 0) + channels = input_shape[-1] + rank = len(inputs.get_shape()) - 2 + + sign_input = random_ops.random_uniform( + array_ops.concat([batch_shape, + array_ops.expand_dims(channels, 0)], 0), + minval=0, + maxval=2, + dtype=dtypes.int32, + seed=layer.seed) + sign_input = math_ops.cast(2 * sign_input - 1, inputs.dtype) + sign_output = random_ops.random_uniform( + array_ops.concat([batch_shape, + array_ops.expand_dims(filters, 0)], 0), + minval=0, + maxval=2, + dtype=dtypes.int32, + seed=distribution_util.gen_new_seed( + layer.seed, salt="conv_flipout")) + sign_output = math_ops.cast(2 * sign_output - 1, inputs.dtype) + for _ in range(rank): + sign_input = array_ops.expand_dims(sign_input, 1) # 2D ex: (B, 1, 1, C) + sign_output = array_ops.expand_dims(sign_output, 1) + + sign_input = array_ops.tile( # tile for element-wise op broadcasting + sign_input, + [1] + [input_shape[i + 1] for i in range(rank)] + [1]) + sign_output = array_ops.tile( + sign_output, + [1] + [output_shape[i + 1] for i in range(rank)] + [1]) + + perturbed_inputs = convolution_op( + inputs * sign_input, expected_kernel_posterior_affine_tensor) + perturbed_inputs *= sign_output + + expected_outputs += perturbed_inputs + expected_outputs = nn.bias_add(expected_outputs, + bias_posterior.result_sample, + data_format="NHWC") + + [ + expected_outputs_, actual_outputs_, + expected_kernel_divergence_, actual_kernel_divergence_, + expected_bias_, actual_bias_, + expected_bias_divergence_, actual_bias_divergence_, + ] = sess.run([ + expected_outputs, outputs, + kernel_divergence.result, kl_penalty[0], + bias_posterior.result_sample, layer.bias_posterior_tensor, + bias_divergence.result, kl_penalty[1], + ]) + + self.assertAllClose( + expected_bias_, actual_bias_, + rtol=1e-6, atol=0.) + self.assertAllClose( + expected_outputs_, actual_outputs_, + rtol=1e-6, atol=0.) + self.assertAllClose( + expected_kernel_divergence_, actual_kernel_divergence_, + rtol=1e-6, atol=0.) + self.assertAllClose( + expected_bias_divergence_, actual_bias_divergence_, + rtol=1e-6, atol=0.) + + self.assertAllEqual( + [[kernel_posterior.distribution, kernel_prior.distribution, None]], + kernel_divergence.args) + + self.assertAllEqual( + [[bias_posterior.distribution, + bias_prior.distribution, + bias_posterior.result_sample]], + bias_divergence.args) + + def _testRandomConvFlipout(self, layer_class): + batch_size, depth, height, width, channels, filters = 2, 4, 4, 4, 3, 5 + with self.test_session() as sess: + seed = Counter() + if layer_class in (prob_layers_lib.Conv1DReparameterization, + prob_layers_lib.Conv1DFlipout): + inputs = random_ops.random_uniform( + [batch_size, width, channels], seed=seed()) + kernel_size = (2,) + elif layer_class in (prob_layers_lib.Conv2DReparameterization, + prob_layers_lib.Conv2DFlipout): + inputs = random_ops.random_uniform( + [batch_size, height, width, channels], seed=seed()) + kernel_size = (2, 2) + elif layer_class in (prob_layers_lib.Conv3DReparameterization, + prob_layers_lib.Conv3DFlipout): + inputs = random_ops.random_uniform( + [batch_size, depth, height, width, channels], seed=seed()) + kernel_size = (2, 2, 2) + + kernel_shape = kernel_size + (channels, filters) + bias_size = (filters,) + + kernel_posterior = MockDistribution( + loc=random_ops.random_uniform( + kernel_shape, seed=seed()), + scale=random_ops.random_uniform( + kernel_shape, seed=seed()), + result_log_prob=random_ops.random_uniform( + kernel_shape, seed=seed()), + result_sample=random_ops.random_uniform( + kernel_shape, seed=seed())) + bias_posterior = MockDistribution( + loc=random_ops.random_uniform( + bias_size, seed=seed()), + scale=random_ops.random_uniform( + bias_size, seed=seed()), + result_log_prob=random_ops.random_uniform( + bias_size, seed=seed()), + result_sample=random_ops.random_uniform( + bias_size, seed=seed())) + layer_one = layer_class( + filters=filters, + kernel_size=kernel_size, + padding="SAME", + kernel_posterior_fn=lambda *args: kernel_posterior, + kernel_posterior_tensor_fn=lambda d: d.sample(seed=42), + bias_posterior_fn=lambda *args: bias_posterior, + bias_posterior_tensor_fn=lambda d: d.sample(seed=43), + seed=44) + layer_two = layer_class( + filters=filters, + kernel_size=kernel_size, + padding="SAME", + kernel_posterior_fn=lambda *args: kernel_posterior, + kernel_posterior_tensor_fn=lambda d: d.sample(seed=42), + bias_posterior_fn=lambda *args: bias_posterior, + bias_posterior_tensor_fn=lambda d: d.sample(seed=43), + seed=45) + + outputs_one = layer_one(inputs) + outputs_two = layer_two(inputs) + + outputs_one_, outputs_two_ = sess.run([ + outputs_one, outputs_two]) + + self.assertLess(np.sum(np.isclose(outputs_one_, outputs_two_)), + np.prod(outputs_one_.shape)) + + def testKLPenaltyKernelConv1DReparameterization(self): + self._testKLPenaltyKernel(prob_layers_lib.Conv1DReparameterization) + + def testKLPenaltyKernelConv2DReparameterization(self): + self._testKLPenaltyKernel(prob_layers_lib.Conv2DReparameterization) + + def testKLPenaltyKernelConv3DReparameterization(self): + self._testKLPenaltyKernel(prob_layers_lib.Conv3DReparameterization) + + def testKLPenaltyKernelConv1DFlipout(self): + self._testKLPenaltyKernel(prob_layers_lib.Conv1DFlipout) + + def testKLPenaltyKernelConv2DFlipout(self): + self._testKLPenaltyKernel(prob_layers_lib.Conv2DFlipout) + + def testKLPenaltyKernelConv3DFlipout(self): + self._testKLPenaltyKernel(prob_layers_lib.Conv3DFlipout) + + def testKLPenaltyBothConv1DReparameterization(self): + self._testKLPenaltyBoth(prob_layers_lib.Conv1DReparameterization) + + def testKLPenaltyBothConv2DReparameterization(self): + self._testKLPenaltyBoth(prob_layers_lib.Conv2DReparameterization) + + def testKLPenaltyBothConv3DReparameterization(self): + self._testKLPenaltyBoth(prob_layers_lib.Conv3DReparameterization) + + def testKLPenaltyBothConv1DFlipout(self): + self._testKLPenaltyBoth(prob_layers_lib.Conv1DFlipout) + + def testKLPenaltyBothConv2DFlipout(self): + self._testKLPenaltyBoth(prob_layers_lib.Conv2DFlipout) + + def testKLPenaltyBothConv3DFlipout(self): + self._testKLPenaltyBoth(prob_layers_lib.Conv3DFlipout) + + def testConv1DReparameterization(self): + self._testConvReparameterization(prob_layers_lib.Conv1DReparameterization) + + def testConv2DReparameterization(self): + self._testConvReparameterization(prob_layers_lib.Conv2DReparameterization) + + def testConv3DReparameterization(self): + self._testConvReparameterization(prob_layers_lib.Conv3DReparameterization) + + def testConv1DFlipout(self): + self._testConvFlipout(prob_layers_lib.Conv1DFlipout) + + def testConv2DFlipout(self): + self._testConvFlipout(prob_layers_lib.Conv2DFlipout) + + def testConv3DFlipout(self): + self._testConvFlipout(prob_layers_lib.Conv3DFlipout) + + def testRandomConv1DFlipout(self): + self._testRandomConvFlipout(prob_layers_lib.Conv1DFlipout) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/layers_dense_variational_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/layers_dense_variational_test.py index 50358fd1c2b7635ffe2d08c5af3219bb0a11498b..342f38ccec7ec74db1b393d6cdc22300205cc547 100644 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/layers_dense_variational_test.py +++ b/tensorflow/contrib/bayesflow/python/kernel_tests/layers_dense_variational_test.py @@ -18,11 +18,18 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.bayesflow.python.ops import layers_dense_variational_impl as prob_layers_lib +import numpy as np + +from tensorflow.contrib.bayesflow.python.ops import layers_dense_variational as prob_layers_lib +from tensorflow.contrib.bayesflow.python.ops import layers_util as prob_layers_util +from tensorflow.contrib.distributions.python.ops import independent as independent_lib +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import normal as normal_lib +from tensorflow.python.ops.distributions import util as distribution_util from tensorflow.python.platform import test @@ -41,14 +48,18 @@ class Counter(object): return self._value -class MockDistribution(normal_lib.Normal): - """Monitors DenseVariational calls to the underlying distribution.""" +class MockDistribution(independent_lib.Independent): + """Monitors layer calls to the underlying distribution.""" def __init__(self, result_sample, result_log_prob, loc=None, scale=None): self.result_sample = result_sample self.result_log_prob = result_log_prob self.result_loc = loc self.result_scale = scale + self.result_distribution = normal_lib.Normal(loc=0.0, scale=1.0) + if loc is not None and scale is not None: + self.result_distribution = normal_lib.Normal(loc=self.result_loc, + scale=self.result_scale) self.called_log_prob = Counter() self.called_sample = Counter() self.called_loc = Counter() @@ -62,6 +73,10 @@ class MockDistribution(normal_lib.Normal): self.called_sample() return self.result_sample + @property + def distribution(self): # for dummy check on Independent(Normal) + return self.result_distribution + @property def loc(self): self.called_loc() @@ -74,7 +89,7 @@ class MockDistribution(normal_lib.Normal): class MockKLDivergence(object): - """Monitors DenseVariational calls to the divergence implementation.""" + """Monitors layer calls to the divergence implementation.""" def __init__(self, result): self.result = result @@ -87,94 +102,125 @@ class MockKLDivergence(object): return self.result -class DenseVariationalLocalReparametrization(test.TestCase): +class DenseVariational(test.TestCase): - def testKLPenaltyKernel(self): + def _testKLPenaltyKernel(self, layer_class): with self.test_session(): - dense_vi = prob_layers_lib.DenseVariational(units=2) + layer = layer_class(units=2) inputs = random_ops.random_uniform([2, 3], seed=1) # No keys. - loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) - self.assertEqual(len(loss_keys), 0) - self.assertListEqual(dense_vi.losses, loss_keys) + losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) + self.assertEqual(len(losses), 0) + self.assertListEqual(layer.losses, losses) - _ = dense_vi(inputs) + _ = layer(inputs) # Yes keys. - loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) - self.assertEqual(len(loss_keys), 1) - self.assertListEqual(dense_vi.losses, loss_keys) + losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) + self.assertEqual(len(losses), 1) + self.assertListEqual(layer.losses, losses) - def testKLPenaltyBoth(self): + def _testKLPenaltyBoth(self, layer_class): def _make_normal(dtype, *args): # pylint: disable=unused-argument return normal_lib.Normal( loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)) with self.test_session(): - dense_vi = prob_layers_lib.DenseVariational( + layer = layer_class( units=2, - bias_posterior_fn=prob_layers_lib.default_mean_field_normal_fn(), + bias_posterior_fn=prob_layers_util.default_mean_field_normal_fn(), bias_prior_fn=_make_normal) inputs = random_ops.random_uniform([2, 3], seed=1) # No keys. - loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) - self.assertEqual(len(loss_keys), 0) - self.assertListEqual(dense_vi.losses, loss_keys) + losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) + self.assertEqual(len(losses), 0) + self.assertListEqual(layer.losses, losses) - _ = dense_vi(inputs) + _ = layer(inputs) # Yes keys. - loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) - self.assertEqual(len(loss_keys), 2) - self.assertListEqual(dense_vi.losses, loss_keys) - - def testVariationalNonLocal(self): + losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) + self.assertEqual(len(losses), 2) + self.assertListEqual(layer.losses, losses) + + def _testDenseSetUp(self, layer_class, batch_size, in_size, out_size, + **kwargs): + seed = Counter() + inputs = random_ops.random_uniform([batch_size, in_size], seed=seed()) + + kernel_size = [in_size, out_size] + kernel_posterior = MockDistribution( + loc=random_ops.random_uniform(kernel_size, seed=seed()), + scale=random_ops.random_uniform(kernel_size, seed=seed()), + result_log_prob=random_ops.random_uniform(kernel_size, seed=seed()), + result_sample=random_ops.random_uniform(kernel_size, seed=seed())) + kernel_prior = MockDistribution( + result_log_prob=random_ops.random_uniform(kernel_size, seed=seed()), + result_sample=random_ops.random_uniform(kernel_size, seed=seed())) + kernel_divergence = MockKLDivergence( + result=random_ops.random_uniform(kernel_size, seed=seed())) + + bias_size = [out_size] + bias_posterior = MockDistribution( + result_log_prob=random_ops.random_uniform(bias_size, seed=seed()), + result_sample=random_ops.random_uniform(bias_size, seed=seed())) + bias_prior = MockDistribution( + result_log_prob=random_ops.random_uniform(bias_size, seed=seed()), + result_sample=random_ops.random_uniform(bias_size, seed=seed())) + bias_divergence = MockKLDivergence( + result=random_ops.random_uniform(bias_size, seed=seed())) + + layer = layer_class( + units=out_size, + kernel_posterior_fn=lambda *args: kernel_posterior, + kernel_posterior_tensor_fn=lambda d: d.sample(seed=42), + kernel_prior_fn=lambda *args: kernel_prior, + kernel_divergence_fn=kernel_divergence, + bias_posterior_fn=lambda *args: bias_posterior, + bias_posterior_tensor_fn=lambda d: d.sample(seed=43), + bias_prior_fn=lambda *args: bias_prior, + bias_divergence_fn=bias_divergence, + **kwargs) + + outputs = layer(inputs) + + kl_penalty = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) + return (kernel_posterior, kernel_prior, kernel_divergence, + bias_posterior, bias_prior, bias_divergence, + layer, inputs, outputs, kl_penalty) + + def testKLPenaltyKernelReparameterization(self): + self._testKLPenaltyKernel(prob_layers_lib.DenseReparameterization) + + def testKLPenaltyKernelLocalReparameterization(self): + self._testKLPenaltyKernel(prob_layers_lib.DenseLocalReparameterization) + + def testKLPenaltyKernelFlipout(self): + self._testKLPenaltyKernel(prob_layers_lib.DenseFlipout) + + def testKLPenaltyBothReparameterization(self): + self._testKLPenaltyBoth(prob_layers_lib.DenseReparameterization) + + def testKLPenaltyBothLocalReparameterization(self): + self._testKLPenaltyBoth(prob_layers_lib.DenseLocalReparameterization) + + def testKLPenaltyBothFlipout(self): + self._testKLPenaltyBoth(prob_layers_lib.DenseFlipout) + + def testDenseReparameterization(self): batch_size, in_size, out_size = 2, 3, 4 with self.test_session() as sess: - seed = Counter() - inputs = random_ops.random_uniform([batch_size, in_size], seed=seed()) - - kernel_size = [in_size, out_size] - kernel_posterior = MockDistribution( - result_log_prob=random_ops.random_uniform(kernel_size, seed=seed()), - result_sample=random_ops.random_uniform(kernel_size, seed=seed())) - kernel_prior = MockDistribution( - result_log_prob=random_ops.random_uniform(kernel_size, seed=seed()), - result_sample=random_ops.random_uniform(kernel_size, seed=seed())) - kernel_divergence = MockKLDivergence( - result=random_ops.random_uniform(kernel_size, seed=seed())) - - bias_size = [out_size] - bias_posterior = MockDistribution( - result_log_prob=random_ops.random_uniform(bias_size, seed=seed()), - result_sample=random_ops.random_uniform(bias_size, seed=seed())) - bias_prior = MockDistribution( - result_log_prob=random_ops.random_uniform(bias_size, seed=seed()), - result_sample=random_ops.random_uniform(bias_size, seed=seed())) - bias_divergence = MockKLDivergence( - result=random_ops.random_uniform(bias_size, seed=seed())) + (kernel_posterior, kernel_prior, kernel_divergence, + bias_posterior, bias_prior, bias_divergence, layer, inputs, + outputs, kl_penalty) = self._testDenseSetUp( + prob_layers_lib.DenseReparameterization, + batch_size, in_size, out_size) expected_outputs = ( math_ops.matmul(inputs, kernel_posterior.result_sample) + bias_posterior.result_sample) - dense_vi = prob_layers_lib.DenseVariational( - units=2, - kernel_use_local_reparameterization=False, - kernel_posterior_fn=lambda *args: kernel_posterior, - kernel_posterior_tensor_fn=lambda d: d.sample(seed=42), - kernel_prior_fn=lambda *args: kernel_prior, - kernel_divergence_fn=kernel_divergence, - bias_posterior_fn=lambda *args: bias_posterior, - bias_posterior_tensor_fn=lambda d: d.sample(seed=43), - bias_prior_fn=lambda *args: bias_prior, - bias_divergence_fn=bias_divergence) - - outputs = dense_vi(inputs) - - kl_penalty = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) - [ expected_outputs_, actual_outputs_, expected_kernel_, actual_kernel_, @@ -183,9 +229,9 @@ class DenseVariationalLocalReparametrization(test.TestCase): expected_bias_divergence_, actual_bias_divergence_, ] = sess.run([ expected_outputs, outputs, - kernel_posterior.result_sample, dense_vi.kernel.posterior_tensor, + kernel_posterior.result_sample, layer.kernel_posterior_tensor, kernel_divergence.result, kl_penalty[0], - bias_posterior.result_sample, dense_vi.bias.posterior_tensor, + bias_posterior.result_sample, layer.bias_posterior_tensor, bias_divergence.result, kl_penalty[1], ]) @@ -206,40 +252,25 @@ class DenseVariationalLocalReparametrization(test.TestCase): rtol=1e-6, atol=0.) self.assertAllEqual( - [[kernel_posterior, kernel_prior, kernel_posterior.result_sample]], + [[kernel_posterior.distribution, + kernel_prior.distribution, + kernel_posterior.result_sample]], kernel_divergence.args) self.assertAllEqual( - [[bias_posterior, bias_prior, bias_posterior.result_sample]], + [[bias_posterior.distribution, + bias_prior.distribution, + bias_posterior.result_sample]], bias_divergence.args) - def testVariationalLocal(self): + def testDenseLocalReparameterization(self): batch_size, in_size, out_size = 2, 3, 4 with self.test_session() as sess: - seed = Counter() - inputs = random_ops.random_uniform([batch_size, in_size], seed=seed()) - - kernel_size = [in_size, out_size] - kernel_posterior = MockDistribution( - loc=random_ops.random_uniform(kernel_size, seed=seed()), - scale=random_ops.random_uniform(kernel_size, seed=seed()), - result_log_prob=random_ops.random_uniform(kernel_size, seed=seed()), - result_sample=random_ops.random_uniform(kernel_size, seed=seed())) - kernel_prior = MockDistribution( - result_log_prob=random_ops.random_uniform(kernel_size, seed=seed()), - result_sample=random_ops.random_uniform(kernel_size, seed=seed())) - kernel_divergence = MockKLDivergence( - result=random_ops.random_uniform(kernel_size, seed=seed())) - - bias_size = [out_size] - bias_posterior = MockDistribution( - result_log_prob=random_ops.random_uniform(bias_size, seed=seed()), - result_sample=random_ops.random_uniform(bias_size, seed=seed())) - bias_prior = MockDistribution( - result_log_prob=random_ops.random_uniform(bias_size, seed=seed()), - result_sample=random_ops.random_uniform(bias_size, seed=seed())) - bias_divergence = MockKLDivergence( - result=random_ops.random_uniform(bias_size, seed=seed())) + (kernel_posterior, kernel_prior, kernel_divergence, + bias_posterior, bias_prior, bias_divergence, layer, inputs, + outputs, kl_penalty) = self._testDenseSetUp( + prob_layers_lib.DenseLocalReparameterization, + batch_size, in_size, out_size) expected_kernel_posterior_affine = normal_lib.Normal( loc=math_ops.matmul(inputs, kernel_posterior.result_loc), @@ -250,21 +281,80 @@ class DenseVariationalLocalReparametrization(test.TestCase): expected_outputs = (expected_kernel_posterior_affine_tensor + bias_posterior.result_sample) - dense_vi = prob_layers_lib.DenseVariational( - units=2, - kernel_use_local_reparameterization=True, - kernel_posterior_fn=lambda *args: kernel_posterior, - kernel_posterior_tensor_fn=lambda d: d.sample(seed=42), - kernel_prior_fn=lambda *args: kernel_prior, - kernel_divergence_fn=kernel_divergence, - bias_posterior_fn=lambda *args: bias_posterior, - bias_posterior_tensor_fn=lambda d: d.sample(seed=43), - bias_prior_fn=lambda *args: bias_prior, - bias_divergence_fn=bias_divergence) + [ + expected_outputs_, actual_outputs_, + expected_kernel_divergence_, actual_kernel_divergence_, + expected_bias_, actual_bias_, + expected_bias_divergence_, actual_bias_divergence_, + ] = sess.run([ + expected_outputs, outputs, + kernel_divergence.result, kl_penalty[0], + bias_posterior.result_sample, layer.bias_posterior_tensor, + bias_divergence.result, kl_penalty[1], + ]) - outputs = dense_vi(inputs) + self.assertAllClose( + expected_bias_, actual_bias_, + rtol=1e-6, atol=0.) + self.assertAllClose( + expected_outputs_, actual_outputs_, + rtol=1e-6, atol=0.) + self.assertAllClose( + expected_kernel_divergence_, actual_kernel_divergence_, + rtol=1e-6, atol=0.) + self.assertAllClose( + expected_bias_divergence_, actual_bias_divergence_, + rtol=1e-6, atol=0.) + + self.assertAllEqual( + [[kernel_posterior.distribution, + kernel_prior.distribution, + None]], + kernel_divergence.args) + + self.assertAllEqual( + [[bias_posterior.distribution, + bias_prior.distribution, + bias_posterior.result_sample]], + bias_divergence.args) + + def testDenseFlipout(self): + batch_size, in_size, out_size = 2, 3, 4 + with self.test_session() as sess: + (kernel_posterior, kernel_prior, kernel_divergence, + bias_posterior, bias_prior, bias_divergence, layer, inputs, + outputs, kl_penalty) = self._testDenseSetUp( + prob_layers_lib.DenseFlipout, + batch_size, in_size, out_size, seed=44) + + expected_kernel_posterior_affine = normal_lib.Normal( + loc=array_ops.zeros_like(kernel_posterior.result_loc), + scale=kernel_posterior.result_scale) + expected_kernel_posterior_affine_tensor = ( + expected_kernel_posterior_affine.sample(seed=42)) - kl_penalty = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) + sign_input = random_ops.random_uniform( + [batch_size, in_size], + minval=0, + maxval=2, + dtype=dtypes.int32, + seed=layer.seed) + sign_input = math_ops.cast(2 * sign_input - 1, inputs.dtype) + sign_output = random_ops.random_uniform( + [batch_size, out_size], + minval=0, + maxval=2, + dtype=dtypes.int32, + seed=distribution_util.gen_new_seed( + layer.seed, salt="dense_flipout")) + sign_output = math_ops.cast(2 * sign_output - 1, inputs.dtype) + perturbed_inputs = math_ops.matmul( + inputs * sign_input, expected_kernel_posterior_affine_tensor) + perturbed_inputs *= sign_output + + expected_outputs = math_ops.matmul(inputs, kernel_posterior.result_loc) + expected_outputs += perturbed_inputs + expected_outputs += bias_posterior.result_sample [ expected_outputs_, actual_outputs_, @@ -274,7 +364,7 @@ class DenseVariationalLocalReparametrization(test.TestCase): ] = sess.run([ expected_outputs, outputs, kernel_divergence.result, kl_penalty[0], - bias_posterior.result_sample, dense_vi.bias.posterior_tensor, + bias_posterior.result_sample, layer.bias_posterior_tensor, bias_divergence.result, kl_penalty[1], ]) @@ -292,13 +382,62 @@ class DenseVariationalLocalReparametrization(test.TestCase): rtol=1e-6, atol=0.) self.assertAllEqual( - [[kernel_posterior, kernel_prior, None]], + [[kernel_posterior.distribution, kernel_prior.distribution, None]], kernel_divergence.args) self.assertAllEqual( - [[bias_posterior, bias_prior, bias_posterior.result_sample]], + [[bias_posterior.distribution, + bias_prior.distribution, + bias_posterior.result_sample]], bias_divergence.args) + def testRandomDenseFlipout(self): + batch_size, in_size, out_size = 2, 3, 4 + with self.test_session() as sess: + seed = Counter() + inputs = random_ops.random_uniform([batch_size, in_size], seed=seed()) + + kernel_posterior = MockDistribution( + loc=random_ops.random_uniform( + [in_size, out_size], seed=seed()), + scale=random_ops.random_uniform( + [in_size, out_size], seed=seed()), + result_log_prob=random_ops.random_uniform( + [in_size, out_size], seed=seed()), + result_sample=random_ops.random_uniform( + [in_size, out_size], seed=seed())) + bias_posterior = MockDistribution( + loc=random_ops.random_uniform( + [out_size], seed=seed()), + scale=random_ops.random_uniform( + [out_size], seed=seed()), + result_log_prob=random_ops.random_uniform( + [out_size], seed=seed()), + result_sample=random_ops.random_uniform( + [out_size], seed=seed())) + layer_one = prob_layers_lib.DenseFlipout( + units=out_size, + kernel_posterior_fn=lambda *args: kernel_posterior, + kernel_posterior_tensor_fn=lambda d: d.sample(seed=42), + bias_posterior_fn=lambda *args: bias_posterior, + bias_posterior_tensor_fn=lambda d: d.sample(seed=43), + seed=44) + layer_two = prob_layers_lib.DenseFlipout( + units=out_size, + kernel_posterior_fn=lambda *args: kernel_posterior, + kernel_posterior_tensor_fn=lambda d: d.sample(seed=42), + bias_posterior_fn=lambda *args: bias_posterior, + bias_posterior_tensor_fn=lambda d: d.sample(seed=43), + seed=45) + + outputs_one = layer_one(inputs) + outputs_two = layer_two(inputs) + + outputs_one_, outputs_two_ = sess.run([ + outputs_one, outputs_two]) + + self.assertLess(np.sum(np.isclose(outputs_one_, outputs_two_)), out_size) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/sgld_optimizer_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/sgld_optimizer_test.py index 66793383fdd5c71f136900197a91be6966e2f8c7..756c25683bd4b0c8c77e9e28485ca2a85582999c 100644 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/sgld_optimizer_test.py +++ b/tensorflow/contrib/bayesflow/python/kernel_tests/sgld_optimizer_test.py @@ -1,4 +1,4 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -36,9 +36,9 @@ class SGLDOptimizerTest(test.TestCase): grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) decay_rate = 0.53 - sgd_op = SGLDOptimizer( - 3.0, preconditioner_decay_rate=decay_rate).apply_gradients( - zip([grads0, grads1], [var0, var1])) + sgd_optimizer = SGLDOptimizer(3.0, preconditioner_decay_rate=decay_rate) + sgd_op = sgd_optimizer.apply_gradients( + zip([grads0, grads1], [var0, var1])) variables.global_variables_initializer().run() # Fetch params to validate initial values self.assertAllCloseAccordingToType([1.1, 2.1], var0.eval()) @@ -54,6 +54,7 @@ class SGLDOptimizerTest(test.TestCase): decay_rate + (1 - decay_rate) * 0.01**2 + 1e-8)) self.assertAllCloseAccordingToType( [3.0 - 3.0 * grads_scaled, 4.0 - 3.0 * grads_scaled], var1.eval()) + self.assertAllCloseAccordingToType(1, sgd_optimizer._counter.eval()) def testBasicMultiInstance(self): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: @@ -102,6 +103,8 @@ class SGLDOptimizerTest(test.TestCase): sgd_optimizer2.variable_scope) self.assertNotEqual(sgd_optimizer.variable_scope.name, sgd_optimizer2.variable_scope.name) + self.assertAllCloseAccordingToType(1, sgd_optimizer._counter.eval()) + self.assertAllCloseAccordingToType(1, sgd_optimizer2._counter.eval()) def testTensorLearningRate(self): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/variational_sgd_optimizer_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/variational_sgd_optimizer_test.py new file mode 100644 index 0000000000000000000000000000000000000000..83c64dbe0fd586edcb784a5c09a4c133aaa99cff --- /dev/null +++ b/tensorflow/contrib/bayesflow/python/kernel_tests/variational_sgd_optimizer_test.py @@ -0,0 +1,268 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Functional test for GradientDescent.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from tensorflow.contrib.bayesflow.python.ops.optimizers import VariationalSGDOptimizer +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + + +class VariationalSGDOptimizerTest(test.TestCase): + + def testBasic(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.test_session(): + var0 = variables.Variable([1.1, 2.1], dtype=dtype) + var1 = variables.Variable([3.0, 4.0], dtype=dtype) + grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) + grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) + decay_rate = 0.53 + sgd_op = VariationalSGDOptimizer( + 1, + 1, + preconditioner_decay_rate=decay_rate, + max_learning_rate=3.0, + burnin_max_learning_rate=3.0, + use_single_learning_rate=True).apply_gradients( + zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + # Fetch params to validate initial values + self.assertAllCloseAccordingToType([1.1, 2.1], var0.eval()) + self.assertAllCloseAccordingToType([3.0, 4.0], var1.eval()) + # Run 1 step of sgd + sgd_op.run() + self.assertAllCloseAccordingToType([1.1 - 3.0 * 0.1, 2.1 - 3.0 * 0.1], + var0.eval()) + self.assertAllCloseAccordingToType([3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01], + var1.eval()) + + def testBasicMultiInstance(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.test_session(): + var0 = variables.Variable([1.1, 2.1], dtype=dtype) + var1 = variables.Variable([3.0, 4.0], dtype=dtype) + grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) + grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) + vara = variables.Variable([1.1, 2.1], dtype=dtype) + varb = variables.Variable([3.0, 4.0], dtype=dtype) + gradsa = constant_op.constant([0.1, 0.1], dtype=dtype) + gradsb = constant_op.constant([0.01, 0.01], dtype=dtype) + decay_rate = 0.5 + batch_size = 2 + total_num_examples = 10 + optimizer = VariationalSGDOptimizer( + batch_size, + total_num_examples, + max_learning_rate=1.0, + burnin_max_learning_rate=3.0, + preconditioner_decay_rate=decay_rate) + sgd_op = optimizer.apply_gradients( + zip([grads0, grads1], [var0, var1])) + optimizer2 = VariationalSGDOptimizer( + batch_size, + total_num_examples, + max_learning_rate=1.0, + burnin_max_learning_rate=10.0, + burnin=0, + preconditioner_decay_rate=decay_rate) + sgd_op2 = optimizer2.apply_gradients( + zip([gradsa, gradsb], [vara, varb])) + variables.global_variables_initializer().run() + # Fetch params to validate initial values + self.assertAllCloseAccordingToType([1.1, 2.1], var0.eval()) + self.assertAllCloseAccordingToType([3.0, 4.0], var1.eval()) + self.assertAllCloseAccordingToType([1.1, 2.1], vara.eval()) + self.assertAllCloseAccordingToType([3.0, 4.0], varb.eval()) + + # Run 1 step of sgd + sgd_op.run() + sgd_op2.run() + # Validate updated params + self.assertAllCloseAccordingToType([1.1 - 3. * 0.1, 2.1 - 3. * 0.1], + var0.eval()) + self.assertAllCloseAccordingToType([1.1 - 0.1, 2.1 - 0.1], vara.eval()) + + self.assertAllCloseAccordingToType([3.0 - 3. * 0.01, 4.0 - 3. * 0.01], + var1.eval()) + self.assertAllCloseAccordingToType([3.0 - 0.01, 4.0 - 0.01], + varb.eval()) + self.assertNotEqual(optimizer.variable_scope, + optimizer2.variable_scope) + self.assertNotEqual(optimizer.variable_scope.name, + optimizer2.variable_scope.name) + self.assertAllCloseAccordingToType(1, optimizer._counter.eval()) + self.assertAllCloseAccordingToType(1, optimizer2._counter.eval()) + + def testTensorLearningRate(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.test_session(): + var0 = variables.Variable([1.1, 2.1], dtype=dtype) + var1 = variables.Variable([3.0, 4.0], dtype=dtype) + grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) + grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) + lrate = constant_op.constant(3.0) + decay_rate = 0.5 + batch_size = 2 + total_num_examples = 10 + sgd_op = VariationalSGDOptimizer( + batch_size, + total_num_examples, + max_learning_rate=lrate, + burnin=0, + preconditioner_decay_rate=decay_rate).apply_gradients( + zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + # Fetch params to validate initial values + self.assertAllCloseAccordingToType([1.1, 2.1], var0.eval()) + self.assertAllCloseAccordingToType([3.0, 4.0], var1.eval()) + # Run 1 step of sgd + sgd_op.run() + # Validate updated params + self.assertAllCloseAccordingToType([1.1 - 3.0 * 0.1, 2.1 - 3.0 * 0.1], + var0.eval()) + self.assertAllCloseAccordingToType([3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01], + var1.eval()) + + def testTensorDecayLearningRate(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.test_session(): + var0 = variables.Variable([1.1, 2.1], dtype=dtype) + var1 = variables.Variable([3.0, 4.0], dtype=dtype) + grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) + grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) + lrate = variables.Variable(3.0) + lrate_decay_op = lrate.assign_add(-3.) + decay_rate = 0.5 + batch_size = 2 + total_num_examples = 10 + optimizer = VariationalSGDOptimizer( + batch_size, + total_num_examples, + max_learning_rate=lrate, + burnin=0, + preconditioner_decay_rate=decay_rate) + sgd_op = optimizer.apply_gradients(zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + # Fetch params to validate initial values + self.assertAllCloseAccordingToType([1.1, 2.1], var0.eval()) + self.assertAllCloseAccordingToType([3.0, 4.0], var1.eval()) + # Run 1 step of sgd + sgd_op.run() + # Validate updated params + self.assertAllCloseAccordingToType([1.1 - 3.0 * 0.1, 2.1 - 3.0 * 0.1], + var0.eval()) + self.assertAllCloseAccordingToType([3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01], + var1.eval()) + # Update learning rate to 0 + lrate_decay_op.eval() + sgd_op.run() + # Validate params haven't changed + self.assertAllCloseAccordingToType([1.1 - 3.0 * 0.1, 2.1 - 3.0 * 0.1], + var0.eval()) + self.assertAllCloseAccordingToType([3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01], + var1.eval()) + lrate_decay_op.eval() + + with self.assertRaises(errors.InvalidArgumentError): + sgd_op.run() + + def testGradWrtRef(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.test_session(): + opt = VariationalSGDOptimizer(1, 1, max_learning_rate=1.0) + values = [1.0, 3.0] + vars_ = [variables.Variable([v], dtype=dtype) for v in values] + grads_and_vars = opt.compute_gradients(vars_[0] + vars_[1], vars_) + variables.global_variables_initializer().run() + for grad, _ in grads_and_vars: + self.assertAllCloseAccordingToType([1.0], grad.eval()) + + def testWithGlobalStep(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.test_session(): + global_step = variables.Variable(0, trainable=False) + var0 = variables.Variable([1.1, 2.1], dtype=dtype) + var1 = variables.Variable([3.0, 4.0], dtype=dtype) + grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) + grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) + decay_rate = 0.1 + batch_size = 2 + total_num_examples = 10 + sgd_optimizer = VariationalSGDOptimizer( + batch_size, + total_num_examples, + max_learning_rate=3.0, + burnin=0, + preconditioner_decay_rate=decay_rate) + sgd_op = sgd_optimizer.apply_gradients( + zip([grads0, grads1], [var0, var1]), global_step=global_step) + variables.global_variables_initializer().run() + # Fetch params to validate initial values + self.assertAllCloseAccordingToType([1.1, 2.1], var0.eval()) + self.assertAllCloseAccordingToType([3.0, 4.0], var1.eval()) + # Run 1 step of sgd + sgd_op.run() + + # Validate updated params and global_step + self.assertAllCloseAccordingToType([1.1 - 3.0 * 0.1, 2.1 - 3.0 * 0.1], + var0.eval()) + self.assertAllCloseAccordingToType([3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01], + var1.eval()) + self.assertAllCloseAccordingToType(1, global_step.eval()) + self.assertAllCloseAccordingToType(1, sgd_optimizer._counter.eval()) + + def testSparseBasic(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.test_session(): + var0 = variables.Variable([[1.1], [2.1]], dtype=dtype) + var1 = variables.Variable([[3.0], [4.0]], dtype=dtype) + grads0 = ops.IndexedSlices( + constant_op.constant([0.1], shape=[1, 1], dtype=dtype), + constant_op.constant([0]), constant_op.constant([2, 1])) + grads1 = ops.IndexedSlices( + constant_op.constant([0.01], shape=[1, 1], dtype=dtype), + constant_op.constant([1]), constant_op.constant([2, 1])) + decay_rate = 0.1 + batch_size = 2 + total_num_examples = 10 + sgd_op = VariationalSGDOptimizer( + batch_size, + total_num_examples, + max_learning_rate=3.0, + burnin=0, + preconditioner_decay_rate=decay_rate).apply_gradients( + zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + # Fetch params to validate initial values + self.assertAllCloseAccordingToType([[1.1], [2.1]], var0.eval()) + self.assertAllCloseAccordingToType([[3.0], [4.0]], var1.eval()) + # Run 1 step of sgd + sgd_op.run() + # Validate updated params + self.assertAllCloseAccordingToType([[1.1 - 3.0 * 0.1], [2.1]], + var0.eval()) + self.assertAllCloseAccordingToType( + [[3.0 - 3.0 * 0], [4.0 - 3.0 * 0.01]], var1.eval()) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/bayesflow/python/ops/custom_grad_impl.py b/tensorflow/contrib/bayesflow/python/ops/custom_grad_impl.py index ee3719232d8796c338247320fd8ef832a41df12b..fdc12e3b21466a2c552124d6c6a339a0c25f9f46 100644 --- a/tensorflow/contrib/bayesflow/python/ops/custom_grad_impl.py +++ b/tensorflow/contrib/bayesflow/python/ops/custom_grad_impl.py @@ -43,7 +43,7 @@ def custom_gradient(fx, gx, x, axis=(), h(x) = x * stop_gradient(g(x)) + stop_gradient(f(x) - x * g(x)) ``` - is such that `h(x) = stop(f(x))` and `grad[h(x), x] = stop_gradient(g(x)).` + is such that `h(x) = stop_gradient(f(x))` and `grad[h(x), x] = stop_gradient(g(x)).` In addition to scalar-domain/scalar-range functions, this function also supports tensor-domain/scalar-range functions. However, in the latter case it diff --git a/tensorflow/contrib/bayesflow/python/ops/hmc_impl.py b/tensorflow/contrib/bayesflow/python/ops/hmc_impl.py index 333dce929530adceb30dcb63653a5bd009c059e0..5685a942e98800a39ec718adc67bcfd43aeafd52 100644 --- a/tensorflow/contrib/bayesflow/python/ops/hmc_impl.py +++ b/tensorflow/contrib/bayesflow/python/ops/hmc_impl.py @@ -27,6 +27,7 @@ from __future__ import print_function import numpy as np +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops @@ -174,9 +175,11 @@ def chain(n_iterations, step_size, n_leapfrog_steps, initial_x, potential_and_grad = _make_potential_and_grad(target_log_prob_fn) potential, grad = potential_and_grad(initial_x) - return functional_ops.scan(body, array_ops.zeros(n_iterations), - (initial_x, array_ops.zeros(non_event_shape), - -potential, -grad))[:2] + return functional_ops.scan( + body, array_ops.zeros(n_iterations, dtype=initial_x.dtype), + (initial_x, + array_ops.zeros(non_event_shape, dtype=initial_x.dtype), + -potential, -grad))[:2] def ais_chain(n_iterations, step_size, n_leapfrog_steps, initial_x, @@ -298,8 +301,9 @@ def ais_chain(n_iterations, step_size, n_leapfrog_steps, initial_x, return updated_x, acceptance_probs, w x, acceptance_probs, w = functional_ops.scan( - _body, beta_series, (initial_x, array_ops.zeros(non_event_shape), - array_ops.zeros(non_event_shape))) + _body, beta_series, + (initial_x, array_ops.zeros(non_event_shape, dtype=initial_x.dtype), + array_ops.zeros(non_event_shape, dtype=initial_x.dtype))) return w[-1], x[-1], acceptance_probs[-1] @@ -446,9 +450,10 @@ def kernel(step_size, n_leapfrog_steps, x, target_log_prob_fn, event_dims=(), """ with ops.name_scope(name, 'hmc_kernel', [step_size, n_leapfrog_steps, x]): potential_and_grad = _make_potential_and_grad(target_log_prob_fn) + x = ops.convert_to_tensor(x, name='x') x_shape = array_ops.shape(x) - m = random_ops.random_normal(x_shape) + m = random_ops.random_normal(x_shape, dtype=x.dtype) kinetic_0 = 0.5 * math_ops.reduce_sum(math_ops.square(m), event_dims) @@ -468,26 +473,33 @@ def kernel(step_size, n_leapfrog_steps, x, target_log_prob_fn, event_dims=(), kinetic_1 = 0.5 * math_ops.reduce_sum(math_ops.square(new_m), event_dims) - # TODO(mhoffman): It seems like there may be an opportunity for nans here. - # I'm delaying addressing this because we're going to refactor this part - # to use the more general Metropolis abstraction anyway. - acceptance_probs = math_ops.exp(math_ops.minimum(0., log_potential_0 - - log_potential_1 + - kinetic_0 - kinetic_1)) - accepted = math_ops.cast( - random_ops.random_uniform(array_ops.shape(acceptance_probs)) < - acceptance_probs, np.float32) - new_log_prob = (-log_potential_0 * (1. - accepted) - - log_potential_1 * accepted) + energy_change = log_potential_1 - log_potential_0 + kinetic_1 - kinetic_0 + # Treat NaN as infinite energy (and therefore guaranteed rejection). + energy_change = array_ops.where( + math_ops.is_nan(energy_change), + array_ops.fill(array_ops.shape(energy_change), + energy_change.dtype.as_numpy_dtype(np.inf)), + energy_change) + acceptance_probs = math_ops.exp(math_ops.minimum(-energy_change, 0.)) + accepted = ( + random_ops.random_uniform( + array_ops.shape(acceptance_probs), dtype=x.dtype) + < acceptance_probs) + new_log_prob = -array_ops.where(accepted, log_potential_1, log_potential_0) # TODO(b/65738010): This should work, but it doesn't for now. # reduced_shape = math_ops.reduced_shape(x_shape, event_dims) reduced_shape = array_ops.shape(math_ops.reduce_sum(x, event_dims, keep_dims=True)) accepted = array_ops.reshape(accepted, reduced_shape) - new_x = x * (1. - accepted) + new_x * accepted - new_grad = -grad_0 * (1. - accepted) - grad_1 * accepted - + accepted = math_ops.logical_or( + accepted, math_ops.cast(array_ops.zeros_like(x), dtypes.bool)) + new_x = array_ops.where(accepted, new_x, x) + new_grad = -array_ops.where(accepted, grad_1, grad_0) + + # TODO(langmore) Gradients of acceptance_probs and new_log_prob with respect + # to initial_x will propagate NaNs (see testNanFromGradsDontPropagate). This + # should be fixed. return new_x, acceptance_probs, new_log_prob, new_grad @@ -525,6 +537,7 @@ def leapfrog_integrator(step_size, n_steps, initial_position, initial_momentum, Has shape matching `initial_position`. Example: Simple quadratic potential. + ```python def potential_and_grad(position): return tf.reduce_sum(0.5 * tf.square(position)), position @@ -600,6 +613,7 @@ def leapfrog_step(step_size, position, momentum, potential_and_grad, grad, Has shape matching `position`. Example: Simple quadratic potential. + ```python def potential_and_grad(position): # Simple quadratic potential diff --git a/tensorflow/contrib/bayesflow/python/ops/layers.py b/tensorflow/contrib/bayesflow/python/ops/layers.py index dcead38af826a12e776160bdb251ba021e6b953c..a742b7c1aa593d6c08bf9d8d597c99c9fc4e7aed 100644 --- a/tensorflow/contrib/bayesflow/python/ops/layers.py +++ b/tensorflow/contrib/bayesflow/python/ops/layers.py @@ -23,13 +23,43 @@ from __future__ import print_function # go/tf-wildcard-import # pylint: disable=wildcard-import -from tensorflow.contrib.bayesflow.python.ops.layers_dense_variational_impl import * +from tensorflow.contrib.bayesflow.python.ops.layers_conv_variational import * +from tensorflow.contrib.bayesflow.python.ops.layers_dense_variational import * +from tensorflow.contrib.bayesflow.python.ops.layers_util import * # pylint: enable=wildcard-import from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = [ - 'DenseVariational', - 'dense_variational', + 'Convolution1DReparameterization', + 'Convolution2DReparameterization', + 'Convolution3DReparameterization', + 'Convolution1DFlipout', + 'Convolution2DFlipout', + 'Convolution3DFlipout', + 'Conv1DReparameterization', + 'Conv2DReparameterization', + 'Conv3DReparameterization', + 'Conv1DFlipout', + 'Conv2DFlipout', + 'Conv3DFlipout', + 'convolution1d_reparameterization', + 'convolution2d_reparameterization', + 'convolution3d_reparameterization', + 'convolution1d_flipout', + 'convolution2d_flipout', + 'convolution3d_flipout', + 'conv1d_reparameterization', + 'conv2d_reparameterization', + 'conv3d_reparameterization', + 'conv1d_flipout', + 'conv2d_flipout', + 'conv3d_flipout', + 'DenseReparameterization', + 'DenseLocalReparameterization', + 'DenseFlipout', + 'dense_reparameterization', + 'dense_local_reparameterization', + 'dense_flipout', 'default_loc_scale_fn', 'default_mean_field_normal_fn', ] diff --git a/tensorflow/contrib/bayesflow/python/ops/layers_conv_variational.py b/tensorflow/contrib/bayesflow/python/ops/layers_conv_variational.py new file mode 100644 index 0000000000000000000000000000000000000000..7723cfb442712626ff415f1412e3362f2392ce9f --- /dev/null +++ b/tensorflow/contrib/bayesflow/python/ops/layers_conv_variational.py @@ -0,0 +1,2943 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Convolutional variational layer classes and their functional aliases. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.bayesflow.python.ops import layers_util +from tensorflow.contrib.distributions.python.ops import independent as independent_lib +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.layers import base as layers_lib +from tensorflow.python.layers import utils +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import nn +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import standard_ops +from tensorflow.python.ops.distributions import kullback_leibler as kl_lib +from tensorflow.python.ops.distributions import normal as normal_lib +from tensorflow.python.ops.distributions import util as distribution_util + + +class _ConvVariational(layers_lib.Layer): + """Abstract nD convolution layer (private, used as implementation base). + + This layer creates a convolution kernel that is convolved + (actually cross-correlated) with the layer input to produce a tensor of + outputs. It may also include a bias addition and activation function + on the outputs. It assumes the `kernel` and/or `bias` are drawn from + distributions. + + By default, the layer implements a stochastic forward pass via + sampling from the kernel and bias posteriors, + ```none + outputs = f(inputs; kernel, bias), kernel, bias ~ posterior + ``` + where f denotes the layer's calculation. + + The arguments permit separate specification of the surrogate posterior + (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` + distributions. + + Arguments: + rank: An integer, the rank of the convolution, e.g. "2" for 2D convolution. + filters: Integer, the dimensionality of the output space (i.e. the number + of filters in the convolution). + kernel_size: An integer or tuple/list of n integers, specifying the + length of the convolution window. + strides: An integer or tuple/list of n integers, + specifying the stride length of the convolution. + Specifying any stride value != 1 is incompatible with specifying + any `dilation_rate` value != 1. + padding: One of `"valid"` or `"same"` (case-insensitive). + data_format: A string, one of `channels_last` (default) or `channels_first`. + The ordering of the dimensions in the inputs. + `channels_last` corresponds to inputs with shape + `(batch, ..., channels)` while `channels_first` corresponds to + inputs with shape `(batch, channels, ...)`. + dilation_rate: An integer or tuple/list of n integers, specifying + the dilation rate to use for dilated convolution. + Currently, specifying any `dilation_rate` value != 1 is + incompatible with specifying any `strides` value != 1. + activation: Activation function. Set it to None to maintain a + linear activation. + activity_regularizer: Optional regularizer function for the output. + trainable: Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + kernel_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `kernel` parameter. Default value: + `default_mean_field_normal_fn()`. + kernel_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + kernel_prior_fn: Python `callable` which creates `tf.distributions` + instance. See `default_mean_field_normal_fn` docstring for required + parameter signature. + Default value: `tf.distributions.Normal(loc=0., scale=1.)`. + kernel_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + sample is a `Tensor`. + bias_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `bias` parameter. Default value: + `default_mean_field_normal_fn(is_singular=True)` (which creates an + instance of `tf.distributions.Deterministic`). + bias_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + bias_prior_fn: Python `callable` which creates `tf.distributions` instance. + See `default_mean_field_normal_fn` docstring for required parameter + signature. Default value: `None` (no prior, no variational inference) + bias_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + sample is a `Tensor`. + name: A string, the name of the layer. + + Properties: + rank: Python integer, dimensionality of convolution. + filters: Python integer, dimensionality of the output space. + kernel_size: Size of the convolution window. + strides: Stride length of convolution. + padding: Python string describing padding approach. + data_format: Python string describing input data's dimensions. + dilation_rate: Dilation rate for an atrous convolution. + activation: Activation function (`callable`). + activity_regularizer: Regularizer function for the output. + kernel_posterior_fn: `callable` returning posterior. + kernel_posterior_tensor_fn: `callable` operating on posterior. + kernel_prior_fn: `callable` returning prior. + kernel_divergence_fn: `callable` returning divergence. + bias_posterior_fn: `callable` returning posterior. + bias_posterior_tensor_fn: `callable` operating on posterior. + bias_prior_fn: `callable` returning prior. + bias_divergence_fn: `callable` returning divergence. + """ + + def __init__( + self, + rank, + filters, + kernel_size, + strides=1, + padding="valid", + data_format="channels_last", + dilation_rate=1, + activation=None, + activity_regularizer=None, + trainable=True, + kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), + kernel_posterior_tensor_fn=lambda d: d.sample(), + kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda + loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), + kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long + bias_posterior_tensor_fn=lambda d: d.sample(), + bias_prior_fn=None, + bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + name=None, + **kwargs): + super(_ConvVariational, self).__init__( + trainable=trainable, + name=name, + activity_regularizer=activity_regularizer, + **kwargs) + self.rank = rank + self.filters = filters + self.kernel_size = utils.normalize_tuple(kernel_size, rank, "kernel_size") + self.strides = utils.normalize_tuple(strides, rank, "strides") + self.padding = utils.normalize_padding(padding) + self.data_format = utils.normalize_data_format(data_format) + self.dilation_rate = utils.normalize_tuple( + dilation_rate, rank, "dilation_rate") + self.activation = activation + self.input_spec = layers_lib.InputSpec(ndim=self.rank + 2) + self.kernel_posterior_fn = kernel_posterior_fn + self.kernel_posterior_tensor_fn = kernel_posterior_tensor_fn + self.kernel_prior_fn = kernel_prior_fn + self.kernel_divergence_fn = kernel_divergence_fn + self.bias_posterior_fn = bias_posterior_fn + self.bias_posterior_tensor_fn = bias_posterior_tensor_fn + self.bias_prior_fn = bias_prior_fn + self.bias_divergence_fn = bias_divergence_fn + + def build(self, input_shape): + input_shape = tensor_shape.TensorShape(input_shape) + if self.data_format == "channels_first": + channel_axis = 1 + else: + channel_axis = -1 + if input_shape[channel_axis].value is None: + raise ValueError("The channel dimension of the inputs " + "should be defined. Found `None`.") + input_dim = input_shape[channel_axis].value + kernel_shape = self.kernel_size + (input_dim, self.filters) + dtype = dtypes.as_dtype(self.dtype) + + # Must have a posterior kernel. + self.kernel_posterior = self.kernel_posterior_fn( + dtype, kernel_shape, "kernel_posterior", + self.trainable, self.add_variable) + + if self.kernel_prior_fn is None: + self.kernel_prior = None + else: + self.kernel_prior = self.kernel_prior_fn( + dtype, kernel_shape, "kernel_prior", + self.trainable, self.add_variable) + self._built_kernel_divergence = False + + if self.bias_posterior_fn is None: + self.bias_posterior = None + else: + self.bias_posterior = self.bias_posterior_fn( + dtype, (self.filters,), "bias_posterior", + self.trainable, self.add_variable) + + if self.bias_prior_fn is None: + self.bias_prior = None + else: + self.bias_prior = self.bias_prior_fn( + dtype, (self.filters,), "bias_prior", + self.trainable, self.add_variable) + self._built_bias_divergence = False + + self.input_spec = layers_lib.InputSpec(ndim=self.rank + 2, + axes={channel_axis: input_dim}) + self._convolution_op = nn_ops.Convolution( + input_shape, + filter_shape=tensor_shape.TensorShape(kernel_shape), + dilation_rate=self.dilation_rate, + strides=self.strides, + padding=self.padding.upper(), + data_format=utils.convert_data_format(self.data_format, + self.rank + 2)) + + self.built = True + + def call(self, inputs): + inputs = ops.convert_to_tensor(inputs, dtype=self.dtype) + + outputs = self._apply_variational_kernel(inputs) + outputs = self._apply_variational_bias(outputs) + if self.activation is not None: + outputs = self.activation(outputs) + if not self._built_kernel_divergence: + kernel_posterior = self.kernel_posterior + kernel_prior = self.kernel_prior + if isinstance(self.kernel_posterior, independent_lib.Independent): + kernel_posterior = kernel_posterior.distribution + if isinstance(self.kernel_prior, independent_lib.Independent): + kernel_prior = kernel_prior.distribution + self._apply_divergence(self.kernel_divergence_fn, + kernel_posterior, + kernel_prior, + self.kernel_posterior_tensor, + name="divergence_kernel") + self._built_kernel_divergence = True + if not self._built_bias_divergence: + bias_posterior = self.bias_posterior + bias_prior = self.bias_prior + if isinstance(self.bias_posterior, independent_lib.Independent): + bias_posterior = bias_posterior.distribution + if isinstance(self.bias_prior, independent_lib.Independent): + bias_prior = bias_prior.distribution + self._apply_divergence(self.bias_divergence_fn, + bias_posterior, + bias_prior, + self.bias_posterior_tensor, + name="divergence_bias") + self._built_bias_divergence = True + return outputs + + def _apply_variational_bias(self, inputs): + if self.bias_posterior is None: + self.bias_posterior_tensor = None + return inputs + self.bias_posterior_tensor = self.bias_posterior_tensor_fn( + self.bias_posterior) + outputs = inputs + if self.data_format == "channels_first": + if self.rank == 1: + # nn.bias_add does not accept a 1D input tensor. + bias = array_ops.reshape(self.bias_posterior_tensor, + (1, self.filters, 1)) + outputs += bias + if self.rank == 2: + outputs = nn.bias_add(outputs, + self.bias_posterior_tensor, + data_format="NCHW") + if self.rank == 3: + # As of Mar 2017, direct addition is significantly slower than + # bias_add when computing gradients. To use bias_add, we collapse Z + # and Y into a single dimension to obtain a 4D input tensor. + outputs_shape = outputs.shape.as_list() + outputs_4d = array_ops.reshape(outputs, + [outputs_shape[0], outputs_shape[1], + outputs_shape[2] * outputs_shape[3], + outputs_shape[4]]) + outputs_4d = nn.bias_add(outputs_4d, + self.bias_posterior_tensor, + data_format="NCHW") + outputs = array_ops.reshape(outputs_4d, outputs_shape) + else: + outputs = nn.bias_add(outputs, + self.bias_posterior_tensor, + data_format="NHWC") + return outputs + + def _apply_divergence(self, divergence_fn, posterior, prior, + posterior_tensor, name): + if (divergence_fn is None or + posterior is None or + prior is None): + divergence = None + return + divergence = standard_ops.identity( + divergence_fn( + posterior, prior, posterior_tensor), + name=name) + self.add_loss(divergence) + + def _compute_output_shape(self, input_shape): + input_shape = tensor_shape.TensorShape(input_shape).as_list() + if self.data_format == "channels_last": + space = input_shape[1:-1] + new_space = [] + for i in range(len(space)): + new_dim = utils.conv_output_length( + space[i], + self.kernel_size[i], + padding=self.padding, + stride=self.strides[i], + dilation=self.dilation_rate[i]) + new_space.append(new_dim) + return tensor_shape.TensorShape([input_shape[0]] + new_space + + [self.filters]) + else: + space = input_shape[2:] + new_space = [] + for i in range(len(space)): + new_dim = utils.conv_output_length( + space[i], + self.kernel_size[i], + padding=self.padding, + stride=self.strides[i], + dilation=self.dilation_rate[i]) + new_space.append(new_dim) + return tensor_shape.TensorShape([input_shape[0], self.filters] + + new_space) + + +class _ConvReparameterization(_ConvVariational): + """Abstract nD convolution layer (private, used as implementation base). + + This layer creates a convolution kernel that is convolved + (actually cross-correlated) with the layer input to produce a tensor of + outputs. It may also include a bias addition and activation function + on the outputs. It assumes the `kernel` and/or `bias` are drawn from + distributions. + + By default, the layer implements a stochastic forward pass via + sampling from the kernel and bias posteriors, + ```none + outputs = f(inputs; kernel, bias), kernel, bias ~ posterior + ``` + where f denotes the layer's calculation. It uses the reparameterization + estimator [1], which performs a Monte Carlo approximation of the + distribution integrating over the `kernel` and `bias`. + + The arguments permit separate specification of the surrogate posterior + (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` + distributions. + + Arguments: + rank: An integer, the rank of the convolution, e.g. "2" for 2D convolution. + filters: Integer, the dimensionality of the output space (i.e. the number + of filters in the convolution). + kernel_size: An integer or tuple/list of n integers, specifying the + length of the convolution window. + strides: An integer or tuple/list of n integers, + specifying the stride length of the convolution. + Specifying any stride value != 1 is incompatible with specifying + any `dilation_rate` value != 1. + padding: One of `"valid"` or `"same"` (case-insensitive). + data_format: A string, one of `channels_last` (default) or `channels_first`. + The ordering of the dimensions in the inputs. + `channels_last` corresponds to inputs with shape + `(batch, ..., channels)` while `channels_first` corresponds to + inputs with shape `(batch, channels, ...)`. + dilation_rate: An integer or tuple/list of n integers, specifying + the dilation rate to use for dilated convolution. + Currently, specifying any `dilation_rate` value != 1 is + incompatible with specifying any `strides` value != 1. + activation: Activation function. Set it to None to maintain a + linear activation. + activity_regularizer: Optional regularizer function for the output. + trainable: Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + kernel_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `kernel` parameter. Default value: + `default_mean_field_normal_fn()`. + kernel_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + kernel_prior_fn: Python `callable` which creates `tf.distributions` + instance. See `default_mean_field_normal_fn` docstring for required + parameter signature. + Default value: `tf.distributions.Normal(loc=0., scale=1.)`. + kernel_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + sample is a `Tensor`. + bias_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `bias` parameter. Default value: + `default_mean_field_normal_fn(is_singular=True)` (which creates an + instance of `tf.distributions.Deterministic`). + bias_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + bias_prior_fn: Python `callable` which creates `tf.distributions` instance. + See `default_mean_field_normal_fn` docstring for required parameter + signature. Default value: `None` (no prior, no variational inference) + bias_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + sample is a `Tensor`. + name: A string, the name of the layer. + + Properties: + rank: Python integer, dimensionality of convolution. + filters: Python integer, dimensionality of the output space. + kernel_size: Size of the convolution window. + strides: Stride length of convolution. + padding: Python string describing padding approach. + data_format: Python string describing input data's dimensions. + dilation_rate: Dilation rate for an atrous convolution. + activation: Activation function (`callable`). + activity_regularizer: Regularizer function for the output. + kernel_posterior_fn: `callable` returning posterior. + kernel_posterior_tensor_fn: `callable` operating on posterior. + kernel_prior_fn: `callable` returning prior. + kernel_divergence_fn: `callable` returning divergence. + bias_posterior_fn: `callable` returning posterior. + bias_posterior_tensor_fn: `callable` operating on posterior. + bias_prior_fn: `callable` returning prior. + bias_divergence_fn: `callable` returning divergence. + + [1]: "Auto-Encoding Variational Bayes." + Diederik P. Kingma, Max Welling. + International Conference on Learning Representations, 2014. + """ + + def __init__( + self, + rank, + filters, + kernel_size, + strides=1, + padding="valid", + data_format="channels_last", + dilation_rate=1, + activation=None, + activity_regularizer=None, + trainable=True, + kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), + kernel_posterior_tensor_fn=lambda d: d.sample(), + kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda + loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), + kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long + bias_posterior_tensor_fn=lambda d: d.sample(), + bias_prior_fn=None, + bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + name=None, + **kwargs): + super(_ConvReparameterization, self).__init__( + rank=rank, + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + data_format=data_format, + dilation_rate=dilation_rate, + activation=activation, + activity_regularizer=activity_regularizer, + trainable=trainable, + kernel_posterior_fn=kernel_posterior_fn, + kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, + kernel_prior_fn=kernel_prior_fn, + kernel_divergence_fn=kernel_divergence_fn, + bias_posterior_fn=bias_posterior_fn, + bias_posterior_tensor_fn=bias_posterior_tensor_fn, + bias_prior_fn=bias_prior_fn, + bias_divergence_fn=bias_divergence_fn, + name=name, **kwargs) + + def _apply_variational_kernel(self, inputs): + self.kernel_posterior_tensor = self.kernel_posterior_tensor_fn( + self.kernel_posterior) + self.kernel_posterior_affine = None + self.kernel_posterior_affine_tensor = None + outputs = self._convolution_op(inputs, self.kernel_posterior_tensor) + return outputs + + +class Conv1DReparameterization(_ConvReparameterization): + """1D convolution layer (e.g. temporal convolution). + + This layer creates a convolution kernel that is convolved + (actually cross-correlated) with the layer input to produce a tensor of + outputs. It may also include a bias addition and activation function + on the outputs. It assumes the `kernel` and/or `bias` are drawn from + distributions. + + By default, the layer implements a stochastic forward pass via + sampling from the kernel and bias posteriors, + ```none + outputs = f(inputs; kernel, bias), kernel, bias ~ posterior + ``` + where f denotes the layer's calculation. It uses the reparameterization + estimator [1], which performs a Monte Carlo approximation of the + distribution integrating over the `kernel` and `bias`. + + The arguments permit separate specification of the surrogate posterior + (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` + distributions. + + Arguments: + filters: Integer, the dimensionality of the output space (i.e. the number + of filters in the convolution). + kernel_size: An integer or tuple/list of a single integer, specifying the + length of the 1D convolution window. + strides: An integer or tuple/list of a single integer, + specifying the stride length of the convolution. + Specifying any stride value != 1 is incompatible with specifying + any `dilation_rate` value != 1. + padding: One of `"valid"` or `"same"` (case-insensitive). + data_format: A string, one of `channels_last` (default) or `channels_first`. + The ordering of the dimensions in the inputs. + `channels_last` corresponds to inputs with shape + `(batch, length, channels)` while `channels_first` corresponds to + inputs with shape `(batch, channels, length)`. + dilation_rate: An integer or tuple/list of a single integer, specifying + the dilation rate to use for dilated convolution. + Currently, specifying any `dilation_rate` value != 1 is + incompatible with specifying any `strides` value != 1. + activation: Activation function. Set it to None to maintain a + linear activation. + activity_regularizer: Optional regularizer function for the output. + trainable: Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + kernel_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `kernel` parameter. Default value: + `default_mean_field_normal_fn()`. + kernel_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + kernel_prior_fn: Python `callable` which creates `tf.distributions` + instance. See `default_mean_field_normal_fn` docstring for required + parameter signature. + Default value: `tf.distributions.Normal(loc=0., scale=1.)`. + kernel_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + sample is a `Tensor`. + bias_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `bias` parameter. Default value: + `default_mean_field_normal_fn(is_singular=True)` (which creates an + instance of `tf.distributions.Deterministic`). + bias_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + bias_prior_fn: Python `callable` which creates `tf.distributions` instance. + See `default_mean_field_normal_fn` docstring for required parameter + signature. Default value: `None` (no prior, no variational inference) + bias_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + name: A string, the name of the layer. + + Properties: + filters: Python integer, dimensionality of the output space. + kernel_size: Size of the convolution window. + strides: Stride length of convolution. + padding: Python string describing padding approach. + data_format: Python string describing input data's dimensions. + dilation_rate: Dilation rate for an atrous convolution. + activation: Activation function (`callable`). + activity_regularizer: Regularizer function for the output. + kernel_posterior_fn: `callable` returning posterior. + kernel_posterior_tensor_fn: `callable` operating on posterior. + kernel_prior_fn: `callable` returning prior. + kernel_divergence_fn: `callable` returning divergence. + bias_posterior_fn: `callable` returning posterior. + bias_posterior_tensor_fn: `callable` operating on posterior. + bias_prior_fn: `callable` returning prior. + bias_divergence_fn: `callable` returning divergence. + + #### Examples + + We illustrate a Bayesian neural network with [variational inference]( + https://en.wikipedia.org/wiki/Variational_Bayesian_methods), + assuming a dataset of `features` and `labels`. + + ```python + tfp = tf.contrib.bayesflow + + net = tf.reshape(features, [-1, 128, 1]) + net = tfp.layers.Conv1DReparameterization(64, + kernel_size=5, + padding="SAME", + activation=tf.nn.relu)(net) + net = tf.reshape(net, [-1, 128 * 64]) + logits = tfp.layers.DenseReparameterization(10)(net) + neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( + labels=labels, logits=logits) + kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) + loss = neg_log_likelihood + kl + train_op = tf.train.AdamOptimizer().minimize(loss) + ``` + + It uses reparameterization gradients to minimize the + Kullback-Leibler divergence up to a constant, also known as the + negative Evidence Lower Bound. It consists of the sum of two terms: + the expected negative log-likelihood, which we approximate via + Monte Carlo; and the KL divergence, which is added via regularizer + terms which are arguments to the layer. + + [1]: "Auto-Encoding Variational Bayes." + Diederik P. Kingma, Max Welling. + International Conference on Learning Representations, 2014. + """ + + def __init__( + self, + filters, + kernel_size, + strides=1, + padding="valid", + data_format="channels_last", + dilation_rate=1, + activation=None, + activity_regularizer=None, + trainable=True, + kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), + kernel_posterior_tensor_fn=lambda d: d.sample(), + kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda + loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), + kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long + bias_posterior_tensor_fn=lambda d: d.sample(), + bias_prior_fn=None, + bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + name=None, + **kwargs): + super(Conv1DReparameterization, self).__init__( + rank=1, + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + data_format=data_format, + dilation_rate=dilation_rate, + activation=activation, + activity_regularizer=activity_regularizer, + trainable=trainable, + kernel_posterior_fn=kernel_posterior_fn, + kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, + kernel_prior_fn=kernel_prior_fn, + kernel_divergence_fn=kernel_divergence_fn, + bias_posterior_fn=bias_posterior_fn, + bias_posterior_tensor_fn=bias_posterior_tensor_fn, + bias_prior_fn=bias_prior_fn, + bias_divergence_fn=bias_divergence_fn, + name=name, **kwargs) + + +def conv1d_reparameterization( + inputs, + filters, + kernel_size, + strides=1, + padding="valid", + data_format="channels_last", + dilation_rate=1, + activation=None, + activity_regularizer=None, + trainable=True, + kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), + kernel_posterior_tensor_fn=lambda d: d.sample(), + kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda + loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), + kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long + bias_posterior_tensor_fn=lambda d: d.sample(), + bias_prior_fn=None, + bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + name=None, + reuse=None): + """Functional interface for 1D convolution layer (e.g. temporal convolution). + + This layer creates a convolution kernel that is convolved + (actually cross-correlated) with the layer input to produce a tensor of + outputs. It may also include a bias addition and activation function + on the outputs. It assumes the `kernel` and/or `bias` are drawn from + distributions. + + By default, the layer implements a stochastic forward pass via + sampling from the kernel and bias posteriors, + ```none + outputs = f(inputs; kernel, bias), kernel, bias ~ posterior + ``` + where f denotes the layer's calculation. It uses the reparameterization + estimator [1], which performs a Monte Carlo approximation of the + distribution integrating over the `kernel` and `bias`. + + The arguments permit separate specification of the surrogate posterior + (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` + distributions. + + Arguments: + inputs: Tensor input. + filters: Integer, the dimensionality of the output space (i.e. the number + of filters in the convolution). + kernel_size: An integer or tuple/list of a single integer, specifying the + length of the 1D convolution window. + strides: An integer or tuple/list of a single integer, + specifying the stride length of the convolution. + Specifying any stride value != 1 is incompatible with specifying + any `dilation_rate` value != 1. + padding: One of `"valid"` or `"same"` (case-insensitive). + data_format: A string, one of `channels_last` (default) or `channels_first`. + The ordering of the dimensions in the inputs. + `channels_last` corresponds to inputs with shape + `(batch, length, channels)` while `channels_first` corresponds to + inputs with shape `(batch, channels, length)`. + dilation_rate: An integer or tuple/list of a single integer, specifying + the dilation rate to use for dilated convolution. + Currently, specifying any `dilation_rate` value != 1 is + incompatible with specifying any `strides` value != 1. + activation: Activation function. Set it to None to maintain a + linear activation. + activity_regularizer: Optional regularizer function for the output. + trainable: Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + kernel_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `kernel` parameter. Default value: + `default_mean_field_normal_fn()`. + kernel_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + kernel_prior_fn: Python `callable` which creates `tf.distributions` + instance. See `default_mean_field_normal_fn` docstring for required + parameter signature. + Default value: `tf.distributions.Normal(loc=0., scale=1.)`. + kernel_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + sample is a `Tensor`. + bias_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `bias` parameter. Default value: + `default_mean_field_normal_fn(is_singular=True)` (which creates an + instance of `tf.distributions.Deterministic`). + bias_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + bias_prior_fn: Python `callable` which creates `tf.distributions` instance. + See `default_mean_field_normal_fn` docstring for required parameter + signature. Default value: `None` (no prior, no variational inference) + bias_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + name: A string, the name of the layer. + reuse: Boolean, whether to reuse the weights of a previous layer + by the same name. + + Returns: + Output tensor. + + Raises: + ValueError: if eager execution is enabled. + + #### Examples + + We illustrate a Bayesian neural network with [variational inference]( + https://en.wikipedia.org/wiki/Variational_Bayesian_methods), + assuming a dataset of `features` and `labels`. + + ```python + tfp = tf.contrib.bayesflow + + net = tf.reshape(features, [-1, 128, 1]) + net = tfp.layers.conv1d_reparameterization(net, + filters=64, + kernel_size=5, + padding="SAME", + activation=tf.nn.relu) + net = tf.reshape(net, [-1, 128 * 64]) + logits = tfp.layers.dense_reparameterization(net, 10) + neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( + labels=labels, logits=logits) + kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) + loss = neg_log_likelihood + kl + train_op = tf.train.AdamOptimizer().minimize(loss) + ``` + + It uses reparameterization gradients to minimize the + Kullback-Leibler divergence up to a constant, also known as the + negative Evidence Lower Bound. It consists of the sum of two terms: + the expected negative log-likelihood, which we approximate via + Monte Carlo; and the KL divergence, which is added via regularizer + terms which are arguments to the layer. + + [1]: "Auto-Encoding Variational Bayes." + Diederik P. Kingma, Max Welling. + International Conference on Learning Representations, 2014. + """ + layer = Conv1DReparameterization( + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + data_format=data_format, + dilation_rate=dilation_rate, + activation=activation, + activity_regularizer=activity_regularizer, + trainable=trainable, + kernel_posterior_fn=kernel_posterior_fn, + kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, + kernel_prior_fn=kernel_prior_fn, + kernel_divergence_fn=kernel_divergence_fn, + bias_posterior_fn=bias_posterior_fn, + bias_posterior_tensor_fn=bias_posterior_tensor_fn, + bias_prior_fn=bias_prior_fn, + bias_divergence_fn=bias_divergence_fn, + name=name, + dtype=inputs.dtype.base_dtype, + _scope=name, + _reuse=reuse) + return layer.apply(inputs) + + +class Conv2DReparameterization(_ConvReparameterization): + """2D convolution layer (e.g. spatial convolution over images). + + This layer creates a convolution kernel that is convolved + (actually cross-correlated) with the layer input to produce a tensor of + outputs. It may also include a bias addition and activation function + on the outputs. It assumes the `kernel` and/or `bias` are drawn from + distributions. + + By default, the layer implements a stochastic forward pass via + sampling from the kernel and bias posteriors, + ```none + outputs = f(inputs; kernel, bias), kernel, bias ~ posterior + ``` + where f denotes the layer's calculation. It uses the reparameterization + estimator [1], which performs a Monte Carlo approximation of the + distribution integrating over the `kernel` and `bias`. + + The arguments permit separate specification of the surrogate posterior + (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` + distributions. + + Arguments: + filters: Integer, the dimensionality of the output space (i.e. the number + of filters in the convolution). + kernel_size: An integer or tuple/list of 2 integers, specifying the + height and width of the 2D convolution window. + Can be a single integer to specify the same value for + all spatial dimensions. + strides: An integer or tuple/list of 2 integers, + specifying the strides of the convolution along the height and width. + Can be a single integer to specify the same value for + all spatial dimensions. + Specifying any stride value != 1 is incompatible with specifying + any `dilation_rate` value != 1. + padding: One of `"valid"` or `"same"` (case-insensitive). + data_format: A string, one of `channels_last` (default) or `channels_first`. + The ordering of the dimensions in the inputs. + `channels_last` corresponds to inputs with shape + `(batch, height, width, channels)` while `channels_first` corresponds to + inputs with shape `(batch, channels, height, width)`. + + dilation_rate: An integer or tuple/list of 2 integers, specifying + the dilation rate to use for dilated convolution. + Can be a single integer to specify the same value for + all spatial dimensions. + Currently, specifying any `dilation_rate` value != 1 is + incompatible with specifying any stride value != 1. + activation: Activation function. Set it to None to maintain a + linear activation. + activity_regularizer: Optional regularizer function for the output. + trainable: Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + kernel_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `kernel` parameter. Default value: + `default_mean_field_normal_fn()`. + kernel_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + kernel_prior_fn: Python `callable` which creates `tf.distributions` + instance. See `default_mean_field_normal_fn` docstring for required + parameter signature. + Default value: `tf.distributions.Normal(loc=0., scale=1.)`. + kernel_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + sample is a `Tensor`. + bias_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `bias` parameter. Default value: + `default_mean_field_normal_fn(is_singular=True)` (which creates an + instance of `tf.distributions.Deterministic`). + bias_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + bias_prior_fn: Python `callable` which creates `tf.distributions` instance. + See `default_mean_field_normal_fn` docstring for required parameter + signature. Default value: `None` (no prior, no variational inference) + bias_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + name: A string, the name of the layer. + + Properties: + filters: Python integer, dimensionality of the output space. + kernel_size: Size of the convolution window. + strides: Stride length of convolution. + padding: Python string describing padding approach. + data_format: Python string describing input data's dimensions. + dilation_rate: Dilation rate for an atrous convolution. + activation: Activation function (`callable`). + activity_regularizer: Regularizer function for the output. + kernel_posterior_fn: `callable` returning posterior. + kernel_posterior_tensor_fn: `callable` operating on posterior. + kernel_prior_fn: `callable` returning prior. + kernel_divergence_fn: `callable` returning divergence. + bias_posterior_fn: `callable` returning posterior. + bias_posterior_tensor_fn: `callable` operating on posterior. + bias_prior_fn: `callable` returning prior. + bias_divergence_fn: `callable` returning divergence. + + #### Examples + + We illustrate a Bayesian neural network with [variational inference]( + https://en.wikipedia.org/wiki/Variational_Bayesian_methods), + assuming a dataset of `features` and `labels`. + + ```python + tfp = tf.contrib.bayesflow + + net = tf.reshape(features, [-1, 32, 32, 3]) + net = tfp.layers.Conv2DReparameterization(64, + kernel_size=5, + padding="SAME", + activation=tf.nn.relu)(net) + net = tf.layers.MaxPooling2D(pool_size=2, + strides=2, + padding="SAME")(net) + net = tf.reshape(net, [-1, 8 * 8 * 64]) + logits = tfp.layers.DenseReparameterization(10)(net) + neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( + labels=labels, logits=logits) + kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) + loss = neg_log_likelihood + kl + train_op = tf.train.AdamOptimizer().minimize(loss) + ``` + + It uses reparameterization gradients to minimize the + Kullback-Leibler divergence up to a constant, also known as the + negative Evidence Lower Bound. It consists of the sum of two terms: + the expected negative log-likelihood, which we approximate via + Monte Carlo; and the KL divergence, which is added via regularizer + terms which are arguments to the layer. + + [1]: "Auto-Encoding Variational Bayes." + Diederik P. Kingma, Max Welling. + International Conference on Learning Representations, 2014. + """ + + def __init__( + self, + filters, + kernel_size, + strides=(1, 1), + padding="valid", + data_format="channels_last", + dilation_rate=(1, 1), + activation=None, + activity_regularizer=None, + trainable=True, + kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), + kernel_posterior_tensor_fn=lambda d: d.sample(), + kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda + loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), + kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long + bias_posterior_tensor_fn=lambda d: d.sample(), + bias_prior_fn=None, + bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + name=None, + **kwargs): + super(Conv2DReparameterization, self).__init__( + rank=2, + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + data_format=data_format, + dilation_rate=dilation_rate, + activation=activation, + activity_regularizer=activity_regularizer, + trainable=trainable, + kernel_posterior_fn=kernel_posterior_fn, + kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, + kernel_prior_fn=kernel_prior_fn, + kernel_divergence_fn=kernel_divergence_fn, + bias_posterior_fn=bias_posterior_fn, + bias_posterior_tensor_fn=bias_posterior_tensor_fn, + bias_prior_fn=bias_prior_fn, + bias_divergence_fn=bias_divergence_fn, + name=name, **kwargs) + + +def conv2d_reparameterization( + inputs, + filters, + kernel_size, + strides=(1, 1), + padding="valid", + data_format="channels_last", + dilation_rate=(1, 1), + activation=None, + activity_regularizer=None, + trainable=True, + kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), + kernel_posterior_tensor_fn=lambda d: d.sample(), + kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda + loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), + kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long + bias_posterior_tensor_fn=lambda d: d.sample(), + bias_prior_fn=None, + bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + name=None, + reuse=None): + """Functional interface for the 2D convolution layer. + + This layer creates a convolution kernel that is convolved + (actually cross-correlated) with the layer input to produce a tensor of + outputs. It may also include a bias addition and activation function + on the outputs. It assumes the `kernel` and/or `bias` are drawn from + distributions. + + By default, the layer implements a stochastic forward pass via + sampling from the kernel and bias posteriors, + ```none + outputs = f(inputs; kernel, bias), kernel, bias ~ posterior + ``` + where f denotes the layer's calculation. It uses the reparameterization + estimator [1], which performs a Monte Carlo approximation of the + distribution integrating over the `kernel` and `bias`. + + The arguments permit separate specification of the surrogate posterior + (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` + distributions. + + Arguments: + inputs: Tensor input. + filters: Integer, the dimensionality of the output space (i.e. the number + of filters in the convolution). + kernel_size: An integer or tuple/list of 2 integers, specifying the + height and width of the 2D convolution window. + Can be a single integer to specify the same value for + all spatial dimensions. + strides: An integer or tuple/list of 2 integers, + specifying the strides of the convolution along the height and width. + Can be a single integer to specify the same value for + all spatial dimensions. + Specifying any stride value != 1 is incompatible with specifying + any `dilation_rate` value != 1. + padding: One of `"valid"` or `"same"` (case-insensitive). + data_format: A string, one of `channels_last` (default) or `channels_first`. + The ordering of the dimensions in the inputs. + `channels_last` corresponds to inputs with shape + `(batch, height, width, channels)` while `channels_first` corresponds to + inputs with shape `(batch, channels, height, width)`. + + dilation_rate: An integer or tuple/list of 2 integers, specifying + the dilation rate to use for dilated convolution. + Can be a single integer to specify the same value for + all spatial dimensions. + Currently, specifying any `dilation_rate` value != 1 is + incompatible with specifying any stride value != 1. + activation: Activation function. Set it to None to maintain a + linear activation. + activity_regularizer: Optional regularizer function for the output. + trainable: Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + kernel_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `kernel` parameter. Default value: + `default_mean_field_normal_fn()`. + kernel_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + kernel_prior_fn: Python `callable` which creates `tf.distributions` + instance. See `default_mean_field_normal_fn` docstring for required + parameter signature. + Default value: `tf.distributions.Normal(loc=0., scale=1.)`. + kernel_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + sample is a `Tensor`. + bias_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `bias` parameter. Default value: + `default_mean_field_normal_fn(is_singular=True)` (which creates an + instance of `tf.distributions.Deterministic`). + bias_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + bias_prior_fn: Python `callable` which creates `tf.distributions` instance. + See `default_mean_field_normal_fn` docstring for required parameter + signature. Default value: `None` (no prior, no variational inference) + bias_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + name: A string, the name of the layer. + reuse: Boolean, whether to reuse the weights of a previous layer + by the same name. + + Returns: + Output tensor. + + Raises: + ValueError: if eager execution is enabled. + + #### Examples + + We illustrate a Bayesian neural network with [variational inference]( + https://en.wikipedia.org/wiki/Variational_Bayesian_methods), + assuming a dataset of `features` and `labels`. + + ```python + tfp = tf.contrib.bayesflow + + net = tf.reshape(features, [-1, 32, 32, 3]) + net = tfp.layers.conv2d_reparameterization(net, + filters=64, + kernel_size=5, + padding="SAME", + activation=tf.nn.relu) + net = tf.layers.max_pooling2d(net, + pool_size=2, + strides=2, + padding="SAME") + net = tf.reshape(net, [-1, 8 * 8 * 64]) + logits = tfp.layers.dense_reparameterization(net, 10) + neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( + labels=labels, logits=logits) + kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) + loss = neg_log_likelihood + kl + train_op = tf.train.AdamOptimizer().minimize(loss) + ``` + + It uses reparameterization gradients to minimize the + Kullback-Leibler divergence up to a constant, also known as the + negative Evidence Lower Bound. It consists of the sum of two terms: + the expected negative log-likelihood, which we approximate via + Monte Carlo; and the KL divergence, which is added via regularizer + terms which are arguments to the layer. + + [1]: "Auto-Encoding Variational Bayes." + Diederik P. Kingma, Max Welling. + International Conference on Learning Representations, 2014. + """ + layer = Conv2DReparameterization( + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + data_format=data_format, + dilation_rate=dilation_rate, + activation=activation, + activity_regularizer=activity_regularizer, + trainable=trainable, + kernel_posterior_fn=kernel_posterior_fn, + kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, + kernel_prior_fn=kernel_prior_fn, + kernel_divergence_fn=kernel_divergence_fn, + bias_posterior_fn=bias_posterior_fn, + bias_posterior_tensor_fn=bias_posterior_tensor_fn, + bias_prior_fn=bias_prior_fn, + bias_divergence_fn=bias_divergence_fn, + name=name, + dtype=inputs.dtype.base_dtype, + _scope=name, + _reuse=reuse) + return layer.apply(inputs) + + +class Conv3DReparameterization(_ConvReparameterization): + """3D convolution layer (e.g. spatial convolution over volumes). + + This layer creates a convolution kernel that is convolved + (actually cross-correlated) with the layer input to produce a tensor of + outputs. It may also include a bias addition and activation function + on the outputs. It assumes the `kernel` and/or `bias` are drawn from + distributions. + + By default, the layer implements a stochastic forward pass via + sampling from the kernel and bias posteriors, + ```none + outputs = f(inputs; kernel, bias), kernel, bias ~ posterior + ``` + where f denotes the layer's calculation. It uses the reparameterization + estimator [1], which performs a Monte Carlo approximation of the + distribution integrating over the `kernel` and `bias`. + + The arguments permit separate specification of the surrogate posterior + (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` + distributions. + + Arguments: + filters: Integer, the dimensionality of the output space (i.e. the number + of filters in the convolution). + kernel_size: An integer or tuple/list of 3 integers, specifying the + depth, height and width of the 3D convolution window. + Can be a single integer to specify the same value for + all spatial dimensions. + strides: An integer or tuple/list of 3 integers, + specifying the strides of the convolution along the depth, + height and width. + Can be a single integer to specify the same value for + all spatial dimensions. + Specifying any stride value != 1 is incompatible with specifying + any `dilation_rate` value != 1. + padding: One of `"valid"` or `"same"` (case-insensitive). + data_format: A string, one of `channels_last` (default) or `channels_first`. + The ordering of the dimensions in the inputs. + `channels_last` corresponds to inputs with shape + `(batch, depth, height, width, channels)` while `channels_first` + corresponds to inputs with shape + `(batch, channels, depth, height, width)`. + dilation_rate: An integer or tuple/list of 3 integers, specifying + the dilation rate to use for dilated convolution. + Can be a single integer to specify the same value for + all spatial dimensions. + Currently, specifying any `dilation_rate` value != 1 is + incompatible with specifying any stride value != 1. + activation: Activation function. Set it to None to maintain a + linear activation. + activity_regularizer: Optional regularizer function for the output. + trainable: Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + kernel_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `kernel` parameter. Default value: + `default_mean_field_normal_fn()`. + kernel_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + kernel_prior_fn: Python `callable` which creates `tf.distributions` + instance. See `default_mean_field_normal_fn` docstring for required + parameter signature. + Default value: `tf.distributions.Normal(loc=0., scale=1.)`. + kernel_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + sample is a `Tensor`. + bias_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `bias` parameter. Default value: + `default_mean_field_normal_fn(is_singular=True)` (which creates an + instance of `tf.distributions.Deterministic`). + bias_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + bias_prior_fn: Python `callable` which creates `tf.distributions` instance. + See `default_mean_field_normal_fn` docstring for required parameter + signature. Default value: `None` (no prior, no variational inference) + bias_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + name: A string, the name of the layer. + + Properties: + filters: Python integer, dimensionality of the output space. + kernel_size: Size of the convolution window. + strides: Stride length of convolution. + padding: Python string describing padding approach. + data_format: Python string describing input data's dimensions. + dilation_rate: Dilation rate for an atrous convolution. + activation: Activation function (`callable`). + activity_regularizer: Regularizer function for the output. + kernel_posterior_fn: `callable` returning posterior. + kernel_posterior_tensor_fn: `callable` operating on posterior. + kernel_prior_fn: `callable` returning prior. + kernel_divergence_fn: `callable` returning divergence. + bias_posterior_fn: `callable` returning posterior. + bias_posterior_tensor_fn: `callable` operating on posterior. + bias_prior_fn: `callable` returning prior. + bias_divergence_fn: `callable` returning divergence. + + #### Examples + + We illustrate a Bayesian neural network with [variational inference]( + https://en.wikipedia.org/wiki/Variational_Bayesian_methods), + assuming a dataset of `features` and `labels`. + + ```python + tfp = tf.contrib.bayesflow + + net = tf.reshape(features, [-1, 256, 32, 32, 3]) + net = tfp.layers.Conv3DReparameterization(64, + kernel_size=5, + padding="SAME", + activation=tf.nn.relu)(net) + net = tf.layers.MaxPooling2D(pool_size=2, + strides=2, + padding="SAME")(net) + net = tf.reshape(net, [-1, 256 * 8 * 8 * 64]) + logits = tfp.layers.DenseReparameterization(10)(net) + neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( + labels=labels, logits=logits) + kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) + loss = neg_log_likelihood + kl + train_op = tf.train.AdamOptimizer().minimize(loss) + ``` + + It uses reparameterization gradients to minimize the + Kullback-Leibler divergence up to a constant, also known as the + negative Evidence Lower Bound. It consists of the sum of two terms: + the expected negative log-likelihood, which we approximate via + Monte Carlo; and the KL divergence, which is added via regularizer + terms which are arguments to the layer. + + [1]: "Auto-Encoding Variational Bayes." + Diederik P. Kingma, Max Welling. + International Conference on Learning Representations, 2014. + """ + + def __init__( + self, + filters, + kernel_size, + strides=(1, 1, 1), + padding="valid", + data_format="channels_last", + dilation_rate=(1, 1, 1), + activation=None, + activity_regularizer=None, + trainable=True, + kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), + kernel_posterior_tensor_fn=lambda d: d.sample(), + kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda + loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), + kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long + bias_posterior_tensor_fn=lambda d: d.sample(), + bias_prior_fn=None, + bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + name=None, + **kwargs): + super(Conv3DReparameterization, self).__init__( + rank=3, + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + data_format=data_format, + dilation_rate=dilation_rate, + activation=activation, + activity_regularizer=activity_regularizer, + trainable=trainable, + kernel_posterior_fn=kernel_posterior_fn, + kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, + kernel_prior_fn=kernel_prior_fn, + kernel_divergence_fn=kernel_divergence_fn, + bias_posterior_fn=bias_posterior_fn, + bias_posterior_tensor_fn=bias_posterior_tensor_fn, + bias_prior_fn=bias_prior_fn, + bias_divergence_fn=bias_divergence_fn, + name=name, **kwargs) + + +def conv3d_reparameterization( + inputs, + filters, + kernel_size, + strides=(1, 1, 1), + padding="valid", + data_format="channels_last", + dilation_rate=(1, 1, 1), + activation=None, + activity_regularizer=None, + trainable=True, + kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), + kernel_posterior_tensor_fn=lambda d: d.sample(), + kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda + loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), + kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long + bias_posterior_tensor_fn=lambda d: d.sample(), + bias_prior_fn=None, + bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + name=None, + reuse=None): + """Functional interface for the 3D convolution layer. + + This layer creates a convolution kernel that is convolved + (actually cross-correlated) with the layer input to produce a tensor of + outputs. It may also include a bias addition and activation function + on the outputs. It assumes the `kernel` and/or `bias` are drawn from + distributions. + + By default, the layer implements a stochastic forward pass via + sampling from the kernel and bias posteriors, + ```none + outputs = f(inputs; kernel, bias), kernel, bias ~ posterior + ``` + where f denotes the layer's calculation. It uses the reparameterization + estimator [1], which performs a Monte Carlo approximation of the + distribution integrating over the `kernel` and `bias`. + + The arguments permit separate specification of the surrogate posterior + (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` + distributions. + + Arguments: + inputs: Tensor input. + filters: Integer, the dimensionality of the output space (i.e. the number + of filters in the convolution). + kernel_size: An integer or tuple/list of 3 integers, specifying the + depth, height and width of the 3D convolution window. + Can be a single integer to specify the same value for + all spatial dimensions. + strides: An integer or tuple/list of 3 integers, + specifying the strides of the convolution along the depth, + height and width. + Can be a single integer to specify the same value for + all spatial dimensions. + Specifying any stride value != 1 is incompatible with specifying + any `dilation_rate` value != 1. + padding: One of `"valid"` or `"same"` (case-insensitive). + data_format: A string, one of `channels_last` (default) or `channels_first`. + The ordering of the dimensions in the inputs. + `channels_last` corresponds to inputs with shape + `(batch, depth, height, width, channels)` while `channels_first` + corresponds to inputs with shape + `(batch, channels, depth, height, width)`. + dilation_rate: An integer or tuple/list of 3 integers, specifying + the dilation rate to use for dilated convolution. + Can be a single integer to specify the same value for + all spatial dimensions. + Currently, specifying any `dilation_rate` value != 1 is + incompatible with specifying any stride value != 1. + activation: Activation function. Set it to None to maintain a + linear activation. + activity_regularizer: Optional regularizer function for the output. + trainable: Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + kernel_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `kernel` parameter. Default value: + `default_mean_field_normal_fn()`. + kernel_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + kernel_prior_fn: Python `callable` which creates `tf.distributions` + instance. See `default_mean_field_normal_fn` docstring for required + parameter signature. + Default value: `tf.distributions.Normal(loc=0., scale=1.)`. + kernel_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + sample is a `Tensor`. + bias_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `bias` parameter. Default value: + `default_mean_field_normal_fn(is_singular=True)` (which creates an + instance of `tf.distributions.Deterministic`). + bias_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + bias_prior_fn: Python `callable` which creates `tf.distributions` instance. + See `default_mean_field_normal_fn` docstring for required parameter + signature. Default value: `None` (no prior, no variational inference) + bias_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + name: A string, the name of the layer. + reuse: Boolean, whether to reuse the weights of a previous layer + by the same name. + + Returns: + Output tensor. + + Raises: + ValueError: if eager execution is enabled. + + #### Examples + + We illustrate a Bayesian neural network with [variational inference]( + https://en.wikipedia.org/wiki/Variational_Bayesian_methods), + assuming a dataset of `features` and `labels`. + + ```python + tfp = tf.contrib.bayesflow + + net = tf.reshape(features, [-1, 256, 32, 32, 3]) + net = tfp.layers.conv3d_reparameterization(net, + filters=64, + kernel_size=5, + padding="SAME", + activation=tf.nn.relu) + net = tf.layers.max_pooling2d(net, + pool_size=2, + strides=2, + padding="SAME") + net = tf.reshape(net, [-1, 256 * 8 * 8 * 64]) + logits = tfp.layers.dense_reparameterization(net, 10) + neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( + labels=labels, logits=logits) + kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) + loss = neg_log_likelihood + kl + train_op = tf.train.AdamOptimizer().minimize(loss) + ``` + + It uses reparameterization gradients to minimize the + Kullback-Leibler divergence up to a constant, also known as the + negative Evidence Lower Bound. It consists of the sum of two terms: + the expected negative log-likelihood, which we approximate via + Monte Carlo; and the KL divergence, which is added via regularizer + terms which are arguments to the layer. + + [1]: "Auto-Encoding Variational Bayes." + Diederik P. Kingma, Max Welling. + International Conference on Learning Representations, 2014. + """ + layer = Conv3DReparameterization( + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + data_format=data_format, + dilation_rate=dilation_rate, + activation=activation, + activity_regularizer=activity_regularizer, + trainable=trainable, + kernel_posterior_fn=kernel_posterior_fn, + kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, + kernel_prior_fn=kernel_prior_fn, + kernel_divergence_fn=kernel_divergence_fn, + bias_posterior_fn=bias_posterior_fn, + bias_posterior_tensor_fn=bias_posterior_tensor_fn, + bias_prior_fn=bias_prior_fn, + bias_divergence_fn=bias_divergence_fn, + name=name, + dtype=inputs.dtype.base_dtype, + _scope=name, + _reuse=reuse) + return layer.apply(inputs) + + +class _ConvFlipout(_ConvVariational): + """Abstract nD convolution layer (private, used as implementation base). + + This layer creates a convolution kernel that is convolved + (actually cross-correlated) with the layer input to produce a tensor of + outputs. It may also include a bias addition and activation function + on the outputs. It assumes the `kernel` and/or `bias` are drawn from + distributions. + + By default, the layer implements a stochastic forward pass via + sampling from the kernel and bias posteriors, + ```none + outputs = f(inputs; kernel, bias), kernel, bias ~ posterior + ``` + where f denotes the layer's calculation. It uses the Flipout + estimator [1], which performs a Monte Carlo approximation of the + distribution integrating over the `kernel` and `bias`. Flipout uses + roughly twice as many floating point operations as the + reparameterization estimator but has the advantage of significantly + lower variance. + + The arguments permit separate specification of the surrogate posterior + (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` + distributions. + + Arguments: + rank: An integer, the rank of the convolution, e.g. "2" for 2D convolution. + filters: Integer, the dimensionality of the output space (i.e. the number + of filters in the convolution). + kernel_size: An integer or tuple/list of n integers, specifying the + length of the convolution window. + strides: An integer or tuple/list of n integers, + specifying the stride length of the convolution. + Specifying any stride value != 1 is incompatible with specifying + any `dilation_rate` value != 1. + padding: One of `"valid"` or `"same"` (case-insensitive). + data_format: A string, one of `channels_last` (default) or `channels_first`. + The ordering of the dimensions in the inputs. + `channels_last` corresponds to inputs with shape + `(batch, ..., channels)` while `channels_first` corresponds to + inputs with shape `(batch, channels, ...)`. + dilation_rate: An integer or tuple/list of n integers, specifying + the dilation rate to use for dilated convolution. + Currently, specifying any `dilation_rate` value != 1 is + incompatible with specifying any `strides` value != 1. + activation: Activation function. Set it to None to maintain a + linear activation. + activity_regularizer: Optional regularizer function for the output. + trainable: Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + kernel_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `kernel` parameter. Default value: + `default_mean_field_normal_fn()`. + kernel_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + kernel_prior_fn: Python `callable` which creates `tf.distributions` + instance. See `default_mean_field_normal_fn` docstring for required + parameter signature. + Default value: `tf.distributions.Normal(loc=0., scale=1.)`. + kernel_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + sample is a `Tensor`. + bias_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `bias` parameter. Default value: + `default_mean_field_normal_fn(is_singular=True)` (which creates an + instance of `tf.distributions.Deterministic`). + bias_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + bias_prior_fn: Python `callable` which creates `tf.distributions` instance. + See `default_mean_field_normal_fn` docstring for required parameter + signature. Default value: `None` (no prior, no variational inference) + bias_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + sample is a `Tensor`. + seed: Python scalar `int` which initializes the random number + generator. Default value: `None` (i.e., use global seed). + name: A string, the name of the layer. + + Properties: + rank: Python integer, dimensionality of convolution. + filters: Python integer, dimensionality of the output space. + kernel_size: Size of the convolution window. + strides: Stride length of convolution. + padding: Python string describing padding approach. + data_format: Python string describing input data's dimensions. + dilation_rate: Dilation rate for an atrous convolution. + activation: Activation function (`callable`). + activity_regularizer: Regularizer function for the output. + kernel_posterior_fn: `callable` returning posterior. + kernel_posterior_tensor_fn: `callable` operating on posterior. + kernel_prior_fn: `callable` returning prior. + kernel_divergence_fn: `callable` returning divergence. + bias_posterior_fn: `callable` returning posterior. + bias_posterior_tensor_fn: `callable` operating on posterior. + bias_prior_fn: `callable` returning prior. + bias_divergence_fn: `callable` returning divergence. + seed: Python integer, used to create random seeds. + + [1]: "Flipout: Efficient Pseudo-Independent Weight Perturbations on + Mini-Batches." + Anonymous. OpenReview, 2017. + https://openreview.net/forum?id=rJnpifWAb + """ + + def __init__( + self, + rank, + filters, + kernel_size, + strides=1, + padding="valid", + data_format="channels_last", + dilation_rate=1, + activation=None, + activity_regularizer=None, + trainable=True, + kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), + kernel_posterior_tensor_fn=lambda d: d.sample(), + kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda + loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), + kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long + bias_posterior_tensor_fn=lambda d: d.sample(), + bias_prior_fn=None, + bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + seed=None, + name=None, + **kwargs): + super(_ConvFlipout, self).__init__( + rank=rank, + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + data_format=data_format, + dilation_rate=dilation_rate, + activation=activation, + activity_regularizer=activity_regularizer, + trainable=trainable, + kernel_posterior_fn=kernel_posterior_fn, + kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, + kernel_prior_fn=kernel_prior_fn, + kernel_divergence_fn=kernel_divergence_fn, + bias_posterior_fn=bias_posterior_fn, + bias_posterior_tensor_fn=bias_posterior_tensor_fn, + bias_prior_fn=bias_prior_fn, + bias_divergence_fn=bias_divergence_fn, + name=name, **kwargs) + self.seed = seed + + def _apply_variational_kernel(self, inputs): + if (not isinstance(self.kernel_posterior, independent_lib.Independent) or + not isinstance(self.kernel_posterior.distribution, normal_lib.Normal)): + raise TypeError( + "`{}` requires " + "`kernel_posterior_fn` produce an instance of " + "`tf.distributions.Independent(tf.distributions.Normal)` " + "(saw: \"{}\").".format( + type(self).__name__, self.kernel_posterior.name)) + self.kernel_posterior_affine = normal_lib.Normal( + loc=array_ops.zeros_like(self.kernel_posterior.distribution.loc), + scale=self.kernel_posterior.distribution.scale) + self.kernel_posterior_affine_tensor = ( + self.kernel_posterior_tensor_fn(self.kernel_posterior_affine)) + self.kernel_posterior_tensor = None + + outputs = self._convolution_op( + inputs, self.kernel_posterior.distribution.loc) + + input_shape = array_ops.shape(inputs) + output_shape = array_ops.shape(outputs) + batch_shape = array_ops.expand_dims(input_shape[0], 0) + channels = input_shape[-1] + + sign_input = layers_util.random_sign( + array_ops.concat([batch_shape, + array_ops.expand_dims(channels, 0)], 0), + dtype=inputs.dtype, + seed=self.seed) + sign_output = layers_util.random_sign( + array_ops.concat([batch_shape, + array_ops.expand_dims(self.filters, 0)], 0), + dtype=inputs.dtype, + seed=distribution_util.gen_new_seed( + self.seed, salt="conv_flipout")) + for _ in range(self.rank): + sign_input = array_ops.expand_dims(sign_input, 1) # 2D ex: (B, 1, 1, C) + sign_output = array_ops.expand_dims(sign_output, 1) + + sign_input = array_ops.tile( # tile for element-wise op broadcasting + sign_input, + [1] + [input_shape[i + 1] for i in range(self.rank)] + [1]) + sign_output = array_ops.tile( + sign_output, + [1] + [output_shape[i + 1] for i in range(self.rank)] + [1]) + + perturbed_inputs = self._convolution_op( + inputs * sign_input, self.kernel_posterior_affine_tensor) * sign_output + + outputs += perturbed_inputs + return outputs + + +class Conv1DFlipout(_ConvFlipout): + """1D convolution layer (e.g. temporal convolution) with Flipout. + + This layer creates a convolution kernel that is convolved + (actually cross-correlated) with the layer input to produce a tensor of + outputs. It may also include a bias addition and activation function + on the outputs. It assumes the `kernel` and/or `bias` are drawn from + distributions. + + By default, the layer implements a stochastic forward pass via + sampling from the kernel and bias posteriors, + ```none + outputs = f(inputs; kernel, bias), kernel, bias ~ posterior + ``` + where f denotes the layer's calculation. It uses the Flipout + estimator [1], which performs a Monte Carlo approximation of the + distribution integrating over the `kernel` and `bias`. Flipout uses + roughly twice as many floating point operations as the + reparameterization estimator but has the advantage of significantly + lower variance. + + The arguments permit separate specification of the surrogate posterior + (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` + distributions. + + Arguments: + filters: Integer, the dimensionality of the output space (i.e. the number + of filters in the convolution). + kernel_size: An integer or tuple/list of a single integer, specifying the + length of the 1D convolution window. + strides: An integer or tuple/list of a single integer, + specifying the stride length of the convolution. + Specifying any stride value != 1 is incompatible with specifying + any `dilation_rate` value != 1. + padding: One of `"valid"` or `"same"` (case-insensitive). + data_format: A string, one of `channels_last` (default) or `channels_first`. + The ordering of the dimensions in the inputs. + `channels_last` corresponds to inputs with shape + `(batch, length, channels)` while `channels_first` corresponds to + inputs with shape `(batch, channels, length)`. + dilation_rate: An integer or tuple/list of a single integer, specifying + the dilation rate to use for dilated convolution. + Currently, specifying any `dilation_rate` value != 1 is + incompatible with specifying any `strides` value != 1. + activation: Activation function. Set it to None to maintain a + linear activation. + activity_regularizer: Optional regularizer function for the output. + trainable: Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + kernel_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `kernel` parameter. Default value: + `default_mean_field_normal_fn()`. + kernel_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + kernel_prior_fn: Python `callable` which creates `tf.distributions` + instance. See `default_mean_field_normal_fn` docstring for required + parameter signature. + Default value: `tf.distributions.Normal(loc=0., scale=1.)`. + kernel_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + sample is a `Tensor`. + bias_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `bias` parameter. Default value: + `default_mean_field_normal_fn(is_singular=True)` (which creates an + instance of `tf.distributions.Deterministic`). + bias_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + bias_prior_fn: Python `callable` which creates `tf.distributions` instance. + See `default_mean_field_normal_fn` docstring for required parameter + signature. Default value: `None` (no prior, no variational inference) + bias_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + seed: Python scalar `int` which initializes the random number + generator. Default value: `None` (i.e., use global seed). + name: A string, the name of the layer. + + Properties: + filters: Python integer, dimensionality of the output space. + kernel_size: Size of the convolution window. + strides: Stride length of convolution. + padding: Python string describing padding approach. + data_format: Python string describing input data's dimensions. + dilation_rate: Dilation rate for an atrous convolution. + activation: Activation function (`callable`). + activity_regularizer: Regularizer function for the output. + kernel_posterior_fn: `callable` returning posterior. + kernel_posterior_tensor_fn: `callable` operating on posterior. + kernel_prior_fn: `callable` returning prior. + kernel_divergence_fn: `callable` returning divergence. + bias_posterior_fn: `callable` returning posterior. + bias_posterior_tensor_fn: `callable` operating on posterior. + bias_prior_fn: `callable` returning prior. + bias_divergence_fn: `callable` returning divergence. + seed: Python integer, used to create random seeds. + + #### Examples + + We illustrate a Bayesian neural network with [variational inference]( + https://en.wikipedia.org/wiki/Variational_Bayesian_methods), + assuming a dataset of `features` and `labels`. + + ```python + tfp = tf.contrib.bayesflow + + net = tf.reshape(features, [-1, 128, 1]) + net = tfp.layers.Conv1DFlipout(64, + kernel_size=5, + padding="SAME", + activation=tf.nn.relu)(net) + net = tf.reshape(net, [-1, 128 * 64]) + logits = tfp.layers.DenseFlipout(10)(net) + neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( + labels=labels, logits=logits) + kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) + loss = neg_log_likelihood + kl + train_op = tf.train.AdamOptimizer().minimize(loss) + ``` + + It uses the Flipout gradient estimator to minimize the + Kullback-Leibler divergence up to a constant, also known as the + negative Evidence Lower Bound. It consists of the sum of two terms: + the expected negative log-likelihood, which we approximate via + Monte Carlo; and the KL divergence, which is added via regularizer + terms which are arguments to the layer. + + [1]: "Flipout: Efficient Pseudo-Independent Weight Perturbations on + Mini-Batches." + Anonymous. OpenReview, 2017. + https://openreview.net/forum?id=rJnpifWAb + """ + + def __init__( + self, + filters, + kernel_size, + strides=1, + padding="valid", + data_format="channels_last", + dilation_rate=1, + activation=None, + activity_regularizer=None, + trainable=True, + kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), + kernel_posterior_tensor_fn=lambda d: d.sample(), + kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda + loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), + kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long + bias_posterior_tensor_fn=lambda d: d.sample(), + bias_prior_fn=None, + bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + seed=None, + name=None, + **kwargs): + super(Conv1DFlipout, self).__init__( + rank=1, + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + data_format=data_format, + dilation_rate=dilation_rate, + activation=activation, + activity_regularizer=activity_regularizer, + trainable=trainable, + kernel_posterior_fn=kernel_posterior_fn, + kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, + kernel_prior_fn=kernel_prior_fn, + kernel_divergence_fn=kernel_divergence_fn, + bias_posterior_fn=bias_posterior_fn, + bias_posterior_tensor_fn=bias_posterior_tensor_fn, + bias_prior_fn=bias_prior_fn, + bias_divergence_fn=bias_divergence_fn, + seed=seed, + name=name, **kwargs) + + +def conv1d_flipout( + inputs, + filters, + kernel_size, + strides=1, + padding="valid", + data_format="channels_last", + dilation_rate=1, + activation=None, + activity_regularizer=None, + trainable=True, + kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), + kernel_posterior_tensor_fn=lambda d: d.sample(), + kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda + loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), + kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long + bias_posterior_tensor_fn=lambda d: d.sample(), + bias_prior_fn=None, + bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + seed=None, + name=None, + reuse=None): + """Functional interface for 1D convolution layer (e.g. temporal convolution). + + This layer creates a convolution kernel that is convolved + (actually cross-correlated) with the layer input to produce a tensor of + outputs. It may also include a bias addition and activation function + on the outputs. It assumes the `kernel` and/or `bias` are drawn from + distributions. + + By default, the layer implements a stochastic forward pass via + sampling from the kernel and bias posteriors, + ```none + outputs = f(inputs; kernel, bias), kernel, bias ~ posterior + ``` + where f denotes the layer's calculation. It uses the Flipout + estimator [1], which performs a Monte Carlo approximation of the + distribution integrating over the `kernel` and `bias`. Flipout uses + roughly twice as many floating point operations as the + reparameterization estimator but has the advantage of significantly + lower variance. + + The arguments permit separate specification of the surrogate posterior + (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` + distributions. + + Arguments: + inputs: Tensor input. + filters: Integer, the dimensionality of the output space (i.e. the number + of filters in the convolution). + kernel_size: An integer or tuple/list of a single integer, specifying the + length of the 1D convolution window. + strides: An integer or tuple/list of a single integer, + specifying the stride length of the convolution. + Specifying any stride value != 1 is incompatible with specifying + any `dilation_rate` value != 1. + padding: One of `"valid"` or `"same"` (case-insensitive). + data_format: A string, one of `channels_last` (default) or `channels_first`. + The ordering of the dimensions in the inputs. + `channels_last` corresponds to inputs with shape + `(batch, length, channels)` while `channels_first` corresponds to + inputs with shape `(batch, channels, length)`. + dilation_rate: An integer or tuple/list of a single integer, specifying + the dilation rate to use for dilated convolution. + Currently, specifying any `dilation_rate` value != 1 is + incompatible with specifying any `strides` value != 1. + activation: Activation function. Set it to None to maintain a + linear activation. + activity_regularizer: Optional regularizer function for the output. + trainable: Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + kernel_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `kernel` parameter. Default value: + `default_mean_field_normal_fn()`. + kernel_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + kernel_prior_fn: Python `callable` which creates `tf.distributions` + instance. See `default_mean_field_normal_fn` docstring for required + parameter signature. + Default value: `tf.distributions.Normal(loc=0., scale=1.)`. + kernel_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + sample is a `Tensor`. + bias_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `bias` parameter. Default value: + `default_mean_field_normal_fn(is_singular=True)` (which creates an + instance of `tf.distributions.Deterministic`). + bias_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + bias_prior_fn: Python `callable` which creates `tf.distributions` instance. + See `default_mean_field_normal_fn` docstring for required parameter + signature. Default value: `None` (no prior, no variational inference) + bias_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + seed: Python scalar `int` which initializes the random number + generator. Default value: `None` (i.e., use global seed). + name: A string, the name of the layer. + reuse: Boolean, whether to reuse the weights of a previous layer + by the same name. + + Returns: + Output tensor. + + Raises: + ValueError: if eager execution is enabled. + + #### Examples + + We illustrate a Bayesian neural network with [variational inference]( + https://en.wikipedia.org/wiki/Variational_Bayesian_methods), + assuming a dataset of `features` and `labels`. + + ```python + tfp = tf.contrib.bayesflow + + net = tf.reshape(features, [-1, 128, 1]) + net = tfp.layers.conv1d_flipout(net, + filters=64, + kernel_size=5, + padding="SAME", + activation=tf.nn.relu) + net = tf.reshape(net, [-1, 128 * 64]) + logits = tfp.layers.dense_flipout(net, 10) + neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( + labels=labels, logits=logits) + kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) + loss = neg_log_likelihood + kl + train_op = tf.train.AdamOptimizer().minimize(loss) + ``` + + It uses the Flipout gradient estimator to minimize the + Kullback-Leibler divergence up to a constant, also known as the + negative Evidence Lower Bound. It consists of the sum of two terms: + the expected negative log-likelihood, which we approximate via + Monte Carlo; and the KL divergence, which is added via regularizer + terms which are arguments to the layer. + + [1]: "Flipout: Efficient Pseudo-Independent Weight Perturbations on + Mini-Batches." + Anonymous. OpenReview, 2017. + https://openreview.net/forum?id=rJnpifWAb + """ + layer = Conv1DFlipout( + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + data_format=data_format, + dilation_rate=dilation_rate, + activation=activation, + activity_regularizer=activity_regularizer, + trainable=trainable, + kernel_posterior_fn=kernel_posterior_fn, + kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, + kernel_prior_fn=kernel_prior_fn, + kernel_divergence_fn=kernel_divergence_fn, + bias_posterior_fn=bias_posterior_fn, + bias_posterior_tensor_fn=bias_posterior_tensor_fn, + bias_prior_fn=bias_prior_fn, + bias_divergence_fn=bias_divergence_fn, + seed=seed, + name=name, + dtype=inputs.dtype.base_dtype, + _scope=name, + _reuse=reuse) + return layer.apply(inputs) + + +class Conv2DFlipout(_ConvFlipout): + """2D convolution layer (e.g. spatial convolution over images) with Flipout. + + This layer creates a convolution kernel that is convolved + (actually cross-correlated) with the layer input to produce a tensor of + outputs. It may also include a bias addition and activation function + on the outputs. It assumes the `kernel` and/or `bias` are drawn from + distributions. + + By default, the layer implements a stochastic forward pass via + sampling from the kernel and bias posteriors, + ```none + outputs = f(inputs; kernel, bias), kernel, bias ~ posterior + ``` + where f denotes the layer's calculation. It uses the Flipout + estimator [1], which performs a Monte Carlo approximation of the + distribution integrating over the `kernel` and `bias`. Flipout uses + roughly twice as many floating point operations as the + reparameterization estimator but has the advantage of significantly + lower variance. + + The arguments permit separate specification of the surrogate posterior + (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` + distributions. + + Arguments: + filters: Integer, the dimensionality of the output space (i.e. the number + of filters in the convolution). + kernel_size: An integer or tuple/list of 2 integers, specifying the + height and width of the 2D convolution window. + Can be a single integer to specify the same value for + all spatial dimensions. + strides: An integer or tuple/list of 2 integers, + specifying the strides of the convolution along the height and width. + Can be a single integer to specify the same value for + all spatial dimensions. + Specifying any stride value != 1 is incompatible with specifying + any `dilation_rate` value != 1. + padding: One of `"valid"` or `"same"` (case-insensitive). + data_format: A string, one of `channels_last` (default) or `channels_first`. + The ordering of the dimensions in the inputs. + `channels_last` corresponds to inputs with shape + `(batch, height, width, channels)` while `channels_first` corresponds to + inputs with shape `(batch, channels, height, width)`. + + dilation_rate: An integer or tuple/list of 2 integers, specifying + the dilation rate to use for dilated convolution. + Can be a single integer to specify the same value for + all spatial dimensions. + Currently, specifying any `dilation_rate` value != 1 is + incompatible with specifying any stride value != 1. + activation: Activation function. Set it to None to maintain a + linear activation. + activity_regularizer: Optional regularizer function for the output. + trainable: Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + kernel_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `kernel` parameter. Default value: + `default_mean_field_normal_fn()`. + kernel_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + kernel_prior_fn: Python `callable` which creates `tf.distributions` + instance. See `default_mean_field_normal_fn` docstring for required + parameter signature. + Default value: `tf.distributions.Normal(loc=0., scale=1.)`. + kernel_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + sample is a `Tensor`. + bias_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `bias` parameter. Default value: + `default_mean_field_normal_fn(is_singular=True)` (which creates an + instance of `tf.distributions.Deterministic`). + bias_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + bias_prior_fn: Python `callable` which creates `tf.distributions` instance. + See `default_mean_field_normal_fn` docstring for required parameter + signature. Default value: `None` (no prior, no variational inference) + bias_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + seed: Python scalar `int` which initializes the random number + generator. Default value: `None` (i.e., use global seed). + name: A string, the name of the layer. + + Properties: + filters: Python integer, dimensionality of the output space. + kernel_size: Size of the convolution window. + strides: Stride length of convolution. + padding: Python string describing padding approach. + data_format: Python string describing input data's dimensions. + dilation_rate: Dilation rate for an atrous convolution. + activation: Activation function (`callable`). + activity_regularizer: Regularizer function for the output. + kernel_posterior_fn: `callable` returning posterior. + kernel_posterior_tensor_fn: `callable` operating on posterior. + kernel_prior_fn: `callable` returning prior. + kernel_divergence_fn: `callable` returning divergence. + bias_posterior_fn: `callable` returning posterior. + bias_posterior_tensor_fn: `callable` operating on posterior. + bias_prior_fn: `callable` returning prior. + bias_divergence_fn: `callable` returning divergence. + seed: Python integer, used to create random seeds. + + #### Examples + + We illustrate a Bayesian neural network with [variational inference]( + https://en.wikipedia.org/wiki/Variational_Bayesian_methods), + assuming a dataset of `features` and `labels`. + + ```python + tfp = tf.contrib.bayesflow + + net = tf.reshape(features, [-1, 32, 32, 3]) + net = tfp.layers.Conv2DFlipout(64, + kernel_size=5, + padding="SAME", + activation=tf.nn.relu)(net) + net = tf.layers.MaxPooling2D(pool_size=2, + strides=2, + padding="SAME")(net) + net = tf.reshape(net, [-1, 8 * 8 * 64]) + logits = tfp.layers.DenseFlipout(10)(net) + neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( + labels=labels, logits=logits) + kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) + loss = neg_log_likelihood + kl + train_op = tf.train.AdamOptimizer().minimize(loss) + ``` + + It uses the Flipout gradient estimator to minimize the + Kullback-Leibler divergence up to a constant, also known as the + negative Evidence Lower Bound. It consists of the sum of two terms: + the expected negative log-likelihood, which we approximate via + Monte Carlo; and the KL divergence, which is added via regularizer + terms which are arguments to the layer. + + [1]: "Flipout: Efficient Pseudo-Independent Weight Perturbations on + Mini-Batches." + Anonymous. OpenReview, 2017. + https://openreview.net/forum?id=rJnpifWAb + """ + + def __init__( + self, + filters, + kernel_size, + strides=(1, 1), + padding="valid", + data_format="channels_last", + dilation_rate=(1, 1), + activation=None, + activity_regularizer=None, + trainable=True, + kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), + kernel_posterior_tensor_fn=lambda d: d.sample(), + kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda + loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), + kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long + bias_posterior_tensor_fn=lambda d: d.sample(), + bias_prior_fn=None, + bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + seed=None, + name=None, + **kwargs): + super(Conv2DFlipout, self).__init__( + rank=2, + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + data_format=data_format, + dilation_rate=dilation_rate, + activation=activation, + activity_regularizer=activity_regularizer, + trainable=trainable, + kernel_posterior_fn=kernel_posterior_fn, + kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, + kernel_prior_fn=kernel_prior_fn, + kernel_divergence_fn=kernel_divergence_fn, + bias_posterior_fn=bias_posterior_fn, + bias_posterior_tensor_fn=bias_posterior_tensor_fn, + bias_prior_fn=bias_prior_fn, + bias_divergence_fn=bias_divergence_fn, + seed=seed, + name=name, **kwargs) + + +def conv2d_flipout( + inputs, + filters, + kernel_size, + strides=(1, 1), + padding="valid", + data_format="channels_last", + dilation_rate=(1, 1), + activation=None, + activity_regularizer=None, + trainable=True, + kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), + kernel_posterior_tensor_fn=lambda d: d.sample(), + kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda + loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), + kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long + bias_posterior_tensor_fn=lambda d: d.sample(), + bias_prior_fn=None, + bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + seed=None, + name=None, + reuse=None): + """Functional interface for the 2D convolution layer. + + This layer creates a convolution kernel that is convolved + (actually cross-correlated) with the layer input to produce a tensor of + outputs. It may also include a bias addition and activation function + on the outputs. It assumes the `kernel` and/or `bias` are drawn from + distributions. + + By default, the layer implements a stochastic forward pass via + sampling from the kernel and bias posteriors, + ```none + outputs = f(inputs; kernel, bias), kernel, bias ~ posterior + ``` + where f denotes the layer's calculation. It uses the Flipout + estimator [1], which performs a Monte Carlo approximation of the + distribution integrating over the `kernel` and `bias`. Flipout uses + roughly twice as many floating point operations as the + reparameterization estimator but has the advantage of significantly + lower variance. + + The arguments permit separate specification of the surrogate posterior + (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` + distributions. + + Arguments: + inputs: Tensor input. + filters: Integer, the dimensionality of the output space (i.e. the number + of filters in the convolution). + kernel_size: An integer or tuple/list of 2 integers, specifying the + height and width of the 2D convolution window. + Can be a single integer to specify the same value for + all spatial dimensions. + strides: An integer or tuple/list of 2 integers, + specifying the strides of the convolution along the height and width. + Can be a single integer to specify the same value for + all spatial dimensions. + Specifying any stride value != 1 is incompatible with specifying + any `dilation_rate` value != 1. + padding: One of `"valid"` or `"same"` (case-insensitive). + data_format: A string, one of `channels_last` (default) or `channels_first`. + The ordering of the dimensions in the inputs. + `channels_last` corresponds to inputs with shape + `(batch, height, width, channels)` while `channels_first` corresponds to + inputs with shape `(batch, channels, height, width)`. + + dilation_rate: An integer or tuple/list of 2 integers, specifying + the dilation rate to use for dilated convolution. + Can be a single integer to specify the same value for + all spatial dimensions. + Currently, specifying any `dilation_rate` value != 1 is + incompatible with specifying any stride value != 1. + activation: Activation function. Set it to None to maintain a + linear activation. + activity_regularizer: Optional regularizer function for the output. + trainable: Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + kernel_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `kernel` parameter. Default value: + `default_mean_field_normal_fn()`. + kernel_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + kernel_prior_fn: Python `callable` which creates `tf.distributions` + instance. See `default_mean_field_normal_fn` docstring for required + parameter signature. + Default value: `tf.distributions.Normal(loc=0., scale=1.)`. + kernel_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + sample is a `Tensor`. + bias_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `bias` parameter. Default value: + `default_mean_field_normal_fn(is_singular=True)` (which creates an + instance of `tf.distributions.Deterministic`). + bias_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + bias_prior_fn: Python `callable` which creates `tf.distributions` instance. + See `default_mean_field_normal_fn` docstring for required parameter + signature. Default value: `None` (no prior, no variational inference) + bias_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + seed: Python scalar `int` which initializes the random number + generator. Default value: `None` (i.e., use global seed). + name: A string, the name of the layer. + reuse: Boolean, whether to reuse the weights of a previous layer + by the same name. + + Returns: + Output tensor. + + Raises: + ValueError: if eager execution is enabled. + + #### Examples + + We illustrate a Bayesian neural network with [variational inference]( + https://en.wikipedia.org/wiki/Variational_Bayesian_methods), + assuming a dataset of `features` and `labels`. + + ```python + tfp = tf.contrib.bayesflow + + net = tf.reshape(features, [-1, 32, 32, 3]) + net = tfp.layers.conv2d_flipout(net, + filters=64, + kernel_size=5, + padding="SAME", + activation=tf.nn.relu) + net = tf.layers.max_pooling2d(net, + pool_size=2, + strides=2, + padding="SAME") + net = tf.reshape(net, [-1, 8 * 8 * 64]) + logits = tfp.layers.dense_flipout(net, 10) + neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( + labels=labels, logits=logits) + kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) + loss = neg_log_likelihood + kl + train_op = tf.train.AdamOptimizer().minimize(loss) + ``` + + It uses the Flipout gradient estimator to minimize the + Kullback-Leibler divergence up to a constant, also known as the + negative Evidence Lower Bound. It consists of the sum of two terms: + the expected negative log-likelihood, which we approximate via + Monte Carlo; and the KL divergence, which is added via regularizer + terms which are arguments to the layer. + + [1]: "Flipout: Efficient Pseudo-Independent Weight Perturbations on + Mini-Batches." + Anonymous. OpenReview, 2017. + https://openreview.net/forum?id=rJnpifWAb + """ + layer = Conv2DFlipout( + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + data_format=data_format, + dilation_rate=dilation_rate, + activation=activation, + activity_regularizer=activity_regularizer, + trainable=trainable, + kernel_posterior_fn=kernel_posterior_fn, + kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, + kernel_prior_fn=kernel_prior_fn, + kernel_divergence_fn=kernel_divergence_fn, + bias_posterior_fn=bias_posterior_fn, + bias_posterior_tensor_fn=bias_posterior_tensor_fn, + bias_prior_fn=bias_prior_fn, + bias_divergence_fn=bias_divergence_fn, + seed=seed, + name=name, + dtype=inputs.dtype.base_dtype, + _scope=name, + _reuse=reuse) + return layer.apply(inputs) + + +class Conv3DFlipout(_ConvFlipout): + """3D convolution layer (e.g. spatial convolution over volumes) with Flipout. + + This layer creates a convolution kernel that is convolved + (actually cross-correlated) with the layer input to produce a tensor of + outputs. It may also include a bias addition and activation function + on the outputs. It assumes the `kernel` and/or `bias` are drawn from + distributions. + + By default, the layer implements a stochastic forward pass via + sampling from the kernel and bias posteriors, + ```none + outputs = f(inputs; kernel, bias), kernel, bias ~ posterior + ``` + where f denotes the layer's calculation. It uses the Flipout + estimator [1], which performs a Monte Carlo approximation of the + distribution integrating over the `kernel` and `bias`. Flipout uses + roughly twice as many floating point operations as the + reparameterization estimator but has the advantage of significantly + lower variance. + + The arguments permit separate specification of the surrogate posterior + (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` + distributions. + + Arguments: + filters: Integer, the dimensionality of the output space (i.e. the number + of filters in the convolution). + kernel_size: An integer or tuple/list of 3 integers, specifying the + depth, height and width of the 3D convolution window. + Can be a single integer to specify the same value for + all spatial dimensions. + strides: An integer or tuple/list of 3 integers, + specifying the strides of the convolution along the depth, + height and width. + Can be a single integer to specify the same value for + all spatial dimensions. + Specifying any stride value != 1 is incompatible with specifying + any `dilation_rate` value != 1. + padding: One of `"valid"` or `"same"` (case-insensitive). + data_format: A string, one of `channels_last` (default) or `channels_first`. + The ordering of the dimensions in the inputs. + `channels_last` corresponds to inputs with shape + `(batch, depth, height, width, channels)` while `channels_first` + corresponds to inputs with shape + `(batch, channels, depth, height, width)`. + dilation_rate: An integer or tuple/list of 3 integers, specifying + the dilation rate to use for dilated convolution. + Can be a single integer to specify the same value for + all spatial dimensions. + Currently, specifying any `dilation_rate` value != 1 is + incompatible with specifying any stride value != 1. + activation: Activation function. Set it to None to maintain a + linear activation. + activity_regularizer: Optional regularizer function for the output. + trainable: Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + kernel_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `kernel` parameter. Default value: + `default_mean_field_normal_fn()`. + kernel_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + kernel_prior_fn: Python `callable` which creates `tf.distributions` + instance. See `default_mean_field_normal_fn` docstring for required + parameter signature. + Default value: `tf.distributions.Normal(loc=0., scale=1.)`. + kernel_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + sample is a `Tensor`. + bias_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `bias` parameter. Default value: + `default_mean_field_normal_fn(is_singular=True)` (which creates an + instance of `tf.distributions.Deterministic`). + bias_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + bias_prior_fn: Python `callable` which creates `tf.distributions` instance. + See `default_mean_field_normal_fn` docstring for required parameter + signature. Default value: `None` (no prior, no variational inference) + bias_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + seed: Python scalar `int` which initializes the random number + generator. Default value: `None` (i.e., use global seed). + name: A string, the name of the layer. + + Properties: + filters: Python integer, dimensionality of the output space. + kernel_size: Size of the convolution window. + strides: Stride length of convolution. + padding: Python string describing padding approach. + data_format: Python string describing input data's dimensions. + dilation_rate: Dilation rate for an atrous convolution. + activation: Activation function (`callable`). + activity_regularizer: Regularizer function for the output. + kernel_posterior_fn: `callable` returning posterior. + kernel_posterior_tensor_fn: `callable` operating on posterior. + kernel_prior_fn: `callable` returning prior. + kernel_divergence_fn: `callable` returning divergence. + bias_posterior_fn: `callable` returning posterior. + bias_posterior_tensor_fn: `callable` operating on posterior. + bias_prior_fn: `callable` returning prior. + bias_divergence_fn: `callable` returning divergence. + seed: Python integer, used to create random seeds. + + #### Examples + + We illustrate a Bayesian neural network with [variational inference]( + https://en.wikipedia.org/wiki/Variational_Bayesian_methods), + assuming a dataset of `features` and `labels`. + + ```python + tfp = tf.contrib.bayesflow + + net = tf.reshape(features, [-1, 256, 32, 32, 3]) + net = tfp.layers.Conv3DFlipout(64, + kernel_size=5, + padding="SAME", + activation=tf.nn.relu)(net) + net = tf.layers.MaxPooling2D(pool_size=2, + strides=2, + padding="SAME")(net) + net = tf.reshape(net, [-1, 256 * 8 * 8 * 64]) + logits = tfp.layers.DenseFlipout(10)(net) + neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( + labels=labels, logits=logits) + kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) + loss = neg_log_likelihood + kl + train_op = tf.train.AdamOptimizer().minimize(loss) + ``` + + It uses the Flipout gradient estimator to minimize the + Kullback-Leibler divergence up to a constant, also known as the + negative Evidence Lower Bound. It consists of the sum of two terms: + the expected negative log-likelihood, which we approximate via + Monte Carlo; and the KL divergence, which is added via regularizer + terms which are arguments to the layer. + + [1]: "Flipout: Efficient Pseudo-Independent Weight Perturbations on + Mini-Batches." + Anonymous. OpenReview, 2017. + https://openreview.net/forum?id=rJnpifWAb + """ + + def __init__( + self, + filters, + kernel_size, + strides=(1, 1, 1), + padding="valid", + data_format="channels_last", + dilation_rate=(1, 1, 1), + activation=None, + activity_regularizer=None, + trainable=True, + kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), + kernel_posterior_tensor_fn=lambda d: d.sample(), + kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda + loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), + kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long + bias_posterior_tensor_fn=lambda d: d.sample(), + bias_prior_fn=None, + bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + seed=None, + name=None, + **kwargs): + super(Conv3DFlipout, self).__init__( + rank=3, + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + data_format=data_format, + dilation_rate=dilation_rate, + activation=activation, + activity_regularizer=activity_regularizer, + trainable=trainable, + kernel_posterior_fn=kernel_posterior_fn, + kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, + kernel_prior_fn=kernel_prior_fn, + kernel_divergence_fn=kernel_divergence_fn, + bias_posterior_fn=bias_posterior_fn, + bias_posterior_tensor_fn=bias_posterior_tensor_fn, + bias_prior_fn=bias_prior_fn, + bias_divergence_fn=bias_divergence_fn, + seed=seed, + name=name, **kwargs) + + +def conv3d_flipout( + inputs, + filters, + kernel_size, + strides=(1, 1, 1), + padding="valid", + data_format="channels_last", + dilation_rate=(1, 1, 1), + activation=None, + activity_regularizer=None, + trainable=True, + kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), + kernel_posterior_tensor_fn=lambda d: d.sample(), + kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda + loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), + kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long + bias_posterior_tensor_fn=lambda d: d.sample(), + bias_prior_fn=None, + bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + seed=None, + name=None, + reuse=None): + """Functional interface for the 3D convolution layer. + + This layer creates a convolution kernel that is convolved + (actually cross-correlated) with the layer input to produce a tensor of + outputs. It may also include a bias addition and activation function + on the outputs. It assumes the `kernel` and/or `bias` are drawn from + distributions. + + By default, the layer implements a stochastic forward pass via + sampling from the kernel and bias posteriors, + ```none + outputs = f(inputs; kernel, bias), kernel, bias ~ posterior + ``` + where f denotes the layer's calculation. It uses the Flipout + estimator [1], which performs a Monte Carlo approximation of the + distribution integrating over the `kernel` and `bias`. Flipout uses + roughly twice as many floating point operations as the + reparameterization estimator but has the advantage of significantly + lower variance. + + The arguments permit separate specification of the surrogate posterior + (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` + distributions. + + Arguments: + inputs: Tensor input. + filters: Integer, the dimensionality of the output space (i.e. the number + of filters in the convolution). + kernel_size: An integer or tuple/list of 3 integers, specifying the + depth, height and width of the 3D convolution window. + Can be a single integer to specify the same value for + all spatial dimensions. + strides: An integer or tuple/list of 3 integers, + specifying the strides of the convolution along the depth, + height and width. + Can be a single integer to specify the same value for + all spatial dimensions. + Specifying any stride value != 1 is incompatible with specifying + any `dilation_rate` value != 1. + padding: One of `"valid"` or `"same"` (case-insensitive). + data_format: A string, one of `channels_last` (default) or `channels_first`. + The ordering of the dimensions in the inputs. + `channels_last` corresponds to inputs with shape + `(batch, depth, height, width, channels)` while `channels_first` + corresponds to inputs with shape + `(batch, channels, depth, height, width)`. + dilation_rate: An integer or tuple/list of 3 integers, specifying + the dilation rate to use for dilated convolution. + Can be a single integer to specify the same value for + all spatial dimensions. + Currently, specifying any `dilation_rate` value != 1 is + incompatible with specifying any stride value != 1. + activation: Activation function. Set it to None to maintain a + linear activation. + activity_regularizer: Optional regularizer function for the output. + trainable: Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + kernel_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `kernel` parameter. Default value: + `default_mean_field_normal_fn()`. + kernel_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + kernel_prior_fn: Python `callable` which creates `tf.distributions` + instance. See `default_mean_field_normal_fn` docstring for required + parameter signature. + Default value: `tf.distributions.Normal(loc=0., scale=1.)`. + kernel_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + sample is a `Tensor`. + bias_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `bias` parameter. Default value: + `default_mean_field_normal_fn(is_singular=True)` (which creates an + instance of `tf.distributions.Deterministic`). + bias_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + bias_prior_fn: Python `callable` which creates `tf.distributions` instance. + See `default_mean_field_normal_fn` docstring for required parameter + signature. Default value: `None` (no prior, no variational inference) + bias_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + seed: Python scalar `int` which initializes the random number + generator. Default value: `None` (i.e., use global seed). + name: A string, the name of the layer. + reuse: Boolean, whether to reuse the weights of a previous layer + by the same name. + + Returns: + Output tensor. + + Raises: + ValueError: if eager execution is enabled. + + #### Examples + + We illustrate a Bayesian neural network with [variational inference]( + https://en.wikipedia.org/wiki/Variational_Bayesian_methods), + assuming a dataset of `features` and `labels`. + + ```python + tfp = tf.contrib.bayesflow + + net = tf.reshape(features, [-1, 256, 32, 32, 3]) + net = tfp.layers.conv3d_flipout(net, + filters=64, + kernel_size=5, + padding="SAME", + activation=tf.nn.relu) + net = tf.layers.max_pooling2d(net, + pool_size=2, + strides=2, + padding="SAME") + net = tf.reshape(net, [-1, 256 * 8 * 8 * 64]) + logits = tfp.layers.dense_flipout(net, 10) + neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( + labels=labels, logits=logits) + kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) + loss = neg_log_likelihood + kl + train_op = tf.train.AdamOptimizer().minimize(loss) + ``` + + It uses the Flipout gradient estimator to minimize the + Kullback-Leibler divergence up to a constant, also known as the + negative Evidence Lower Bound. It consists of the sum of two terms: + the expected negative log-likelihood, which we approximate via + Monte Carlo; and the KL divergence, which is added via regularizer + terms which are arguments to the layer. + + [1]: "Flipout: Efficient Pseudo-Independent Weight Perturbations on + Mini-Batches." + Anonymous. OpenReview, 2017. + https://openreview.net/forum?id=rJnpifWAb + """ + layer = Conv3DFlipout( + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + data_format=data_format, + dilation_rate=dilation_rate, + activation=activation, + activity_regularizer=activity_regularizer, + trainable=trainable, + kernel_posterior_fn=kernel_posterior_fn, + kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, + kernel_prior_fn=kernel_prior_fn, + kernel_divergence_fn=kernel_divergence_fn, + bias_posterior_fn=bias_posterior_fn, + bias_posterior_tensor_fn=bias_posterior_tensor_fn, + bias_prior_fn=bias_prior_fn, + bias_divergence_fn=bias_divergence_fn, + seed=seed, + name=name, + dtype=inputs.dtype.base_dtype, + _scope=name, + _reuse=reuse) + return layer.apply(inputs) + + +# Aliases + +Convolution1DReparameterization = Conv1DReparameterization +Convolution2DReparameterization = Conv2DReparameterization +Convolution3DReparameterization = Conv3DReparameterization +convolution1d_reparameterization = conv1d_reparameterization +convolution2d_reparameterization = conv2d_reparameterization +convolution3d_reparameterization = conv3d_reparameterization +Convolution1DFlipout = Conv1DFlipout +Convolution2DFlipout = Conv2DFlipout +Convolution3DFlipout = Conv3DFlipout +convolution1d_flipout = conv1d_flipout +convolution2d_flipout = conv2d_flipout +convolution3d_flipout = conv3d_flipout diff --git a/tensorflow/contrib/bayesflow/python/ops/layers_dense_variational.py b/tensorflow/contrib/bayesflow/python/ops/layers_dense_variational.py new file mode 100644 index 0000000000000000000000000000000000000000..591a8e553de0c194786c7ee8693665f762711b2d --- /dev/null +++ b/tensorflow/contrib/bayesflow/python/ops/layers_dense_variational.py @@ -0,0 +1,1176 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Dense Bayesian layer using KL-divergence based variational inference. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.bayesflow.python.ops import layers_util +from tensorflow.contrib.distributions.python.ops import independent as independent_lib +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.layers import base as layers_lib +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import nn +from tensorflow.python.ops import standard_ops +from tensorflow.python.ops.distributions import kullback_leibler as kl_lib +from tensorflow.python.ops.distributions import normal as normal_lib +from tensorflow.python.ops.distributions import util as distribution_util + + +class _DenseVariational(layers_lib.Layer): + """Abstract densely-connected class (private, used as implementation base). + + This layer implements the Bayesian variational inference analogue to + a dense layer by assuming the `kernel` and/or the `bias` are drawn + from distributions. By default, the layer implements a stochastic + forward pass via sampling from the kernel and bias posteriors, + + ```none + kernel, bias ~ posterior + outputs = activation(matmul(inputs, kernel) + bias) + ``` + + The arguments permit separate specification of the surrogate posterior + (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` + distributions. + + Args: + units: Integer or Long, dimensionality of the output space. + activation: Activation function (`callable`). Set it to None to maintain a + linear activation. + activity_regularizer: Regularizer function for the output. + trainable: Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + kernel_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `kernel` parameter. Default value: + `default_mean_field_normal_fn()`. + kernel_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + kernel_prior_fn: Python `callable` which creates `tf.distributions` + instance. See `default_mean_field_normal_fn` docstring for required + parameter signature. + Default value: `tf.distributions.Normal(loc=0., scale=1.)`. + kernel_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + sample is a `Tensor`. + bias_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `bias` parameter. Default value: + `default_mean_field_normal_fn(is_singular=True)` (which creates an + instance of `tf.distributions.Deterministic`). + bias_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + bias_prior_fn: Python `callable` which creates `tf.distributions` instance. + See `default_mean_field_normal_fn` docstring for required parameter + signature. Default value: `None` (no prior, no variational inference) + bias_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + sample is a `Tensor`. + name: Python `str`, the name of the layer. Layers with the same name will + share `tf.Variable`s, but to avoid mistakes we require `reuse=True` in + such cases. + reuse: Python `bool`, whether to reuse the `tf.Variable`s of a previous + layer by the same name. + + Properties: + units: Python integer, dimensionality of the output space. + activation: Activation function (`callable`). + activity_regularizer: Regularizer function for the output. + kernel_posterior_fn: `callable` returning posterior. + kernel_posterior_tensor_fn: `callable` operating on posterior. + kernel_prior_fn: `callable` returning prior. + kernel_divergence_fn: `callable` returning divergence. + bias_posterior_fn: `callable` returning posterior. + bias_posterior_tensor_fn: `callable` operating on posterior. + bias_prior_fn: `callable` returning prior. + bias_divergence_fn: `callable` returning divergence. + """ + + def __init__( + self, + units, + activation=None, + activity_regularizer=None, + trainable=True, + kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), + kernel_posterior_tensor_fn=lambda d: d.sample(), + kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda + loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), + kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long + bias_posterior_tensor_fn=lambda d: d.sample(), + bias_prior_fn=None, + bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + name=None, + **kwargs): + super(_DenseVariational, self).__init__( + trainable=trainable, + name=name, + activity_regularizer=activity_regularizer, + **kwargs) + self.units = units + self.activation = activation + self.input_spec = layers_lib.InputSpec(min_ndim=2) + self.kernel_posterior_fn = kernel_posterior_fn + self.kernel_posterior_tensor_fn = kernel_posterior_tensor_fn + self.kernel_prior_fn = kernel_prior_fn + self.kernel_divergence_fn = kernel_divergence_fn + self.bias_posterior_fn = bias_posterior_fn + self.bias_posterior_tensor_fn = bias_posterior_tensor_fn + self.bias_prior_fn = bias_prior_fn + self.bias_divergence_fn = bias_divergence_fn + + def build(self, input_shape): + input_shape = tensor_shape.TensorShape(input_shape) + in_size = input_shape.with_rank_at_least(2)[-1].value + if in_size is None: + raise ValueError("The last dimension of the inputs to `Dense` " + "should be defined. Found `None`.") + self._input_spec = layers_lib.InputSpec(min_ndim=2, axes={-1: in_size}) + dtype = dtypes.as_dtype(self.dtype) + + # Must have a posterior kernel. + self.kernel_posterior = self.kernel_posterior_fn( + dtype, [in_size, self.units], "kernel_posterior", + self.trainable, self.add_variable) + + if self.kernel_prior_fn is None: + self.kernel_prior = None + else: + self.kernel_prior = self.kernel_prior_fn( + dtype, [in_size, self.units], "kernel_prior", + self.trainable, self.add_variable) + self._built_kernel_divergence = False + + if self.bias_posterior_fn is None: + self.bias_posterior = None + else: + self.bias_posterior = self.bias_posterior_fn( + dtype, [self.units], "bias_posterior", + self.trainable, self.add_variable) + + if self.bias_prior_fn is None: + self.bias_prior = None + else: + self.bias_prior = self.bias_prior_fn( + dtype, [self.units], "bias_prior", + self.trainable, self.add_variable) + self._built_bias_divergence = False + + self.built = True + + def call(self, inputs): + inputs = ops.convert_to_tensor(inputs, dtype=self.dtype) + + outputs = self._apply_variational_kernel(inputs) + outputs = self._apply_variational_bias(outputs) + if self.activation is not None: + outputs = self.activation(outputs) # pylint: disable=not-callable + if not self._built_kernel_divergence: + kernel_posterior = self.kernel_posterior + kernel_prior = self.kernel_prior + if isinstance(self.kernel_posterior, independent_lib.Independent): + kernel_posterior = kernel_posterior.distribution + if isinstance(self.kernel_prior, independent_lib.Independent): + kernel_prior = kernel_prior.distribution + self._apply_divergence(self.kernel_divergence_fn, + kernel_posterior, + kernel_prior, + self.kernel_posterior_tensor, + name="divergence_kernel") + self._built_kernel_divergence = True + if not self._built_bias_divergence: + bias_posterior = self.bias_posterior + bias_prior = self.bias_prior + if isinstance(self.bias_posterior, independent_lib.Independent): + bias_posterior = bias_posterior.distribution + if isinstance(self.bias_prior, independent_lib.Independent): + bias_prior = bias_prior.distribution + self._apply_divergence(self.bias_divergence_fn, + bias_posterior, + bias_prior, + self.bias_posterior_tensor, + name="divergence_bias") + self._built_bias_divergence = True + return outputs + + def _apply_variational_bias(self, inputs): + if self.bias_posterior is None: + self.bias_posterior_tensor = None + return inputs + self.bias_posterior_tensor = self.bias_posterior_tensor_fn( + self.bias_posterior) + return nn.bias_add(inputs, self.bias_posterior_tensor) + + def _apply_divergence(self, divergence_fn, posterior, prior, + posterior_tensor, name): + if (divergence_fn is None or + posterior is None or + prior is None): + divergence = None + return + divergence = standard_ops.identity( + divergence_fn( + posterior, prior, posterior_tensor), + name=name) + self.add_loss(divergence) + + def _matmul(self, inputs, kernel): + if inputs.shape.ndims <= 2: + return standard_ops.matmul(inputs, kernel) + # To handle broadcasting, we must use `tensordot`. + return standard_ops.tensordot(inputs, kernel, axes=[[-1], [0]]) + + def _compute_output_shape(self, input_shape): + input_shape = tensor_shape.TensorShape(input_shape).with_rank_at_least(2) + if input_shape[-1].value is None: + raise ValueError( + "The innermost dimension of input_shape must be defined, " + "but saw: {}".format(input_shape)) + return input_shape[:-1].concatenate(self.units) + + +class DenseReparameterization(_DenseVariational): + """Densely-connected layer class with reparameterization estimator. + + This layer implements the Bayesian variational inference analogue to + a dense layer by assuming the `kernel` and/or the `bias` are drawn + from distributions. By default, the layer implements a stochastic + forward pass via sampling from the kernel and bias posteriors, + + ```none + kernel, bias ~ posterior + outputs = activation(matmul(inputs, kernel) + bias) + ``` + + It uses the reparameterization estimator [1], which performs a Monte Carlo + approximation of the distribution integrating over the `kernel` and + `bias`. + + The arguments permit separate specification of the surrogate posterior + (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` + distributions. + + Args: + units: Integer or Long, dimensionality of the output space. + activation: Activation function (`callable`). Set it to None to maintain a + linear activation. + activity_regularizer: Regularizer function for the output. + trainable: Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + kernel_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `kernel` parameter. Default value: + `default_mean_field_normal_fn()`. + kernel_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + kernel_prior_fn: Python `callable` which creates `tf.distributions` + instance. See `default_mean_field_normal_fn` docstring for required + parameter signature. + Default value: `tf.distributions.Normal(loc=0., scale=1.)`. + kernel_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + sample is a `Tensor`. + bias_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `bias` parameter. Default value: + `default_mean_field_normal_fn(is_singular=True)` (which creates an + instance of `tf.distributions.Deterministic`). + bias_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + bias_prior_fn: Python `callable` which creates `tf.distributions` instance. + See `default_mean_field_normal_fn` docstring for required parameter + signature. Default value: `None` (no prior, no variational inference) + bias_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + sample is a `Tensor`. + name: Python `str`, the name of the layer. Layers with the same name will + share `tf.Variable`s, but to avoid mistakes we require `reuse=True` in + such cases. + reuse: Python `bool`, whether to reuse the `tf.Variable`s of a previous + layer by the same name. + + Properties: + units: Python integer, dimensionality of the output space. + activation: Activation function (`callable`). + activity_regularizer: Regularizer function for the output. + kernel_posterior_fn: `callable` returning posterior. + kernel_posterior_tensor_fn: `callable` operating on posterior. + kernel_prior_fn: `callable` returning prior. + kernel_divergence_fn: `callable` returning divergence. + bias_posterior_fn: `callable` returning posterior. + bias_posterior_tensor_fn: `callable` operating on posterior. + bias_prior_fn: `callable` returning prior. + bias_divergence_fn: `callable` returning divergence. + + #### Examples + + We illustrate a Bayesian neural network with [variational inference]( + https://en.wikipedia.org/wiki/Variational_Bayesian_methods), + assuming a dataset of `features` and `labels`. + + ```python + tfp = tf.contrib.bayesflow + + net = tfp.layers.DenseReparameterization( + 512, activation=tf.nn.relu)(features) + logits = tfp.layers.DenseReparameterization(10)(net) + neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( + labels=labels, logits=logits) + kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) + loss = neg_log_likelihood + kl + train_op = tf.train.AdamOptimizer().minimize(loss) + ``` + + It uses reparameterization gradients to minimize the + Kullback-Leibler divergence up to a constant, also known as the + negative Evidence Lower Bound. It consists of the sum of two terms: + the expected negative log-likelihood, which we approximate via + Monte Carlo; and the KL divergence, which is added via regularizer + terms which are arguments to the layer. + + [1]: "Auto-Encoding Variational Bayes." + Diederik P. Kingma, Max Welling. + International Conference on Learning Representations, 2014. + """ + + def __init__( + self, + units, + activation=None, + activity_regularizer=None, + trainable=True, + kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), + kernel_posterior_tensor_fn=lambda d: d.sample(), + kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda + loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), + kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + bias_posterior_fn=layers_util.default_mean_field_normal_fn( + is_singular=True), + bias_posterior_tensor_fn=lambda d: d.sample(), + bias_prior_fn=None, + bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + name=None, + **kwargs): + super(DenseReparameterization, self).__init__( + units=units, + activation=activation, + activity_regularizer=activity_regularizer, + trainable=trainable, + kernel_posterior_fn=kernel_posterior_fn, + kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, + kernel_prior_fn=kernel_prior_fn, + kernel_divergence_fn=kernel_divergence_fn, + bias_posterior_fn=bias_posterior_fn, + bias_posterior_tensor_fn=bias_posterior_tensor_fn, + bias_prior_fn=bias_prior_fn, + bias_divergence_fn=bias_divergence_fn, + name=name, + **kwargs) + + def _apply_variational_kernel(self, inputs): + self.kernel_posterior_tensor = self.kernel_posterior_tensor_fn( + self.kernel_posterior) + self.kernel_posterior_affine = None + self.kernel_posterior_affine_tensor = None + return self._matmul(inputs, self.kernel_posterior_tensor) + + +def dense_reparameterization( + inputs, + units, + activation=None, + activity_regularizer=None, + trainable=True, + kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), + kernel_posterior_tensor_fn=lambda d: d.sample(), + kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda + loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), + kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long + bias_posterior_tensor_fn=lambda d: d.sample(), + bias_prior_fn=None, + bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + name=None, + reuse=None): + """Densely-connected layer with reparameterization estimator. + + This layer implements the Bayesian variational inference analogue to + a dense layer by assuming the `kernel` and/or the `bias` are drawn + from distributions. By default, the layer implements a stochastic + forward pass via sampling from the kernel and bias posteriors, + + ```none + kernel, bias ~ posterior + outputs = activation(matmul(inputs, kernel) + bias) + ``` + + It uses the reparameterization estimator [1], which performs a Monte Carlo + approximation of the distribution integrating over the `kernel` and + `bias`. + + The arguments permit separate specification of the surrogate posterior + (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` + distributions. + + Args: + inputs: Tensor input. + units: Integer or Long, dimensionality of the output space. + activation: Activation function (`callable`). Set it to None to maintain a + linear activation. + activity_regularizer: Regularizer function for the output. + trainable: Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + kernel_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `kernel` parameter. Default value: + `default_mean_field_normal_fn()`. + kernel_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + kernel_prior_fn: Python `callable` which creates `tf.distributions` + instance. See `default_mean_field_normal_fn` docstring for required + parameter signature. + Default value: `tf.distributions.Normal(loc=0., scale=1.)`. + kernel_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + sample is a `Tensor`. + bias_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `bias` parameter. Default value: + `default_mean_field_normal_fn(is_singular=True)` (which creates an + instance of `tf.distributions.Deterministic`). + bias_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + bias_prior_fn: Python `callable` which creates `tf.distributions` instance. + See `default_mean_field_normal_fn` docstring for required parameter + signature. Default value: `None` (no prior, no variational inference) + bias_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + sample is a `Tensor`. + name: Python `str`, the name of the layer. Layers with the same name will + share `tf.Variable`s, but to avoid mistakes we require `reuse=True` in + such cases. + reuse: Python `bool`, whether to reuse the `tf.Variable`s of a previous + layer by the same name. + + Returns: + output: `Tensor` representing a the affine transformed input under a random + draw from the surrogate posterior distribution. + + #### Examples + + We illustrate a Bayesian neural network with [variational inference]( + https://en.wikipedia.org/wiki/Variational_Bayesian_methods), + assuming a dataset of `features` and `labels`. + + ```python + tfp = tf.contrib.bayesflow + + net = tfp.layers.dense_reparameterization( + features, 512, activation=tf.nn.relu) + logits = tfp.layers.dense_reparameterization(net, 10) + neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( + labels=labels, logits=logits) + kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) + loss = neg_log_likelihood + kl + train_op = tf.train.AdamOptimizer().minimize(loss) + ``` + + It uses reparameterization gradients to minimize the + Kullback-Leibler divergence up to a constant, also known as the + negative Evidence Lower Bound. It consists of the sum of two terms: + the expected negative log-likelihood, which we approximate via + Monte Carlo; and the KL divergence, which is added via regularizer + terms which are arguments to the layer. + + [1]: "Auto-Encoding Variational Bayes." + Diederik P. Kingma, Max Welling. + International Conference on Learning Representations, 2014. + """ + layer = DenseReparameterization( + units, + activation=activation, + activity_regularizer=activity_regularizer, + trainable=trainable, + kernel_posterior_fn=kernel_posterior_fn, + kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, + kernel_prior_fn=kernel_prior_fn, + kernel_divergence_fn=kernel_divergence_fn, + bias_posterior_fn=bias_posterior_fn, + bias_posterior_tensor_fn=bias_posterior_tensor_fn, + bias_prior_fn=bias_prior_fn, + bias_divergence_fn=bias_divergence_fn, + name=name, + dtype=inputs.dtype.base_dtype, + _scope=name, + _reuse=reuse) + return layer.apply(inputs) + + +class DenseLocalReparameterization(_DenseVariational): + """Densely-connected layer class with local reparameterization estimator. + + This layer implements the Bayesian variational inference analogue to + a dense layer by assuming the `kernel` and/or the `bias` are drawn + from distributions. By default, the layer implements a stochastic + forward pass via sampling from the kernel and bias posteriors, + + ```none + kernel, bias ~ posterior + outputs = activation(matmul(inputs, kernel) + bias) + ``` + + It uses the local reparameterization estimator [1], which performs a + Monte Carlo approximation of the distribution on the hidden units + induced by the `kernel` and `bias`. + + The arguments permit separate specification of the surrogate posterior + (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` + distributions. + + Args: + units: Integer or Long, dimensionality of the output space. + activation: Activation function (`callable`). Set it to None to maintain a + linear activation. + activity_regularizer: Regularizer function for the output. + trainable: Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + kernel_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `kernel` parameter. Default value: + `default_mean_field_normal_fn()`. + kernel_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + kernel_prior_fn: Python `callable` which creates `tf.distributions` + instance. See `default_mean_field_normal_fn` docstring for required + parameter signature. + Default value: `tf.distributions.Normal(loc=0., scale=1.)`. + kernel_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + sample is a `Tensor`. + bias_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `bias` parameter. Default value: + `default_mean_field_normal_fn(is_singular=True)` (which creates an + instance of `tf.distributions.Deterministic`). + bias_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + bias_prior_fn: Python `callable` which creates `tf.distributions` instance. + See `default_mean_field_normal_fn` docstring for required parameter + signature. Default value: `None` (no prior, no variational inference) + bias_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + sample is a `Tensor`. + name: Python `str`, the name of the layer. Layers with the same name will + share `tf.Variable`s, but to avoid mistakes we require `reuse=True` in + such cases. + reuse: Python `bool`, whether to reuse the `tf.Variable`s of a previous + layer by the same name. + + Properties: + units: Python integer, dimensionality of the output space. + activation: Activation function (`callable`). + activity_regularizer: Regularizer function for the output. + kernel_posterior_fn: `callable` returning posterior. + kernel_posterior_tensor_fn: `callable` operating on posterior. + kernel_prior_fn: `callable` returning prior. + kernel_divergence_fn: `callable` returning divergence. + bias_posterior_fn: `callable` returning posterior. + bias_posterior_tensor_fn: `callable` operating on posterior. + bias_prior_fn: `callable` returning prior. + bias_divergence_fn: `callable` returning divergence. + + #### Examples + + We illustrate a Bayesian neural network with [variational inference]( + https://en.wikipedia.org/wiki/Variational_Bayesian_methods), + assuming a dataset of `features` and `labels`. + + ```python + tfp = tf.contrib.bayesflow + + net = tfp.layers.DenseLocalReparameterization( + 512, activation=tf.nn.relu)(features) + logits = tfp.layers.DenseLocalReparameterization(10)(net) + neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( + labels=labels, logits=logits) + kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) + loss = neg_log_likelihood + kl + train_op = tf.train.AdamOptimizer().minimize(loss) + ``` + + It uses local reparameterization gradients to minimize the + Kullback-Leibler divergence up to a constant, also known as the + negative Evidence Lower Bound. It consists of the sum of two terms: + the expected negative log-likelihood, which we approximate via + Monte Carlo; and the KL divergence, which is added via regularizer + terms which are arguments to the layer. + + [1]: "Variational Dropout and the Local Reparameterization Trick." + Diederik P. Kingma, Tim Salimans, Max Welling. + Neural Information Processing Systems, 2015. + """ + + def __init__( + self, + units, + activation=None, + activity_regularizer=None, + trainable=True, + kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), + kernel_posterior_tensor_fn=lambda d: d.sample(), + kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda + loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), + kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + bias_posterior_fn=layers_util.default_mean_field_normal_fn( + is_singular=True), + bias_posterior_tensor_fn=lambda d: d.sample(), + bias_prior_fn=None, + bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + name=None, + **kwargs): + super(DenseLocalReparameterization, self).__init__( + units=units, + activation=activation, + activity_regularizer=activity_regularizer, + trainable=trainable, + kernel_posterior_fn=kernel_posterior_fn, + kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, + kernel_prior_fn=kernel_prior_fn, + kernel_divergence_fn=kernel_divergence_fn, + bias_posterior_fn=bias_posterior_fn, + bias_posterior_tensor_fn=bias_posterior_tensor_fn, + bias_prior_fn=bias_prior_fn, + bias_divergence_fn=bias_divergence_fn, + name=name, + **kwargs) + + def _apply_variational_kernel(self, inputs): + if (not isinstance(self.kernel_posterior, independent_lib.Independent) or + not isinstance(self.kernel_posterior.distribution, normal_lib.Normal)): + raise TypeError( + "`DenseLocalReparameterization` requires " + "`kernel_posterior_fn` produce an instance of " + "`tf.distributions.Independent(tf.distributions.Normal)` " + "(saw: \"{}\").".format(self.kernel_posterior.name)) + self.kernel_posterior_affine = normal_lib.Normal( + loc=self._matmul(inputs, self.kernel_posterior.distribution.loc), + scale=standard_ops.sqrt(self._matmul( + standard_ops.square(inputs), + standard_ops.square(self.kernel_posterior.distribution.scale)))) + self.kernel_posterior_affine_tensor = ( + self.kernel_posterior_tensor_fn(self.kernel_posterior_affine)) + self.kernel_posterior_tensor = None + return self.kernel_posterior_affine_tensor + + +def dense_local_reparameterization( + inputs, + units, + activation=None, + activity_regularizer=None, + trainable=True, + kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), + kernel_posterior_tensor_fn=lambda d: d.sample(), + kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda + loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), + kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + bias_posterior_fn=layers_util.default_mean_field_normal_fn( + is_singular=True), + bias_posterior_tensor_fn=lambda d: d.sample(), + bias_prior_fn=None, + bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + name=None, + reuse=None): + """Densely-connected layer with local reparameterization estimator. + + This layer implements the Bayesian variational inference analogue to + a dense layer by assuming the `kernel` and/or the `bias` are drawn + from distributions. By default, the layer implements a stochastic + forward pass via sampling from the kernel and bias posteriors, + + ```none + kernel, bias ~ posterior + outputs = activation(matmul(inputs, kernel) + bias) + ``` + + It uses the local reparameterization estimator [1], which performs a + Monte Carlo approximation of the distribution on the hidden units + induced by the `kernel` and `bias`. + + The arguments permit separate specification of the surrogate posterior + (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` + distributions. + + Args: + inputs: Tensor input. + units: Integer or Long, dimensionality of the output space. + activation: Activation function (`callable`). Set it to None to maintain a + linear activation. + activity_regularizer: Regularizer function for the output. + trainable: Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + kernel_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `kernel` parameter. Default value: + `default_mean_field_normal_fn()`. + kernel_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + kernel_prior_fn: Python `callable` which creates `tf.distributions` + instance. See `default_mean_field_normal_fn` docstring for required + parameter signature. + Default value: `tf.distributions.Normal(loc=0., scale=1.)`. + kernel_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + sample is a `Tensor`. + bias_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `bias` parameter. Default value: + `default_mean_field_normal_fn(is_singular=True)` (which creates an + instance of `tf.distributions.Deterministic`). + bias_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + bias_prior_fn: Python `callable` which creates `tf.distributions` instance. + See `default_mean_field_normal_fn` docstring for required parameter + signature. Default value: `None` (no prior, no variational inference) + bias_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + sample is a `Tensor`. + name: Python `str`, the name of the layer. Layers with the same name will + share `tf.Variable`s, but to avoid mistakes we require `reuse=True` in + such cases. + reuse: Python `bool`, whether to reuse the `tf.Variable`s of a previous + layer by the same name. + + Returns: + output: `Tensor` representing a the affine transformed input under a random + draw from the surrogate posterior distribution. + + #### Examples + + We illustrate a Bayesian neural network with [variational inference]( + https://en.wikipedia.org/wiki/Variational_Bayesian_methods), + assuming a dataset of `features` and `labels`. + + ```python + tfp = tf.contrib.bayesflow + + net = tfp.layers.dense_local_reparameterization( + features, 512, activation=tf.nn.relu) + logits = tfp.layers.dense_local_reparameterization(net, 10) + neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( + labels=labels, logits=logits) + kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) + loss = neg_log_likelihood + kl + train_op = tf.train.AdamOptimizer().minimize(loss) + ``` + + It uses local reparameterization gradients to minimize the + Kullback-Leibler divergence up to a constant, also known as the + negative Evidence Lower Bound. It consists of the sum of two terms: + the expected negative log-likelihood, which we approximate via + Monte Carlo; and the KL divergence, which is added via regularizer + terms which are arguments to the layer. + + [1]: "Variational Dropout and the Local Reparameterization Trick." + Diederik P. Kingma, Tim Salimans, Max Welling. + Neural Information Processing Systems, 2015. + """ + layer = DenseLocalReparameterization( + units, + activation=activation, + activity_regularizer=activity_regularizer, + trainable=trainable, + kernel_posterior_fn=kernel_posterior_fn, + kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, + kernel_prior_fn=kernel_prior_fn, + kernel_divergence_fn=kernel_divergence_fn, + bias_posterior_fn=bias_posterior_fn, + bias_posterior_tensor_fn=bias_posterior_tensor_fn, + bias_prior_fn=bias_prior_fn, + bias_divergence_fn=bias_divergence_fn, + name=name, + dtype=inputs.dtype.base_dtype, + _scope=name, + _reuse=reuse) + return layer.apply(inputs) + + +class DenseFlipout(_DenseVariational): + """Densely-connected layer class with Flipout estimator. + + This layer implements the Bayesian variational inference analogue to + a dense layer by assuming the `kernel` and/or the `bias` are drawn + from distributions. By default, the layer implements a stochastic + forward pass via sampling from the kernel and bias posteriors, + + ```none + kernel, bias ~ posterior + outputs = activation(matmul(inputs, kernel) + bias) + ``` + + It uses the Flipout estimator [1], which performs a Monte Carlo + approximation of the distribution integrating over the `kernel` and + `bias`. Flipout uses roughly twice as many floating point operations + as the reparameterization estimator but has the advantage of + significantly lower variance. + + The arguments permit separate specification of the surrogate posterior + (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` + distributions. + + Args: + units: Integer or Long, dimensionality of the output space. + activation: Activation function (`callable`). Set it to None to maintain a + linear activation. + activity_regularizer: Regularizer function for the output. + trainable: Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + kernel_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `kernel` parameter. Default value: + `default_mean_field_normal_fn()`. + kernel_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + kernel_prior_fn: Python `callable` which creates `tf.distributions` + instance. See `default_mean_field_normal_fn` docstring for required + parameter signature. + Default value: `tf.distributions.Normal(loc=0., scale=1.)`. + kernel_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + sample is a `Tensor`. + bias_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `bias` parameter. Default value: + `default_mean_field_normal_fn(is_singular=True)` (which creates an + instance of `tf.distributions.Deterministic`). + bias_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + bias_prior_fn: Python `callable` which creates `tf.distributions` instance. + See `default_mean_field_normal_fn` docstring for required parameter + signature. Default value: `None` (no prior, no variational inference) + bias_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + sample is a `Tensor`. + seed: Python scalar `int` which initializes the random number + generator. Default value: `None` (i.e., use global seed). + name: Python `str`, the name of the layer. Layers with the same name will + share `tf.Variable`s, but to avoid mistakes we require `reuse=True` in + such cases. + reuse: Python `bool`, whether to reuse the `tf.Variable`s of a previous + layer by the same name. + + Properties: + units: Python integer, dimensionality of the output space. + activation: Activation function (`callable`). + activity_regularizer: Regularizer function for the output. + kernel_posterior_fn: `callable` returning posterior. + kernel_posterior_tensor_fn: `callable` operating on posterior. + kernel_prior_fn: `callable` returning prior. + kernel_divergence_fn: `callable` returning divergence. + bias_posterior_fn: `callable` returning posterior. + bias_posterior_tensor_fn: `callable` operating on posterior. + bias_prior_fn: `callable` returning prior. + bias_divergence_fn: `callable` returning divergence. + seed: Python integer, used to create random seeds. + + #### Examples + + We illustrate a Bayesian neural network with [variational inference]( + https://en.wikipedia.org/wiki/Variational_Bayesian_methods), + assuming a dataset of `features` and `labels`. + + ```python + tfp = tf.contrib.bayesflow + + net = tfp.layers.DenseFlipout( + 512, activation=tf.nn.relu)(features) + logits = tfp.layers.DenseFlipout(10)(net) + neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( + labels=labels, logits=logits) + kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) + loss = neg_log_likelihood + kl + train_op = tf.train.AdamOptimizer().minimize(loss) + ``` + + It uses the Flipout gradient estimator to minimize the + Kullback-Leibler divergence up to a constant, also known as the + negative Evidence Lower Bound. It consists of the sum of two terms: + the expected negative log-likelihood, which we approximate via + Monte Carlo; and the KL divergence, which is added via regularizer + terms which are arguments to the layer. + + [1]: "Flipout: Efficient Pseudo-Independent Weight Perturbations on + Mini-Batches." + Anonymous. OpenReview, 2017. + https://openreview.net/forum?id=rJnpifWAb + """ + + def __init__( + self, + units, + activation=None, + activity_regularizer=None, + trainable=True, + kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), + kernel_posterior_tensor_fn=lambda d: d.sample(), + kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda + loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), + kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + bias_posterior_fn=layers_util.default_mean_field_normal_fn( + is_singular=True), + bias_posterior_tensor_fn=lambda d: d.sample(), + bias_prior_fn=None, + bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + seed=None, + name=None, + **kwargs): + super(DenseFlipout, self).__init__( + units=units, + activation=activation, + activity_regularizer=activity_regularizer, + trainable=trainable, + kernel_posterior_fn=kernel_posterior_fn, + kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, + kernel_prior_fn=kernel_prior_fn, + kernel_divergence_fn=kernel_divergence_fn, + bias_posterior_fn=bias_posterior_fn, + bias_posterior_tensor_fn=bias_posterior_tensor_fn, + bias_prior_fn=bias_prior_fn, + bias_divergence_fn=bias_divergence_fn, + name=name, + **kwargs) + self.seed = seed + + def _apply_variational_kernel(self, inputs): + if (not isinstance(self.kernel_posterior, independent_lib.Independent) or + not isinstance(self.kernel_posterior.distribution, normal_lib.Normal)): + raise TypeError( + "`DenseFlipout` requires " + "`kernel_posterior_fn` produce an instance of " + "`tf.distributions.Independent(tf.distributions.Normal)` " + "(saw: \"{}\").".format(self.kernel_posterior.name)) + self.kernel_posterior_affine = normal_lib.Normal( + loc=array_ops.zeros_like(self.kernel_posterior.distribution.loc), + scale=self.kernel_posterior.distribution.scale) + self.kernel_posterior_affine_tensor = ( + self.kernel_posterior_tensor_fn(self.kernel_posterior_affine)) + self.kernel_posterior_tensor = None + + input_shape = array_ops.shape(inputs) + batch_shape = input_shape[:-1] + + sign_input = layers_util.random_sign( + input_shape, + dtype=inputs.dtype, + seed=self.seed) + sign_output = layers_util.random_sign( + array_ops.concat([batch_shape, + array_ops.expand_dims(self.units, 0)], 0), + dtype=inputs.dtype, + seed=distribution_util.gen_new_seed( + self.seed, salt="dense_flipout")) + perturbed_inputs = self._matmul( + inputs * sign_input, self.kernel_posterior_affine_tensor) * sign_output + + outputs = self._matmul(inputs, self.kernel_posterior.distribution.loc) + outputs += perturbed_inputs + return outputs + + +def dense_flipout( + inputs, + units, + activation=None, + activity_regularizer=None, + trainable=True, + kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), + kernel_posterior_tensor_fn=lambda d: d.sample(), + kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda + loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), + kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + bias_posterior_fn=layers_util.default_mean_field_normal_fn( + is_singular=True), + bias_posterior_tensor_fn=lambda d: d.sample(), + bias_prior_fn=None, + bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + seed=None, + name=None, + reuse=None): + """Densely-connected layer with Flipout estimator. + + This layer implements the Bayesian variational inference analogue to + a dense layer by assuming the `kernel` and/or the `bias` are drawn + from distributions. By default, the layer implements a stochastic + forward pass via sampling from the kernel and bias posteriors, + + ```none + kernel, bias ~ posterior + outputs = activation(matmul(inputs, kernel) + bias) + ``` + + It uses the Flipout estimator [1], which performs a Monte Carlo + approximation of the distribution integrating over the `kernel` and + `bias`. Flipout uses roughly twice as many floating point operations + as the reparameterization estimator but has the advantage of + significantly lower variance. + + The arguments permit separate specification of the surrogate posterior + (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` + distributions. + + Args: + inputs: Tensor input. + units: Integer or Long, dimensionality of the output space. + activation: Activation function (`callable`). Set it to None to maintain a + linear activation. + activity_regularizer: Regularizer function for the output. + trainable: Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + kernel_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `kernel` parameter. Default value: + `default_mean_field_normal_fn()`. + kernel_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + kernel_prior_fn: Python `callable` which creates `tf.distributions` + instance. See `default_mean_field_normal_fn` docstring for required + parameter signature. + Default value: `tf.distributions.Normal(loc=0., scale=1.)`. + kernel_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + sample is a `Tensor`. + bias_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `bias` parameter. Default value: + `default_mean_field_normal_fn(is_singular=True)` (which creates an + instance of `tf.distributions.Deterministic`). + bias_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + bias_prior_fn: Python `callable` which creates `tf.distributions` instance. + See `default_mean_field_normal_fn` docstring for required parameter + signature. Default value: `None` (no prior, no variational inference) + bias_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + sample is a `Tensor`. + seed: Python scalar `int` which initializes the random number + generator. Default value: `None` (i.e., use global seed). + name: Python `str`, the name of the layer. Layers with the same name will + share `tf.Variable`s, but to avoid mistakes we require `reuse=True` in + such cases. + reuse: Python `bool`, whether to reuse the `tf.Variable`s of a previous + layer by the same name. + + Returns: + output: `Tensor` representing a the affine transformed input under a random + draw from the surrogate posterior distribution. + + #### Examples + + We illustrate a Bayesian neural network with [variational inference]( + https://en.wikipedia.org/wiki/Variational_Bayesian_methods), + assuming a dataset of `features` and `labels`. + + ```python + tfp = tf.contrib.bayesflow + + net = tfp.layers.dense_flipout( + features, 512, activation=tf.nn.relu) + logits = tfp.layers.dense_flipout(net, 10) + neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( + labels=labels, logits=logits) + kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) + loss = neg_log_likelihood + kl + train_op = tf.train.AdamOptimizer().minimize(loss) + ``` + + It uses the Flipout gradient estimator to minimize the + Kullback-Leibler divergence up to a constant, also known as the + negative Evidence Lower Bound. It consists of the sum of two terms: + the expected negative log-likelihood, which we approximate via + Monte Carlo; and the KL divergence, which is added via regularizer + terms which are arguments to the layer. + + [1]: "Flipout: Efficient Pseudo-Independent Weight Perturbations on + Mini-Batches." + Anonymous. OpenReview, 2017. + https://openreview.net/forum?id=rJnpifWAb + """ + layer = DenseFlipout( + units, + activation=activation, + activity_regularizer=activity_regularizer, + trainable=trainable, + kernel_posterior_fn=kernel_posterior_fn, + kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, + kernel_prior_fn=kernel_prior_fn, + kernel_divergence_fn=kernel_divergence_fn, + bias_posterior_fn=bias_posterior_fn, + bias_posterior_tensor_fn=bias_posterior_tensor_fn, + bias_prior_fn=bias_prior_fn, + bias_divergence_fn=bias_divergence_fn, + seed=seed, + name=name, + dtype=inputs.dtype.base_dtype, + _scope=name, + _reuse=reuse) + return layer.apply(inputs) diff --git a/tensorflow/contrib/bayesflow/python/ops/layers_dense_variational_impl.py b/tensorflow/contrib/bayesflow/python/ops/layers_dense_variational_impl.py deleted file mode 100644 index b05ce0ffc1dd55ffb029b339a846a9aa5c877620..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/ops/layers_dense_variational_impl.py +++ /dev/null @@ -1,797 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Dense Bayesian layer using KL-divergence based variational inference. - -@@DenseVariational -@@dense_variational - -@@default_loc_scale_fn -@@default_mean_field_normal_fn -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.contrib.distributions.python.ops import deterministic as deterministic_lib -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape -from tensorflow.python.layers import base as layers_lib -from tensorflow.python.ops import init_ops -from tensorflow.python.ops import nn -from tensorflow.python.ops import nn_ops -from tensorflow.python.ops import standard_ops -from tensorflow.python.ops.distributions import kullback_leibler as kl_lib -from tensorflow.python.ops.distributions import normal as normal_lib - - -__all__ = [ - "DenseVariational", - "dense_variational", - "default_loc_scale_fn", - "default_mean_field_normal_fn", -] - - -def default_loc_scale_fn( - is_singular=False, - loc_initializer=init_ops.random_normal_initializer(stddev=0.1), - untransformed_scale_initializer=init_ops.random_normal_initializer( - mean=-3., stddev=0.1), - loc_regularizer=None, - untransformed_scale_regularizer=None, - loc_constraint=None, - untransformed_scale_constraint=None): - """Makes closure which creates `loc`, `scale` params from `tf.get_variable`. - - This function produces a closure which produces `loc`, `scale` using - `tf.get_variable`. The closure accepts the following arguments: - - dtype: Type of parameter's event. - shape: Python `list`-like representing the parameter's event shape. - name: Python `str` name prepended to any created (or existing) - `tf.Variable`s. - trainable: Python `bool` indicating all created `tf.Variable`s should be - added to the graph collection `GraphKeys.TRAINABLE_VARIABLES`. - add_variable_fn: `tf.get_variable`-like `callable` used to create (or - access existing) `tf.Variable`s. - - Args: - is_singular: Python `bool` indicating if `scale is None`. Default: `False`. - loc_initializer: Initializer function for the `loc` parameters. - The default is `tf.random_normal_initializer(mean=0., stddev=0.1)`. - untransformed_scale_initializer: Initializer function for the `scale` - parameters. Default value: `tf.random_normal_initializer(mean=-3., - stddev=0.1)`. This implies the softplus transformed result has mean - approximately `0.05` and std. deviation approximately `0.005`. - loc_regularizer: Regularizer function for the `loc` parameters. - The default (`None`) is to use the `tf.get_variable` default. - untransformed_scale_regularizer: Regularizer function for the `scale` - parameters. The default (`None`) is to use the `tf.get_variable` default. - loc_constraint: An optional projection function to be applied to the - loc after being updated by an `Optimizer`. The function must take as input - the unprojected variable and must return the projected variable (which - must have the same shape). Constraints are not safe to use when doing - asynchronous distributed training. - The default (`None`) is to use the `tf.get_variable` default. - untransformed_scale_constraint: An optional projection function to be - applied to the `scale` parameters after being updated by an `Optimizer` - (e.g. used to implement norm constraints or value constraints). The - function must take as input the unprojected variable and must return the - projected variable (which must have the same shape). Constraints are not - safe to use when doing asynchronous distributed training. The default - (`None`) is to use the `tf.get_variable` default. - - Returns: - default_loc_scale_fn: Python `callable` which instantiates `loc`, `scale` - parameters from args: `dtype, shape, name, trainable, add_variable_fn`. - """ - def _fn(dtype, shape, name, trainable, add_variable_fn): - """Creates `loc`, `scale` parameters.""" - loc = add_variable_fn( - name=name + "_loc", - shape=shape, - initializer=loc_initializer, - regularizer=loc_regularizer, - constraint=loc_constraint, - dtype=dtype, - trainable=trainable) - if is_singular: - return loc, None - untransformed_scale = add_variable_fn( - name=name + "_untransformed_scale", - shape=shape, - initializer=untransformed_scale_initializer, - regularizer=untransformed_scale_regularizer, - constraint=untransformed_scale_constraint, - dtype=dtype, - trainable=trainable) - scale = (np.finfo(dtype.as_numpy_dtype).eps + - nn_ops.softplus(untransformed_scale)) - return loc, scale - return _fn - - -def default_mean_field_normal_fn( - is_singular=False, - loc_initializer=None, - untransformed_scale_initializer=None, - loc_regularizer=None, - untransformed_scale_regularizer=None, - loc_constraint=None, - untransformed_scale_constraint=None): - """Creates a function to build Normal distributions with trainable params. - - This function produces a closure which produces `tf.distributions.Normal` - parameterized by a loc` and `scale` each created using `tf.get_variable`. The - produced closure accepts the following arguments: - - name: Python `str` name prepended to any created (or existing) - `tf.Variable`s. - shape: Python `list`-like representing the parameter's event shape. - dtype: Type of parameter's event. - trainable: Python `bool` indicating all created `tf.Variable`s should be - added to the graph collection `GraphKeys.TRAINABLE_VARIABLES`. - add_variable_fn: `tf.get_variable`-like `callable` used to create (or - access existing) `tf.Variable`s. - - Args: - is_singular: Python `bool` if `True`, forces the special case limit of - `scale->0`, i.e., a `Deterministic` distribution. - loc_initializer: Initializer function for the `loc` parameters. - If `None` (default), values are initialized using the default - initializer used by `tf.get_variable`. - untransformed_scale_initializer: Initializer function for the `scale` - parameters. If `None` (default), values are initialized using the default - initializer used by `tf.get_variable`. - loc_regularizer: Regularizer function for the `loc` parameters. - untransformed_scale_regularizer: Regularizer function for the `scale` - parameters. - loc_constraint: An optional projection function to be applied to the - loc after being updated by an `Optimizer`. The function must take as input - the unprojected variable and must return the projected variable (which - must have the same shape). Constraints are not safe to use when doing - asynchronous distributed training. - untransformed_scale_constraint: An optional projection function to be - applied to the `scale` parameters after being updated by an `Optimizer` - (e.g. used to implement norm constraints or value constraints). The - function must take as input the unprojected variable and must return the - projected variable (which must have the same shape). Constraints are not - safe to use when doing asynchronous distributed training. - - Returns: - make_normal_fn: Python `callable` which creates a `tf.distributions.Normal` - using from args: `dtype, shape, name, trainable, add_variable_fn`. - """ - loc_scale_fn_ = default_loc_scale_fn( - is_singular, - loc_initializer, - untransformed_scale_initializer, - loc_regularizer, - untransformed_scale_regularizer, - loc_constraint, - untransformed_scale_constraint) - def _fn(dtype, shape, name, trainable, add_variable_fn): - """Creates a batch of `Deterministic` or `Normal` distributions.""" - loc, scale = loc_scale_fn_(dtype, shape, name, trainable, add_variable_fn) - if scale is None: - return deterministic_lib.Deterministic(loc=loc) - return normal_lib.Normal(loc=loc, scale=scale) - return _fn - - -class DenseVariational(layers_lib.Layer): - """Densely-connected variational class. - - This layer implements the Bayesian variational inference analogue to: - `outputs = activation(matmul(inputs, kernel) + bias)` - by assuming the `kernel` and/or the `bias` are random variables. - - The layer implements a stochastic dense calculation by making a Monte Carlo - approximation of a [variational Bayesian method based on KL divergence]( - https://en.wikipedia.org/wiki/Variational_Bayesian_methods), i.e., - - ```none - -log p(y|x) = -log int_{R**d} p(y|x,w) p(w) dw - = -log int_{R**d} p(y,w|x) q(w|x) / q(w|x) dw - <= E_q(W|x)[-log p(y,W|x) + log q(W|x)] # Jensen's - = E_q(W|x)[-log p(y|x,W)] + KL[q(W|x), p(W)] - ~= m**-1 sum{ -log(y|x,w[j]) : w[j] ~ q(W|x), j=1..m } - + KL[q(W|x), p(W)] - ``` - - where `W` denotes the (independent) `kernel` and `bias` random variables, `w` - is a random variate or outcome of `W`, `y` is the label, `x` is the evidence`, - and `~=` denotes an approximation which becomes exact as `m->inf`. The above - bound is sometimes referred to as the negative Evidence Lower BOund or - negative [ELBO](https://arxiv.org/abs/1601.00670). In context of a DNN, this - layer is appropriate to use when the final loss is a negative log-likelihood. - - The Monte-Carlo sum portion is used for the feed-forward calculation of the - DNN. The KL divergence portion can be added to the final loss via: - `loss += sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))`. - - The arguments permit separate specification of the surrogate posterior - (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` - random variables (which together comprise `W`). - - Args: - units: Integer or Long, dimensionality of the output space. - activation: Activation function (`callable`). Set it to None to maintain a - linear activation. - activity_regularizer: Regularizer function for the output. - trainable: Boolean, if `True` also add variables to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). - kernel_use_local_reparameterization: Python `bool` indicating whether - `kernel` calculation should employ the Local Reparameterization Trick. - When `True`, `kernel_posterior_fn` must create an instance of - `tf.distributions.Normal`. - kernel_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `kernel` parameter. Default value: - `default_mean_field_normal_fn()`. - kernel_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - kernel_prior_fn: Python `callable` which creates `tf.distributions` - instance. See `default_mean_field_normal_fn` docstring for required - parameter signature. - Default value: `tf.distributions.Normal(loc=0., scale=1.)`. - kernel_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - bias_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `bias` parameter. Default value: - `default_mean_field_normal_fn(is_singular=True)` (which creates an - instance of `tf.distributions.Deterministic`). - bias_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - bias_prior_fn: Python `callable` which creates `tf.distributions` instance. - See `default_mean_field_normal_fn` docstring for required parameter - signature. Default value: `None` (no prior, no variational inference) - bias_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - name: Python `str`, the name of the layer. Layers with the same name will - share `tf.Variable`s, but to avoid mistakes we require `reuse=True` in - such cases. - reuse: Python `bool`, whether to reuse the `tf.Variable`s of a previous - layer by the same name. - - Properties: - units: Python integer, dimensionality of the output space. - activation: Activation function (`callable`). - activity_regularizer: Regularizer function for the output. - kernel_use_local_reparameterization: Python `bool` indicating whether - `kernel` calculation should employ the Local Reparameterization Trick. - kernel: `VariationalKernelParamater` instance containing all `kernel` - related properties and `callable`s. - bias: `VariationalParameter` instance containing all `kernel` - related properties and `callable`s. - """ - - def __init__( - self, - units, - activation=None, - activity_regularizer=None, - trainable=True, - kernel_use_local_reparameterization=True, - kernel_posterior_fn=default_mean_field_normal_fn(), - kernel_posterior_tensor_fn=lambda d: d.sample(), - kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda - loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), - kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - bias_posterior_fn=default_mean_field_normal_fn(is_singular=True), - bias_posterior_tensor_fn=lambda d: d.sample(), - bias_prior_fn=None, - bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - name=None, - **kwargs): - super(DenseVariational, self).__init__( - trainable=trainable, - name=name, - activity_regularizer=activity_regularizer, - **kwargs) - self._units = units - self._activation = activation - self._input_spec = layers_lib.InputSpec(min_ndim=2) - self._kernel_use_local_reparameterization = ( - kernel_use_local_reparameterization) - self._kernel = VariationalKernelParameter( - kernel_posterior_fn, - kernel_posterior_tensor_fn, - kernel_prior_fn, - kernel_divergence_fn) - self._bias = VariationalParameter( - bias_posterior_fn, - bias_posterior_tensor_fn, - bias_prior_fn, - bias_divergence_fn) - - @property - def units(self): - return self._units - - @property - def activation(self): - return self._activation - - @property - def input_spec(self): - return self._input_spec - - @input_spec.setter - def input_spec(self, value): - self._input_spec = value - - @property - def kernel_use_local_reparameterization(self): - return self._kernel_use_local_reparameterization - - @property - def kernel(self): - return self._kernel - - @property - def bias(self): - return self._bias - - def build(self, input_shape): - input_shape = tensor_shape.TensorShape(input_shape) - in_size = input_shape.with_rank_at_least(2)[-1].value - if in_size is None: - raise ValueError("The last dimension of the inputs to `Dense` " - "should be defined. Found `None`.") - self._input_spec = layers_lib.InputSpec(min_ndim=2, axes={-1: in_size}) - dtype = dtypes.as_dtype(self.dtype) - - # Must have a posterior kernel. - self.kernel.posterior = self.kernel.posterior_fn( - dtype, [in_size, self.units], "kernel_posterior", - self.trainable, self.add_variable) - - if self.kernel.prior_fn is None: - self.kernel_prior = None - else: - self.kernel.prior = self.kernel.prior_fn( - dtype, [in_size, self.units], "kernel_prior", - self.trainable, self.add_variable) - self._built_kernel_divergence = False - - if self.bias.posterior_fn is None: - self.bias.posterior = None - else: - self.bias.posterior = self.bias.posterior_fn( - dtype, [self.units], "bias_posterior", - self.trainable, self.add_variable) - - if self.bias.prior_fn is None: - self.bias.prior = None - else: - self.bias.prior = self.bias.prior_fn( - dtype, [self.units], "bias_prior", - self.trainable, self.add_variable) - self._built_bias_divergence = False - - self.built = True - - def call(self, inputs): - inputs = ops.convert_to_tensor(inputs, dtype=self.dtype) - - outputs = self._apply_variational_kernel(inputs) - outputs = self._apply_variational_bias(outputs) - if self.activation is not None: - outputs = self.activation(outputs) # pylint: disable=not-callable - if not self._built_kernel_divergence: - self._apply_divergence(self.kernel, name="divergence_kernel") - self._built_kernel_divergence = True - if not self._built_bias_divergence: - self._apply_divergence(self.bias, name="divergence_bias") - self._built_bias_divergence = True - return outputs - - def _apply_variational_kernel(self, inputs): - if not self.kernel_use_local_reparameterization: - self.kernel.posterior_tensor = self.kernel.posterior_tensor_fn( - self.kernel.posterior) - self.kernel.posterior_affine = None - self.kernel.posterior_affine_tensor = None - return self._matmul(inputs, self.kernel.posterior_tensor) - if not isinstance(self.kernel.posterior, normal_lib.Normal): - raise TypeError("`kernel_use_local_reparameterization=True` requires " - "`kernel_posterior_fn` produce an instance of " - "`tf.distributions.Normal` (saw: \"{}\").".format( - type(self.kernel.posterior).__name__)) - self.kernel.posterior_affine = normal_lib.Normal( - loc=self._matmul(inputs, self.kernel.posterior.loc), - scale=standard_ops.sqrt(self._matmul( - standard_ops.square(inputs), - standard_ops.square(self.kernel.posterior.scale)))) - self.kernel.posterior_affine_tensor = ( - self.kernel.posterior_tensor_fn(self.kernel.posterior_affine)) - self.kernel.posterior_tensor = None - return self.kernel.posterior_affine_tensor - - def _apply_variational_bias(self, inputs): - if self.bias.posterior is None: - self.bias.posterior_tensor = None - return inputs - self.bias.posterior_tensor = self.bias.posterior_tensor_fn( - self.bias.posterior) - return nn.bias_add(inputs, self.bias.posterior_tensor) - - def _apply_divergence(self, param, name): - if (param.divergence_fn is None or - param.posterior is None or - param.prior is None): - param.divergence = None - return - param.divergence = standard_ops.identity( - param.divergence_fn( - param.posterior, param.prior, param.posterior_tensor), - name=name) - self.add_loss(param.divergence) - - def _matmul(self, inputs, kernel): - if inputs.shape.ndims <= 2: - return standard_ops.matmul(inputs, kernel) - # To handle broadcasting, we must use `tensordot`. - return standard_ops.tensordot(inputs, kernel, axes=[[-1], [0]]) - - def _compute_output_shape(self, input_shape): - input_shape = tensor_shape.TensorShape(input_shape).with_rank_at_least(2) - if input_shape[-1].value is None: - raise ValueError( - "The innermost dimension of input_shape must be defined, " - "but saw: {}".format(input_shape)) - return input_shape[:-1].concatenate(self.units) - - -def dense_variational( - inputs, - units, - activation=None, - activity_regularizer=None, - trainable=True, - kernel_use_local_reparameterization=True, - kernel_posterior_fn=default_mean_field_normal_fn(), - kernel_posterior_tensor_fn=lambda d: d.sample(), - kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda - loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), - kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - bias_posterior_fn=default_mean_field_normal_fn(is_singular=True), - bias_posterior_tensor_fn=lambda d: d.sample(), - bias_prior_fn=None, - bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - name=None, - reuse=None): - """Densely-connected variational layer. - - This layer implements the Bayesian variational inference analogue to: - `outputs = activation(matmul(inputs, kernel) + bias)` - by assuming the `kernel` and/or the `bias` are random variables. - - The layer implements a stochastic dense calculation by making a Monte Carlo - approximation of a [variational Bayesian method based on KL divergence]( - https://en.wikipedia.org/wiki/Variational_Bayesian_methods), i.e., - - ```none - -log p(y|x) = -log int_{R**d} p(y|x,w) p(w) dw - = -log int_{R**d} p(y,w|x) q(w|x) / q(w|x) dw - <= E_q(W|x)[-log p(y,W|x) + log q(W|x)] # Jensen's - = E_q(W|x)[-log p(y|x,W)] + KL[q(W|x), p(W)] - ~= m**-1 sum{ -log(y|x,w[j]) : w[j] ~ q(W|x), j=1..m } - + KL[q(W|x), p(W)] - ``` - - where `W` denotes the (independent) `kernel` and `bias` random variables, `w` - is a random variate or outcome of `W`, `y` is the label, `x` is the evidence`, - and `~=` denotes an approximation which becomes exact as `m->inf`. The above - bound is sometimes referred to as the negative Evidence Lower BOund or - negative [ELBO](https://arxiv.org/abs/1601.00670). In context of a DNN, this - layer is appropriate to use when the final loss is a negative log-likelihood. - - The Monte-Carlo sum portion is used for the feed-forward calculation of the - DNN. The KL divergence portion can be added to the final loss via: - `loss += sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))`. - - The arguments permit separate specification of the surrogate posterior - (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` - random variables (which together comprise `W`). - - Args: - inputs: Tensor input. - units: Integer or Long, dimensionality of the output space. - activation: Activation function (`callable`). Set it to None to maintain a - linear activation. - activity_regularizer: Regularizer function for the output. - trainable: Boolean, if `True` also add variables to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). - kernel_use_local_reparameterization: Python `bool` indicating whether - `kernel` calculation should employ the Local Reparameterization Trick. - When `True`, `kernel_posterior_fn` must create an instance of - `tf.distributions.Normal`. - kernel_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `kernel` parameter. Default value: - `default_mean_field_normal_fn()`. - kernel_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - kernel_prior_fn: Python `callable` which creates `tf.distributions` - instance. See `default_mean_field_normal_fn` docstring for required - parameter signature. - Default value: `tf.distributions.Normal(loc=0., scale=1.)`. - kernel_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - bias_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `bias` parameter. Default value: - `default_mean_field_normal_fn(is_singular=True)` (which creates an - instance of `tf.distributions.Deterministic`). - bias_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - bias_prior_fn: Python `callable` which creates `tf.distributions` instance. - See `default_mean_field_normal_fn` docstring for required parameter - signature. Default value: `None` (no prior, no variational inference) - bias_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - name: Python `str`, the name of the layer. Layers with the same name will - share `tf.Variable`s, but to avoid mistakes we require `reuse=True` in - such cases. - reuse: Python `bool`, whether to reuse the `tf.Variable`s of a previous - layer by the same name. - - Returns: - output: `Tensor` representing a the affine transformed input under a random - draw from the surrogate posterior distribution. - """ - layer = DenseVariational( - units, - activation=activation, - activity_regularizer=activity_regularizer, - trainable=trainable, - kernel_use_local_reparameterization=( - kernel_use_local_reparameterization), - kernel_posterior_fn=kernel_posterior_fn, - kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, - kernel_prior_fn=kernel_prior_fn, - kernel_divergence_fn=kernel_divergence_fn, - bias_posterior_fn=bias_posterior_fn, - bias_posterior_tensor_fn=bias_posterior_tensor_fn, - bias_prior_fn=bias_prior_fn, - bias_divergence_fn=bias_divergence_fn, - name=name, - dtype=inputs.dtype.base_dtype, - _scope=name, - _reuse=reuse) - return layer.apply(inputs) - - -class NotSet(object): - """Helper to track whether a `VariationalParameter` value has been set.""" - pass - - -class VariationalParameter(object): - """Struct-like container of variational parameter properties. - - A `VariationalParameter` is intitialized with Python `callable`s which set the - value of correspondingly named members. Corresponding values have "set once" - semantics, i.e., once set to any value they are immutable. - """ - - def __init__( - self, - posterior_fn, - posterior_tensor_fn, - prior_fn, - divergence_fn): - """Creates the `VariationalParameter` struct-like object. - - Args: - posterior_fn: Python `callable` which creates a - `tf.distribution.Distribution` like object representing the posterior - distribution. See `VariationalParameter.posterior_fn` for `callable`'s - required parameters. - posterior_tensor_fn: Python `callable` which computes a `Tensor` - which represents the `posterior`. - prior_fn: Python `callable` which creates a - `tf.distribution.Distribution` like object representing the prior - distribution. See `VariationalParameter.prior_fn` for `callable`'s - required parameters. - divergence_fn: Python `callable` which computes the KL divergence from - `posterior` to `prior`. See `VariationalParameter.divergence_fn` for - required `callable`'s parameters. - """ - self._posterior_fn = posterior_fn - self._posterior = NotSet() - self._posterior_tensor_fn = posterior_tensor_fn - self._posterior_tensor = NotSet() - self._prior_fn = prior_fn - self._prior = NotSet() - self._divergence_fn = divergence_fn - self._divergence = NotSet() - self._init_helper() - - @property - def posterior_fn(self): - """`callable` which creates `tf.distributions.Distribution`-like posterior. - - The `callable` must accept the following parameters: - name: Python `str` name prepended to any created (or existing) - `tf.Variable`s. - shape: Python `list`-like representing the parameter's event shape. - dtype: Type of parameter's event. - trainable: Python `bool` indicating all created `tf.Variable`s should be - added to the graph collection `GraphKeys.TRAINABLE_VARIABLES`. - add_variable_fn: `tf.get_variable`-like `callable` used to create (or - access existing) `tf.Variable`s. - - Returns: - posterior_fn: The Python `callable` specified in `__init__`. - """ - return self._posterior_fn - - @property - def posterior(self): - """`tf.distributions.Distribution`-like instance representing posterior.""" - return self._posterior - - @posterior.setter - def posterior(self, value): - """One-time setter of the `posterior` distribution.""" - if not isinstance(self._posterior, NotSet): - raise ValueError("Cannot override already set attribute.") - self._posterior = value - - @property - def posterior_tensor_fn(self): - """Creates `Tensor` representing the `posterior` distribution. - - The `callable` must accept the following parameters: - posterior: `tf.distributions.Distribution`-like instance. - - Returns: - posterior_tensor_fn: The Python `callable` specified in - `__init__`. - """ - return self._posterior_tensor_fn - - @property - def posterior_tensor(self): - """`Tensor` representing the `posterior` distribution.""" - return self._posterior_tensor - - @posterior_tensor.setter - def posterior_tensor(self, value): - """One-time setter of the `posterior_tensor`.""" - if not isinstance(self._posterior_tensor, NotSet): - raise ValueError("Cannot override already set attribute.") - self._posterior_tensor = value - - @property - def prior_fn(self): - """`callable` which creates `tf.distributions.Distribution`-like prior. - - The `callable` must accept the following parameters: - name: Python `str` name prepended to any created (or existing) - `tf.Variable`s. - shape: Python `list`-like representing the parameter's event shape. - dtype: Type of parameter's event. - trainable: Python `bool` indicating all created `tf.Variable`s should be - added to the graph collection `GraphKeys.TRAINABLE_VARIABLES`. - add_variable_fn: `tf.get_variable`-like `callable` used to create (or - access existing) `tf.Variable`s. - - Returns: - prior_fn: The Python `callable` specified in `__init__`. - """ - return self._prior_fn - - @property - def prior(self): - """`tf.distributions.Distribution`-like instance representing posterior.""" - return self._prior - - @prior.setter - def prior(self, value): - """One-time setter of the `prior` distribution.""" - if not isinstance(self._prior, NotSet): - raise ValueError("Cannot override already set attribute.") - self._prior = value - - @property - def divergence_fn(self): - """`callable` which computes KL-divergence `Tensor` from posterior to prior. - - The `callable` must accept the following parameters: - posterior: `tf.distributions.Distribution`-like instance. - prior: `tf.distributions.Distribution`-like instance. - posterior_tensor: `Tensor` representing value of posterior. - - Returns: - divergence_fn: The Python `callable` specified in `__init__`. - """ - return self._divergence_fn - - @property - def divergence(self): - """`Tensor` representing KL-divergence from posterior to prior.""" - return self._divergence - - @divergence.setter - def divergence(self, value): - """One-time setter of the `divergence`.""" - if not isinstance(self._divergence, NotSet): - raise ValueError("Cannot override already set attribute.") - self._divergence = value - - def _init_helper(self): - pass - - -class VariationalKernelParameter(VariationalParameter): - """Struct-like container of variational kernel properties. - - A `VariationalKernelParameter` is intitialized with Python `callable`s which - set the value of correspondingly named members. Corresponding values have "set - once" semantics, i.e., once set to any value they are immutable. - """ - - @property - def posterior_affine(self): - """`tf.distributions.Distribution` affine transformed posterior.""" - return self._posterior_affine - - @posterior_affine.setter - def posterior_affine(self, value): - """One-time setter of `posterior_affine`.""" - if not isinstance(self._posterior_affine, NotSet): - raise ValueError("Cannot override already set attribute.") - self._posterior_affine = value - - @property - def posterior_affine_tensor(self): - """`Tensor` representing the `posterior_affine` distribution.""" - return self._posterior_affine_tensor - - @posterior_affine_tensor.setter - def posterior_affine_tensor(self, value): - """One-time setter of the `posterior_affine_tensor`.""" - if not isinstance(self._posterior_affine_tensor, NotSet): - raise ValueError("Cannot override already set attribute.") - self._posterior_affine_tensor = value - - def _init_helper(self): - self._posterior_affine = NotSet() - self._posterior_affine_tensor = NotSet() diff --git a/tensorflow/contrib/bayesflow/python/ops/layers_util.py b/tensorflow/contrib/bayesflow/python/ops/layers_util.py new file mode 100644 index 0000000000000000000000000000000000000000..8c1fb203f7328e8260e49b4326d813fbe133613e --- /dev/null +++ b/tensorflow/contrib/bayesflow/python/ops/layers_util.py @@ -0,0 +1,191 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities for probabilistic layers. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.distributions.python.ops import deterministic as deterministic_lib +from tensorflow.contrib.distributions.python.ops import independent as independent_lib +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops.distributions import normal as normal_lib + + +def default_loc_scale_fn( + is_singular=False, + loc_initializer=init_ops.random_normal_initializer(stddev=0.1), + untransformed_scale_initializer=init_ops.random_normal_initializer( + mean=-3., stddev=0.1), + loc_regularizer=None, + untransformed_scale_regularizer=None, + loc_constraint=None, + untransformed_scale_constraint=None): + """Makes closure which creates `loc`, `scale` params from `tf.get_variable`. + + This function produces a closure which produces `loc`, `scale` using + `tf.get_variable`. The closure accepts the following arguments: + + dtype: Type of parameter's event. + shape: Python `list`-like representing the parameter's event shape. + name: Python `str` name prepended to any created (or existing) + `tf.Variable`s. + trainable: Python `bool` indicating all created `tf.Variable`s should be + added to the graph collection `GraphKeys.TRAINABLE_VARIABLES`. + add_variable_fn: `tf.get_variable`-like `callable` used to create (or + access existing) `tf.Variable`s. + + Args: + is_singular: Python `bool` indicating if `scale is None`. Default: `False`. + loc_initializer: Initializer function for the `loc` parameters. + The default is `tf.random_normal_initializer(mean=0., stddev=0.1)`. + untransformed_scale_initializer: Initializer function for the `scale` + parameters. Default value: `tf.random_normal_initializer(mean=-3., + stddev=0.1)`. This implies the softplus transformed result has mean + approximately `0.05` and std. deviation approximately `0.005`. + loc_regularizer: Regularizer function for the `loc` parameters. + The default (`None`) is to use the `tf.get_variable` default. + untransformed_scale_regularizer: Regularizer function for the `scale` + parameters. The default (`None`) is to use the `tf.get_variable` default. + loc_constraint: An optional projection function to be applied to the + loc after being updated by an `Optimizer`. The function must take as input + the unprojected variable and must return the projected variable (which + must have the same shape). Constraints are not safe to use when doing + asynchronous distributed training. + The default (`None`) is to use the `tf.get_variable` default. + untransformed_scale_constraint: An optional projection function to be + applied to the `scale` parameters after being updated by an `Optimizer` + (e.g. used to implement norm constraints or value constraints). The + function must take as input the unprojected variable and must return the + projected variable (which must have the same shape). Constraints are not + safe to use when doing asynchronous distributed training. The default + (`None`) is to use the `tf.get_variable` default. + + Returns: + default_loc_scale_fn: Python `callable` which instantiates `loc`, `scale` + parameters from args: `dtype, shape, name, trainable, add_variable_fn`. + """ + def _fn(dtype, shape, name, trainable, add_variable_fn): + """Creates `loc`, `scale` parameters.""" + loc = add_variable_fn( + name=name + "_loc", + shape=shape, + initializer=loc_initializer, + regularizer=loc_regularizer, + constraint=loc_constraint, + dtype=dtype, + trainable=trainable) + if is_singular: + return loc, None + untransformed_scale = add_variable_fn( + name=name + "_untransformed_scale", + shape=shape, + initializer=untransformed_scale_initializer, + regularizer=untransformed_scale_regularizer, + constraint=untransformed_scale_constraint, + dtype=dtype, + trainable=trainable) + scale = (np.finfo(dtype.as_numpy_dtype).eps + + nn_ops.softplus(untransformed_scale)) + return loc, scale + return _fn + + +def default_mean_field_normal_fn( + is_singular=False, + loc_initializer=None, + untransformed_scale_initializer=None, + loc_regularizer=None, + untransformed_scale_regularizer=None, + loc_constraint=None, + untransformed_scale_constraint=None): + """Creates a function to build Normal distributions with trainable params. + + This function produces a closure which produces `tf.distributions.Normal` + parameterized by a loc` and `scale` each created using `tf.get_variable`. The + produced closure accepts the following arguments: + + name: Python `str` name prepended to any created (or existing) + `tf.Variable`s. + shape: Python `list`-like representing the parameter's event shape. + dtype: Type of parameter's event. + trainable: Python `bool` indicating all created `tf.Variable`s should be + added to the graph collection `GraphKeys.TRAINABLE_VARIABLES`. + add_variable_fn: `tf.get_variable`-like `callable` used to create (or + access existing) `tf.Variable`s. + + Args: + is_singular: Python `bool` if `True`, forces the special case limit of + `scale->0`, i.e., a `Deterministic` distribution. + loc_initializer: Initializer function for the `loc` parameters. + If `None` (default), values are initialized using the default + initializer used by `tf.get_variable`. + untransformed_scale_initializer: Initializer function for the `scale` + parameters. If `None` (default), values are initialized using the default + initializer used by `tf.get_variable`. + loc_regularizer: Regularizer function for the `loc` parameters. + untransformed_scale_regularizer: Regularizer function for the `scale` + parameters. + loc_constraint: An optional projection function to be applied to the + loc after being updated by an `Optimizer`. The function must take as input + the unprojected variable and must return the projected variable (which + must have the same shape). Constraints are not safe to use when doing + asynchronous distributed training. + untransformed_scale_constraint: An optional projection function to be + applied to the `scale` parameters after being updated by an `Optimizer` + (e.g. used to implement norm constraints or value constraints). The + function must take as input the unprojected variable and must return the + projected variable (which must have the same shape). Constraints are not + safe to use when doing asynchronous distributed training. + + Returns: + make_normal_fn: Python `callable` which creates a `tf.distributions.Normal` + using from args: `dtype, shape, name, trainable, add_variable_fn`. + """ + loc_scale_fn_ = default_loc_scale_fn( + is_singular, + loc_initializer, + untransformed_scale_initializer, + loc_regularizer, + untransformed_scale_regularizer, + loc_constraint, + untransformed_scale_constraint) + def _fn(dtype, shape, name, trainable, add_variable_fn): + """Creates multivariate `Deterministic` or `Normal` distribution.""" + loc, scale = loc_scale_fn_(dtype, shape, name, trainable, add_variable_fn) + if scale is None: + dist = deterministic_lib.Deterministic(loc=loc) + else: + dist = normal_lib.Normal(loc=loc, scale=scale) + reinterpreted_batch_ndims = array_ops.shape(dist.batch_shape_tensor())[0] + return independent_lib.Independent( + dist, reinterpreted_batch_ndims=reinterpreted_batch_ndims) + return _fn + + +def random_sign(shape, dtype=dtypes.float32, seed=None): + """Draw values from {-1, 1} uniformly, i.e., Rademacher distribution.""" + random_bernoulli = random_ops.random_uniform(shape, minval=0, maxval=2, + dtype=dtypes.int32, + seed=seed) + return math_ops.cast(2 * random_bernoulli - 1, dtype) diff --git a/tensorflow/contrib/bayesflow/python/ops/optimizers.py b/tensorflow/contrib/bayesflow/python/ops/optimizers.py index ee32e6b5c3d9efaeaf73436638c5eea55f2cfc70..fb70628d1083836281e9327e83e109493276c64f 100644 --- a/tensorflow/contrib/bayesflow/python/ops/optimizers.py +++ b/tensorflow/contrib/bayesflow/python/ops/optimizers.py @@ -24,11 +24,13 @@ from __future__ import print_function # go/tf-wildcard-import # pylint: disable=wildcard-import from tensorflow.contrib.bayesflow.python.ops.sgld_optimizer import * +from tensorflow.contrib.bayesflow.python.ops.variational_sgd_optimizer import * # pylint: enable=wildcard-import from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = [ 'SGLDOptimizer', + 'VariationalSGDOptimizer', ] remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/bayesflow/python/ops/sgld_optimizer.py b/tensorflow/contrib/bayesflow/python/ops/sgld_optimizer.py index 5d36ea7a2b51aa45cdc253992a2a58634c068987..7786656398e3c87704227be95b3cd23a38785249 100644 --- a/tensorflow/contrib/bayesflow/python/ops/sgld_optimizer.py +++ b/tensorflow/contrib/bayesflow/python/ops/sgld_optimizer.py @@ -189,6 +189,10 @@ class SGLDOptimizer(optimizer.Optimizer): new_grad, use_locking=self._use_locking).op + def _finish(self, update_ops, name_scope): + update_ops.append([self._counter.assign_add(1)]) + return control_flow_ops.group(*update_ops, name=name_scope) + @property def variable_scope(self): """Variable scope of all calls to `tf.get_variable`.""" diff --git a/tensorflow/contrib/bayesflow/python/ops/variational_sgd_optimizer.py b/tensorflow/contrib/bayesflow/python/ops/variational_sgd_optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..4d5f0cfe9713a011b32c5aba8d429847d81f33e2 --- /dev/null +++ b/tensorflow/contrib/bayesflow/python/ops/variational_sgd_optimizer.py @@ -0,0 +1,279 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""An optimizer module for constant stochastic gradient descent.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import clip_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variable_scope as varscope_ops +from tensorflow.python.training import optimizer +from tensorflow.python.training import training_ops + + +class VariationalSGDOptimizer(optimizer.Optimizer): + """An optimizer module for constant stochastic gradient descent. + + This implements an optimizer module for the constant stochastic gradient + descent algorithm [1]. The optimization variable is regarded as an + approximate sample from the posterior . + + Note: If a prior is included in the loss, it should be scaled by + `1/num_pseudo_batches`, where num_pseudo_batches is the number of minibatches + in the data. I.e., it should be divided by the `num_pseudo_batches` term + described below. + + [1]: "Stochastic Gradient Descent as Approximate Bayesian Inference + Stephan Mandt, Matthew D. Hoffman, David M. Blei. + ArXiv:1704.04289, 2017. https://arxiv.org/abs/1704.04289 + + Args: + batch_size: Scalar `int`-like `Tensor`. The number of examples in a + minibatch in the data set. Note: Assumes the loss is taken as the mean + over a minibatch. Otherwise if the sum was taken set this to 1. + total_num_examples: Scalar `int`-like `Tensor`. The total number of examples + in the data set. + max_learning_rate: Scalar `float`-like `Tensor`. A maximum allowable + effective coordinate-wise learning rate. The algorithm scales down any + effective learning rate (i.e. after preconditioning) that is larger than + this. (Default: `1`) + preconditioner_decay_rate: Scalar `float`-like `Tensor`. The exponential + decay rate of the rescaling of the preconditioner (RMSprop). (This is + "alpha" in [1]). Should be smaller than but nearly `1` to approximate + sampling from the posterior. (Default: `0.95`) + burnin: Scalar `int`-like `Tensor`. The number of iterations to collect + gradient statistics to update the preconditioner before starting to draw + noisy samples. (Default: `25`) + burnin_max_learning_rate: Scalar `float`-like `Tensor`. Maximum learning + rate to use during the burnin period. + (Default: `1e-8`) + use_single_learning_rate: Boolean Indicates whether one single learning + rate is used or coordinate_wise learning rates are used. + (Default: `False`) + name: Python `str` describing ops managed by this function. + (Default: `"VariationalSGDOptimizer"`) + variable_scope: Variable scope used for calls to `tf.get_variable`. + If `None`, a new variable scope is created using name + `ops.get_default_graph().unique_name(name or default_name)`. + + Raises: + InvalidArgumentError: If preconditioner_decay_rate is a `Tensor` not in + `(0,1]`. + """ + + def __init__(self, + batch_size, + total_num_examples, + max_learning_rate=1.0, + preconditioner_decay_rate=0.95, + burnin=25, + burnin_max_learning_rate=1e-6, + use_single_learning_rate=False, + name=None, + variable_scope=None): + default_name = 'VariationalSGDOptimizer' + with ops.name_scope(name, default_name, [ + max_learning_rate, preconditioner_decay_rate, batch_size, burnin, + burnin_max_learning_rate + ]): + if variable_scope is None: + var_scope_name = ops.get_default_graph().unique_name( + name or default_name) + with varscope_ops.variable_scope(var_scope_name) as scope: + self._variable_scope = scope + else: + self._variable_scope = variable_scope + + self._preconditioner_decay_rate = ops.convert_to_tensor( + preconditioner_decay_rate, name='preconditioner_decay_rate') + self._batch_size = ops.convert_to_tensor(batch_size, name='batch_size') + self._total_num_examples = ops.convert_to_tensor( + total_num_examples, name='total_num_examples') + self._burnin = ops.convert_to_tensor(burnin, name='burnin') + self._burnin_max_learning_rate = ops.convert_to_tensor( + burnin_max_learning_rate, name='burnin_max_learning_rate') + self._max_learning_rate = ops.convert_to_tensor( + max_learning_rate, name='max_learning_rate') + self._use_single_learning_rate = use_single_learning_rate + + with varscope_ops.variable_scope(self._variable_scope): + self._counter = varscope_ops.get_variable( + 'counter', initializer=0, trainable=False) + + self._preconditioner_decay_rate = control_flow_ops.with_dependencies([ + check_ops.assert_non_negative( + self._preconditioner_decay_rate, + message='`preconditioner_decay_rate` must be non-negative'), + check_ops.assert_less_equal( + self._preconditioner_decay_rate, + 1., + message='`preconditioner_decay_rate` must be at most 1.'), + ], self._preconditioner_decay_rate) + + self._batch_size = control_flow_ops.with_dependencies([ + check_ops.assert_greater( + self._batch_size, + 0, + message='`batch_size` must be greater than zero') + ], self._batch_size) + + self._total_num_examples = control_flow_ops.with_dependencies([ + check_ops.assert_greater( + self._total_num_examples, + 0, + message='`total_num_examples` must be greater than zero') + ], self._total_num_examples) + + self._burnin = control_flow_ops.with_dependencies([ + check_ops.assert_non_negative( + self._burnin, message='`burnin` must be non-negative'), + check_ops.assert_integer( + self._burnin, message='`burnin` must be an integer') + ], self._burnin) + + self._burnin_max_learning_rate = control_flow_ops.with_dependencies([ + check_ops.assert_non_negative( + self._burnin_max_learning_rate, + message='`burnin_max_learning_rate` must be non-negative') + ], self._burnin_max_learning_rate) + + self._max_learning_rate = control_flow_ops.with_dependencies([ + check_ops.assert_non_negative( + self._max_learning_rate, + message='`max_learning_rate` must be non-negative') + ], self._max_learning_rate) + + super(VariationalSGDOptimizer, self).__init__( + use_locking=False, name=name or default_name) + + def _create_slots(self, var_list): + for v in var_list: + init_moment = init_ops.zeros_initializer(dtype=v.dtype) + self._get_or_make_slot_with_initializer( + v, init_moment, v.get_shape(), v.dtype, 'first_moment', self._name) + self._get_or_make_slot_with_initializer( + v, init_moment, v.get_shape(), v.dtype, 'second_moment', self._name) + + def _prepare(self): + self._decay_tensor = ops.convert_to_tensor( + self._preconditioner_decay_rate, name='preconditioner_decay_rate') + self._batch_size_tensor = ops.convert_to_tensor( + self._batch_size, name='batch_size_tensor') + + super(VariationalSGDOptimizer, self)._prepare() + + def _get_coordinatewise_learning_rate(self, grad, var): + # Compute the learning rate using a moving average for the diagonal of BB^T + avg_first = self.get_slot(var, 'first_moment') + avg_second = self.get_slot(var, 'second_moment') + decay_tensor = math_ops.cast(self._decay_tensor, var.dtype) + batch_size = math_ops.cast(self._batch_size_tensor, var.dtype) + + # Create an estimator for the moving average of gradient mean and variance + # via Welford's algorithm + if isinstance(grad, ops.Tensor): + delta = grad - avg_first + first_moment_update = avg_first.assign_add( + array_ops.where(self._counter < 1, math_ops.cast(1, var.dtype), + 1. - decay_tensor) * delta) + + with ops.control_dependencies([first_moment_update]): + second_moment_update = avg_second.assign_add( + math_ops.cast(self._counter < 1, var.dtype) * + -(1. - decay_tensor) * ( + avg_second - decay_tensor * math_ops.square(delta))) + diag_preconditioner = control_flow_ops.with_dependencies( + [second_moment_update], + clip_ops.clip_by_value(avg_second, 1e-12, 1e12)) + elif isinstance(grad, ops.IndexedSlices): + delta = grad.values - array_ops.gather_nd(avg_first, grad.indices) + first_moment_update = state_ops.scatter_add( + avg_first, + grad.indices, + array_ops.where(self._counter < 1, + math_ops.cast(1., var.dtype), + 1. - decay_tensor) * delta) + + with ops.control_dependencies([first_moment_update]): + avg_second = state_ops.scatter_add( + avg_second, + grad.indices, + math_ops.cast(self._counter < 1, var.dtype) * + -(1. - decay_tensor) * ( + array_ops.gather_nd(avg_second, grad.indices) - decay_tensor * + math_ops.square(delta))) + avg_second = array_ops.gather_nd(avg_second, grad.indices) + # TODO(b/70783772) + diag_preconditioner = clip_ops.clip_by_value(avg_second, 1e-12, 1e12) + else: + raise errors.InvalidArgumentError( + None, None, 'grad must of type Tensor or IndexedSlice') + + diag_preconditioner *= batch_size + + if self._use_single_learning_rate: + diag_preconditioner = math_ops.reduce_mean(diag_preconditioner) + + # From Theorem 2 Corollary 1 of Mandt et al. 2017 + return 2. * batch_size / ( + math_ops.cast(self._total_num_examples, var.dtype.base_dtype) * + diag_preconditioner) + + def _apply_dense(self, grad, var): + + max_learning_rate = array_ops.where(self._counter < self._burnin, + self._burnin_max_learning_rate, + self._max_learning_rate) + + learn_rates = clip_ops.clip_by_value( + self._get_coordinatewise_learning_rate(grad, var), 0.0, + math_ops.cast(max_learning_rate, var.dtype.base_dtype)) + + newgrad = grad * learn_rates + return training_ops.apply_gradient_descent( + var, + math_ops.cast(1.0, var.dtype), + newgrad, + use_locking=self._use_locking).op + + def _apply_sparse(self, grad, var): + + max_learning_rate = array_ops.where(self._counter < self._burnin, + self._burnin_max_learning_rate, + self._max_learning_rate) + + learn_rate = clip_ops.clip_by_value( + self._get_coordinatewise_learning_rate(grad, var), 0.0, + math_ops.cast(max_learning_rate, var.dtype)) + delta = grad.values * learn_rate + + return state_ops.scatter_sub(var, grad.indices, delta, + use_locking=self._use_locking) + + def _finish(self, update_ops, name_scope): + update_ops.append([self._counter.assign_add(1)]) + return control_flow_ops.group(*update_ops, name=name_scope) + + @property + def variable_scope(self): + """Variable scope of all calls to `tf.get_variable`.""" + return self._variable_scope diff --git a/tensorflow/contrib/boosted_trees/BUILD b/tensorflow/contrib/boosted_trees/BUILD index 7072f56420ac9e576b20b62c0aa67498857403a7..392ac7fa1ce600a64ee3b941b70b01447645e4aa 100644 --- a/tensorflow/contrib/boosted_trees/BUILD +++ b/tensorflow/contrib/boosted_trees/BUILD @@ -601,6 +601,7 @@ py_library( ":init_py", "//tensorflow/contrib/boosted_trees:gbdt_batch", "//tensorflow/contrib/boosted_trees/estimator_batch:custom_export_strategy", + "//tensorflow/contrib/boosted_trees/estimator_batch:dnn_tree_combined_estimator", "//tensorflow/contrib/boosted_trees/estimator_batch:init_py", "//tensorflow/contrib/boosted_trees/estimator_batch:trainer_hooks", "//tensorflow/contrib/boosted_trees/lib:categorical_split_handler", diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/BUILD b/tensorflow/contrib/boosted_trees/estimator_batch/BUILD index 7792c7127c0285dc2eb5b213da054674f6a81d64..289f5bb3140974d8c37f4938ceef27275b099f9a 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/BUILD +++ b/tensorflow/contrib/boosted_trees/estimator_batch/BUILD @@ -50,6 +50,7 @@ py_library( deps = [ "//tensorflow/contrib/learn", "//tensorflow/core:protos_all_py", + "//tensorflow/python:control_flow_ops", "//tensorflow/python:framework_ops", "//tensorflow/python:platform", "//tensorflow/python:training", @@ -129,3 +130,38 @@ py_library( "//tensorflow/python:math_ops", ], ) + +py_library( + name = "dnn_tree_combined_estimator", + srcs = ["dnn_tree_combined_estimator.py"], + srcs_version = "PY2AND3", + deps = [ + ":trainer_hooks", + "//tensorflow/contrib/boosted_trees:gbdt_batch", + "//tensorflow/contrib/boosted_trees:model_ops_py", + "//tensorflow/contrib/learn", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:framework_ops", + "//tensorflow/python:state_ops", + "//tensorflow/python:training", + ], +) + +py_test( + name = "dnn_tree_combined_estimator_test", + size = "small", + srcs = ["dnn_tree_combined_estimator_test.py"], + srcs_version = "PY2AND3", + tags = [ + "no_gpu", + "no_pip_gpu", + "notsan", + ], + deps = [ + ":dnn_tree_combined_estimator", + "//tensorflow/contrib/boosted_trees:gbdt_batch", + "//tensorflow/contrib/layers:layers_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_for_generated_wrappers", + ], +) diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py b/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py index ef8dee91b6cc05c4c3dd5eb3c81de4fb65b473e3..31f5c444817b9b82723c86bea3504d4934e57eb8 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py @@ -33,6 +33,8 @@ from tensorflow.python.platform import gfile from tensorflow.python.saved_model import loader as saved_model_loader from tensorflow.python.saved_model import tag_constants +_SPARSE_FLOAT_FEATURE_NAME_TEMPLATE = "%s_%d" + def make_custom_export_strategy(name, convert_fn, @@ -147,13 +149,15 @@ def convert_to_universal_format(dtec, sorted_feature_names, inequality_test.threshold.float_value = split.threshold elif node_type == "sparse_float_binary_split_default_left": split = gtflow_node.sparse_float_binary_split_default_left.split - node.default_direction = ( - generic_tree_model_pb2.BinaryNode.LEFT) - # TODO(nponomareva): adjust this id assignement when we allow multi- - # column sparse tensors. + node.default_direction = (generic_tree_model_pb2.BinaryNode.LEFT) feature_id = split.feature_column + num_dense inequality_test = node.inequality_left_child_test - inequality_test.feature_id.id.value = sorted_feature_names[feature_id] + inequality_test.feature_id.id.value = ( + _SPARSE_FLOAT_FEATURE_NAME_TEMPLATE % + (sorted_feature_names[feature_id], split.dimension_id)) + model_and_features.features.pop(sorted_feature_names[feature_id]) + (model_and_features.features[inequality_test.feature_id.id.value] + .SetInParent()) inequality_test.type = ( generic_tree_model_pb2.InequalityTest.LESS_OR_EQUAL) inequality_test.threshold.float_value = split.threshold @@ -165,7 +169,12 @@ def convert_to_universal_format(dtec, sorted_feature_names, # column sparse tensors. feature_id = split.feature_column + num_dense inequality_test = node.inequality_left_child_test - inequality_test.feature_id.id.value = sorted_feature_names[feature_id] + inequality_test.feature_id.id.value = ( + _SPARSE_FLOAT_FEATURE_NAME_TEMPLATE % + (sorted_feature_names[feature_id], split.dimension_id)) + model_and_features.features.pop(sorted_feature_names[feature_id]) + (model_and_features.features[inequality_test.feature_id.id.value] + .SetInParent()) inequality_test.type = ( generic_tree_model_pb2.InequalityTest.LESS_OR_EQUAL) inequality_test.threshold.float_value = split.threshold @@ -201,10 +210,14 @@ def _get_feature_importances(dtec, feature_names, num_dense_floats, split_column = feature_names[split.feature_column] elif node_type == "sparse_float_binary_split_default_left": split = tree_node.sparse_float_binary_split_default_left.split - split_column = feature_names[split.feature_column + num_dense_floats] + split_column = _SPARSE_FLOAT_FEATURE_NAME_TEMPLATE % ( + feature_names[split.feature_column + num_dense_floats], + split.dimension_id) elif node_type == "sparse_float_binary_split_default_right": split = tree_node.sparse_float_binary_split_default_right.split - split_column = feature_names[split.feature_column + num_dense_floats] + split_column = _SPARSE_FLOAT_FEATURE_NAME_TEMPLATE % ( + feature_names[split.feature_column + num_dense_floats], + split.dimension_id) elif node_type == "categorical_id_binary_split": split = tree_node.categorical_id_binary_split split_column = feature_names[split.feature_column + num_dense_floats + diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy_test.py b/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy_test.py index 4ed18b2d34c5af47826ab1c058f5d13797593bd4..67ec0e16bf815e9dbea6567cc87c3980a825a004 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy_test.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy_test.py @@ -12,7 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for the conversion code from GTFlow format to Chauffeur.""" +"""Tests for the conversion code and for feature importances export. + +Tests that cover conversion from TFBT format to a tensorflow.contrib. +decision_tree generic_tree_model format and feature importances export. +""" from __future__ import absolute_import from __future__ import division @@ -95,10 +99,31 @@ class ConvertModelTest(test_util.TensorFlowTestCase): } } } + nodes { + sparse_float_binary_split_default_right { + split { + feature_column: 1 + dimension_id:3 + threshold: -0.4 + left_id: 7 + right_id: 8 + } + } + node_metadata { + gain: 3600 + } + } + nodes { + leaf { + vector { + value: 0.36 + } + } + } nodes { leaf { vector { - value: 0.3 + value: 18 } } } @@ -108,17 +133,25 @@ class ConvertModelTest(test_util.TensorFlowTestCase): """ dtec = tree_config_pb2.DecisionTreeEnsembleConfig() text_format.Merge(dtec_str, dtec) - feature_columns = ["feature_b", "feature_a", "feature_d"] + feature_columns = [ + "feature_b", + "feature_a", + "feature_a_m", + "feature_d", + ] return dtec, feature_columns def testConvertModel(self): dtec, feature_columns = self._make_trees() + # Assume 2 sparse float columns, one with 1 dimension, the second one with + # 5 dimensions. # The feature columns in the order they were added. out = custom_export_strategy.convert_to_universal_format( - dtec, feature_columns, 1, 1, - 1) + dtec, feature_columns, 1, 2, 1) + # Features a and a_m are sparse float features, a_m is multidimensional. expected_tree = """ - features { key: "feature_a" } + features { key: "feature_a_0" } + features { key: "feature_a_m_3" } features { key: "feature_b" } features { key: "feature_d" } model { @@ -169,7 +202,6 @@ class ConvertModelTest(test_util.TensorFlowTestCase): } } } - nodes { node_id { value: 1 @@ -196,7 +228,7 @@ class ConvertModelTest(test_util.TensorFlowTestCase): inequality_left_child_test { feature_id { id { - value: "feature_a" + value: "feature_a_0" } } threshold { @@ -259,14 +291,51 @@ class ConvertModelTest(test_util.TensorFlowTestCase): node_id { value: 6 } + binary_node { + left_child_id { + value: 7 + } + right_child_id { + value: 8 + } + default_direction: RIGHT + inequality_left_child_test { + feature_id { + id { + value: "feature_a_m_3" + } + } + threshold { + float_value: -0.4 + } + } + } + } + nodes { + node_id { + value: 7 + } leaf { vector { value { - float_value: 0.03 + float_value: 0.036 } } } } + nodes { + node_id { + value: 8 + } + leaf { + vector { + value { + float_value: 1.8 + } + } + } + } + } } submodel_id { @@ -280,12 +349,15 @@ class ConvertModelTest(test_util.TensorFlowTestCase): def testFeatureImportance(self): dtec, feature_columns = self._make_trees() feature_importances = custom_export_strategy._get_feature_importances( - dtec, feature_columns, 1, 1, 1) - self.assertItemsEqual(["feature_b", "feature_a", "feature_d"], - feature_importances.keys()) + dtec, feature_columns, 1, 2, 1) + self.assertItemsEqual( + ["feature_b", "feature_a_0", "feature_a_m_3", "feature_d"], + feature_importances.keys()) self.assertAlmostEqual(50.0, feature_importances["feature_b"], places=4) - self.assertAlmostEqual(50.0, feature_importances["feature_a"], places=4) + self.assertAlmostEqual(50.0, feature_importances["feature_a_0"], places=4) self.assertAlmostEqual(50.0, feature_importances["feature_d"], places=4) + self.assertAlmostEqual( + 360.0, feature_importances["feature_a_m_3"], places=4) if __name__ == "__main__": diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py new file mode 100644 index 0000000000000000000000000000000000000000..cec3892b57655dc967b4e7926f7f5a6a30084487 --- /dev/null +++ b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py @@ -0,0 +1,515 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""TensorFlow estimators for combined DNN + GBDT training model. + +The combined model trains a DNN first, then trains boosted trees to boost the +logits of the DNN. The input layer of the DNN (including the embeddings learned +over sparse features) can optionally be provided to the boosted trees as +an additional input feature. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import six + +from tensorflow.contrib import layers +from tensorflow.contrib.boosted_trees.estimator_batch import trainer_hooks +from tensorflow.contrib.boosted_trees.python.ops import model_ops +from tensorflow.contrib.boosted_trees.python.training.functions import gbdt_batch +from tensorflow.contrib.layers.python.layers import optimizers +from tensorflow.contrib.learn.python.learn.estimators import estimator +from tensorflow.contrib.learn.python.learn.estimators import head as head_lib +from tensorflow.contrib.learn.python.learn.estimators import model_fn +from tensorflow.python.framework import ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import nn +from tensorflow.python.ops import partitioned_variables +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.summary import summary +from tensorflow.python.training import training_util + + +_DNN_LEARNING_RATE = 0.001 + + +def _get_optimizer(optimizer): + if callable(optimizer): + return optimizer() + else: + return optimizer + + +def _add_hidden_layer_summary(value, tag): + summary.scalar("%s_fraction_of_zero_values" % tag, nn.zero_fraction(value)) + summary.histogram("%s_activation" % tag, value) + + +def _dnn_tree_combined_model_fn( + features, labels, mode, head, dnn_hidden_units, + dnn_feature_columns, tree_learner_config, num_trees, + tree_examples_per_layer, + config=None, dnn_optimizer="Adagrad", + dnn_activation_fn=nn.relu, dnn_dropout=None, + dnn_input_layer_partitioner=None, + dnn_input_layer_to_tree=True, dnn_steps_to_train=10000, + tree_feature_columns=None, + tree_center_bias=True): + """DNN and GBDT combined model_fn. + + Args: + features: `dict` of `Tensor` objects. + labels: Labels used to train on. + mode: Mode we are in. (TRAIN/EVAL/INFER) + head: A `Head` instance. + dnn_hidden_units: List of hidden units per layer. + dnn_feature_columns: An iterable containing all the feature columns + used by the model's DNN. + tree_learner_config: A config for the tree learner. + num_trees: Number of trees to grow model to after training DNN. + tree_examples_per_layer: Number of examples to accumulate before + growing the tree a layer. This value has a big impact on model + quality and should be set equal to the number of examples in + training dataset if possible. It can also be a function that computes + the number of examples based on the depth of the layer that's + being built. + config: `RunConfig` of the estimator. + dnn_optimizer: string, `Optimizer` object, or callable that defines the + optimizer to use for training the DNN. If `None`, will use the Adagrad + optimizer with default learning rate of 0.001. + dnn_activation_fn: Activation function applied to each layer of the DNN. + If `None`, will use `tf.nn.relu`. + dnn_dropout: When not `None`, the probability to drop out a given + unit in the DNN. + dnn_input_layer_partitioner: Partitioner for input layer of the DNN. + Defaults to `min_max_variable_partitioner` with `min_slice_size` 64 << 20. + dnn_input_layer_to_tree: Whether to provide the DNN's input layer + as a feature to the tree. + dnn_steps_to_train: Number of steps to train dnn for before switching + to gbdt. + tree_feature_columns: An iterable containing all the feature columns + used by the model's boosted trees. If dnn_input_layer_to_tree is + set to True, these features are in addition to dnn_feature_columns. + tree_center_bias: Whether a separate tree should be created for + first fitting the bias. + + Returns: + A `ModelFnOps` object. + Raises: + ValueError: if inputs are not valid. + """ + if not isinstance(features, dict): + raise ValueError("features should be a dictionary of `Tensor`s. " + "Given type: {}".format(type(features))) + + if not dnn_feature_columns: + raise ValueError("dnn_feature_columns must be specified") + + # Build DNN Logits. + dnn_parent_scope = "dnn" + dnn_partitioner = dnn_input_layer_partitioner or ( + partitioned_variables.min_max_variable_partitioner( + max_partitions=config.num_ps_replicas, + min_slice_size=64 << 20)) + + with variable_scope.variable_scope( + dnn_parent_scope, + values=tuple(six.itervalues(features)), + partitioner=dnn_partitioner): + + with variable_scope.variable_scope( + "input_from_feature_columns", + values=tuple(six.itervalues(features)), + partitioner=dnn_partitioner) as input_layer_scope: + input_layer = layers.input_from_feature_columns( + columns_to_tensors=features, + feature_columns=dnn_feature_columns, + weight_collections=[dnn_parent_scope], + scope=input_layer_scope) + previous_layer = input_layer + for layer_id, num_hidden_units in enumerate(dnn_hidden_units): + with variable_scope.variable_scope( + "hiddenlayer_%d" % layer_id, + values=(previous_layer,)) as hidden_layer_scope: + net = layers.fully_connected( + previous_layer, + num_hidden_units, + activation_fn=dnn_activation_fn, + variables_collections=[dnn_parent_scope], + scope=hidden_layer_scope) + if dnn_dropout is not None and mode == model_fn.ModeKeys.TRAIN: + net = layers.dropout(net, keep_prob=(1.0 - dnn_dropout)) + _add_hidden_layer_summary(net, hidden_layer_scope.name) + previous_layer = net + with variable_scope.variable_scope( + "logits", + values=(previous_layer,)) as logits_scope: + dnn_logits = layers.fully_connected( + previous_layer, + head.logits_dimension, + activation_fn=None, + variables_collections=[dnn_parent_scope], + scope=logits_scope) + _add_hidden_layer_summary(dnn_logits, logits_scope.name) + + def _dnn_train_op_fn(loss): + """Returns the op to optimize the loss.""" + return optimizers.optimize_loss( + loss=loss, + global_step=training_util.get_global_step(), + learning_rate=_DNN_LEARNING_RATE, + optimizer=_get_optimizer(dnn_optimizer), + name=dnn_parent_scope, + variables=ops.get_collection( + ops.GraphKeys.TRAINABLE_VARIABLES, + scope=dnn_parent_scope), + # Empty summaries to prevent optimizers from logging training_loss. + summaries=[]) + + # Build Tree Logits. + global_step = training_util.get_global_step() + with ops.device(global_step.device): + ensemble_handle = model_ops.tree_ensemble_variable( + stamp_token=0, + tree_ensemble_config="", # Initialize an empty ensemble. + name="ensemble_model") + + tree_features = features.copy() + if dnn_input_layer_to_tree: + tree_features["dnn_input_layer"] = input_layer + tree_feature_columns.append(layers.real_valued_column("dnn_input_layer")) + gbdt_model = gbdt_batch.GradientBoostedDecisionTreeModel( + is_chief=config.is_chief, + num_ps_replicas=config.num_ps_replicas, + ensemble_handle=ensemble_handle, + center_bias=tree_center_bias, + examples_per_layer=tree_examples_per_layer, + learner_config=tree_learner_config, + feature_columns=tree_feature_columns, + logits_dimension=head.logits_dimension, + features=tree_features) + + with ops.name_scope("gbdt"): + predictions_dict = gbdt_model.predict(mode) + tree_logits = predictions_dict["predictions"] + + def _tree_train_op_fn(loss): + """Returns the op to optimize the loss.""" + update_op = gbdt_model.train(loss, predictions_dict, labels) + with ops.control_dependencies( + [update_op]), (ops.colocate_with(global_step)): + update_op = state_ops.assign_add(global_step, 1).op + return update_op + + tree_train_logits = dnn_logits + tree_logits + + def _no_train_op_fn(loss): + """Returns a no-op.""" + del loss + return control_flow_ops.no_op() + + model_fn_ops = head.create_model_fn_ops( + features=features, + mode=mode, + labels=labels, + train_op_fn=_no_train_op_fn, + logits=tree_train_logits) + dnn_train_op = head.create_model_fn_ops( + features=features, + mode=mode, + labels=labels, + train_op_fn=_dnn_train_op_fn, + logits=dnn_logits).train_op + tree_train_op = head.create_model_fn_ops( + features=tree_features, + mode=mode, + labels=labels, + train_op_fn=_tree_train_op_fn, + logits=tree_train_logits).train_op + + if tree_center_bias: + num_trees += 1 + finalized_trees, attempted_trees = gbdt_model.get_number_of_trees_tensor() + + model_fn_ops.training_hooks.extend([ + trainer_hooks.SwitchTrainOp( + dnn_train_op, dnn_steps_to_train, tree_train_op), + trainer_hooks.StopAfterNTrees( + num_trees, attempted_trees, finalized_trees)]) + + return model_fn_ops + + +class DNNBoostedTreeCombinedClassifier(estimator.Estimator): + """A classifier that uses a combined DNN/GBDT model.""" + + def __init__(self, + dnn_hidden_units, + dnn_feature_columns, + tree_learner_config, + num_trees, + tree_examples_per_layer, + n_classes=2, + weight_column_name=None, + model_dir=None, + config=None, + label_name=None, + label_keys=None, + feature_engineering_fn=None, + dnn_optimizer="Adagrad", + dnn_activation_fn=nn.relu, + dnn_dropout=None, + dnn_input_layer_partitioner=None, + dnn_input_layer_to_tree=True, + dnn_steps_to_train=10000, + tree_feature_columns=None, + tree_center_bias=True): + """Initializes a DNNBoostedTreeCombinedClassifier instance. + + Args: + dnn_hidden_units: List of hidden units per layer for DNN. + dnn_feature_columns: An iterable containing all the feature columns + used by the model's DNN. + tree_learner_config: A config for the tree learner. + num_trees: Number of trees to grow model to after training DNN. + tree_examples_per_layer: Number of examples to accumulate before + growing the tree a layer. This value has a big impact on model + quality and should be set equal to the number of examples in + training dataset if possible. It can also be a function that computes + the number of examples based on the depth of the layer that's + being built. + n_classes: The number of label classes. + weight_column_name: The name of weight column. + model_dir: Directory for model exports. + config: `RunConfig` of the estimator. + label_name: String, name of the key in label dict. Can be null if label + is a tensor (single headed models). + label_keys: Optional list of strings with size `[n_classes]` defining the + label vocabulary. Only supported for `n_classes` > 2. + feature_engineering_fn: Feature engineering function. Takes features and + labels which are the output of `input_fn` and returns features and + labels which will be fed into the model. + dnn_optimizer: string, `Optimizer` object, or callable that defines the + optimizer to use for training the DNN. If `None`, will use the Adagrad + optimizer with default learning rate. + dnn_activation_fn: Activation function applied to each layer of the DNN. + If `None`, will use `tf.nn.relu`. + dnn_dropout: When not `None`, the probability to drop out a given + unit in the DNN. + dnn_input_layer_partitioner: Partitioner for input layer of the DNN. + Defaults to `min_max_variable_partitioner` with `min_slice_size` + 64 << 20. + dnn_input_layer_to_tree: Whether to provide the DNN's input layer + as a feature to the tree. + dnn_steps_to_train: Number of steps to train dnn for before switching + to gbdt. + tree_feature_columns: An iterable containing all the feature columns + used by the model's boosted trees. If dnn_input_layer_to_tree is + set to True, these features are in addition to dnn_feature_columns. + tree_center_bias: Whether a separate tree should be created for + first fitting the bias. + """ + head = head_lib.multi_class_head( + n_classes=n_classes, + label_name=label_name, + label_keys=label_keys, + weight_column_name=weight_column_name, + enable_centered_bias=False) + + def _model_fn(features, labels, mode, config): + return _dnn_tree_combined_model_fn( + features, labels, mode, head, dnn_hidden_units, dnn_feature_columns, + tree_learner_config, num_trees, tree_examples_per_layer, config, + dnn_optimizer, dnn_activation_fn, dnn_dropout, + dnn_input_layer_partitioner, dnn_input_layer_to_tree, + dnn_steps_to_train, + tree_feature_columns, tree_center_bias) + + super(DNNBoostedTreeCombinedClassifier, self).__init__( + model_fn=_model_fn, model_dir=model_dir, + config=config, feature_engineering_fn=feature_engineering_fn) + + +class DNNBoostedTreeCombinedRegressor(estimator.Estimator): + """A regressor that uses a combined DNN/GBDT model.""" + + def __init__(self, + dnn_hidden_units, + dnn_feature_columns, + tree_learner_config, + num_trees, + tree_examples_per_layer, + weight_column_name=None, + model_dir=None, + config=None, + label_name=None, + label_dimension=1, + feature_engineering_fn=None, + dnn_optimizer="Adagrad", + dnn_activation_fn=nn.relu, + dnn_dropout=None, + dnn_input_layer_partitioner=None, + dnn_input_layer_to_tree=True, + dnn_steps_to_train=10000, + tree_feature_columns=None, + tree_center_bias=True): + """Initializes a DNNBoostedTreeCombinedRegressor instance. + + Args: + dnn_hidden_units: List of hidden units per layer for DNN. + dnn_feature_columns: An iterable containing all the feature columns + used by the model's DNN. + tree_learner_config: A config for the tree learner. + num_trees: Number of trees to grow model to after training DNN. + tree_examples_per_layer: Number of examples to accumulate before + growing the tree a layer. This value has a big impact on model + quality and should be set equal to the number of examples in + training dataset if possible. It can also be a function that computes + the number of examples based on the depth of the layer that's + being built. + weight_column_name: The name of weight column. + model_dir: Directory for model exports. + config: `RunConfig` of the estimator. + label_name: String, name of the key in label dict. Can be null if label + is a tensor (single headed models). + label_dimension: Number of regression labels per example. This is the size + of the last dimension of the labels `Tensor` (typically, this has shape + `[batch_size, label_dimension]`). + feature_engineering_fn: Feature engineering function. Takes features and + labels which are the output of `input_fn` and returns features and + labels which will be fed into the model. + dnn_optimizer: string, `Optimizer` object, or callable that defines the + optimizer to use for training the DNN. If `None`, will use the Adagrad + optimizer with default learning rate. + dnn_activation_fn: Activation function applied to each layer of the DNN. + If `None`, will use `tf.nn.relu`. + dnn_dropout: When not `None`, the probability to drop out a given + unit in the DNN. + dnn_input_layer_partitioner: Partitioner for input layer of the DNN. + Defaults to `min_max_variable_partitioner` with `min_slice_size` + 64 << 20. + dnn_input_layer_to_tree: Whether to provide the DNN's input layer + as a feature to the tree. + dnn_steps_to_train: Number of steps to train dnn for before switching + to gbdt. + tree_feature_columns: An iterable containing all the feature columns + used by the model's boosted trees. If dnn_input_layer_to_tree is + set to True, these features are in addition to dnn_feature_columns. + tree_center_bias: Whether a separate tree should be created for + first fitting the bias. + """ + head = head_lib.regression_head( + label_name=label_name, + label_dimension=label_dimension, + weight_column_name=weight_column_name, + enable_centered_bias=False) + + # num_classes needed for GradientBoostedDecisionTreeModel + if label_dimension == 1: + tree_learner_config.num_classes = 2 + else: + tree_learner_config.num_classes = label_dimension + + def _model_fn(features, labels, mode, config): + return _dnn_tree_combined_model_fn( + features, labels, mode, head, dnn_hidden_units, dnn_feature_columns, + tree_learner_config, num_trees, tree_examples_per_layer, config, + dnn_optimizer, dnn_activation_fn, dnn_dropout, + dnn_input_layer_partitioner, dnn_input_layer_to_tree, + dnn_steps_to_train, tree_feature_columns, tree_center_bias) + + super(DNNBoostedTreeCombinedRegressor, self).__init__( + model_fn=_model_fn, model_dir=model_dir, + config=config, feature_engineering_fn=feature_engineering_fn) + + +class DNNBoostedTreeCombinedEstimator(estimator.Estimator): + """An estimator that uses a combined DNN/GBDT model. + + Useful for training with user specified `Head`. + """ + + def __init__(self, + dnn_hidden_units, + dnn_feature_columns, + tree_learner_config, + num_trees, + tree_examples_per_layer, + head, + model_dir=None, + config=None, + feature_engineering_fn=None, + dnn_optimizer="Adagrad", + dnn_activation_fn=nn.relu, + dnn_dropout=None, + dnn_input_layer_partitioner=None, + dnn_input_layer_to_tree=True, + dnn_steps_to_train=10000, + tree_feature_columns=None, + tree_center_bias=True): + """Initializes a DNNBoostedTreeCombinedEstimator instance. + + Args: + dnn_hidden_units: List of hidden units per layer for DNN. + dnn_feature_columns: An iterable containing all the feature columns + used by the model's DNN. + tree_learner_config: A config for the tree learner. + num_trees: Number of trees to grow model to after training DNN. + tree_examples_per_layer: Number of examples to accumulate before + growing the tree a layer. This value has a big impact on model + quality and should be set equal to the number of examples in + training dataset if possible. It can also be a function that computes + the number of examples based on the depth of the layer that's + being built. + head: `Head` instance. + model_dir: Directory for model exports. + config: `RunConfig` of the estimator. + feature_engineering_fn: Feature engineering function. Takes features and + labels which are the output of `input_fn` and returns features and + labels which will be fed into the model. + dnn_optimizer: string, `Optimizer` object, or callable that defines the + optimizer to use for training the DNN. If `None`, will use the Adagrad + optimizer with default learning rate. + dnn_activation_fn: Activation function applied to each layer of the DNN. + If `None`, will use `tf.nn.relu`. + dnn_dropout: When not `None`, the probability to drop out a given + unit in the DNN. + dnn_input_layer_partitioner: Partitioner for input layer of the DNN. + Defaults to `min_max_variable_partitioner` with `min_slice_size` + 64 << 20. + dnn_input_layer_to_tree: Whether to provide the DNN's input layer + as a feature to the tree. + dnn_steps_to_train: Number of steps to train dnn for before switching + to gbdt. + tree_feature_columns: An iterable containing all the feature columns + used by the model's boosted trees. If dnn_input_layer_to_tree is + set to True, these features are in addition to dnn_feature_columns. + tree_center_bias: Whether a separate tree should be created for + first fitting the bias. + """ + def _model_fn(features, labels, mode, config): + return _dnn_tree_combined_model_fn( + features, labels, mode, head, dnn_hidden_units, dnn_feature_columns, + tree_learner_config, num_trees, tree_examples_per_layer, config, + dnn_optimizer, dnn_activation_fn, dnn_dropout, + dnn_input_layer_partitioner, dnn_input_layer_to_tree, + dnn_steps_to_train, + tree_feature_columns, tree_center_bias) + + super(DNNBoostedTreeCombinedEstimator, self).__init__( + model_fn=_model_fn, model_dir=model_dir, + config=config, feature_engineering_fn=feature_engineering_fn) diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py new file mode 100644 index 0000000000000000000000000000000000000000..83d58c561008e8a5a69eb503d1605bb9e940f281 --- /dev/null +++ b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py @@ -0,0 +1,105 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for combined DNN + GBDT estimators.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tempfile + +from tensorflow.contrib.boosted_trees.estimator_batch import dnn_tree_combined_estimator as estimator +from tensorflow.contrib.boosted_trees.proto import learner_pb2 +from tensorflow.contrib.layers.python.layers import feature_column +from tensorflow.contrib.learn.python.learn.estimators import estimator_test_utils +from tensorflow.contrib.learn.python.learn.estimators import run_config +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import test_util +from tensorflow.python.platform import googletest + + +def _train_input_fn(): + features = { + "x": constant_op.constant([[2.], [1.], [1.]]) + } + label = constant_op.constant([[1], [0], [0]], dtype=dtypes.int32) + return features, label + + +def _eval_input_fn(): + features = { + "x": constant_op.constant([[1.], [2.], [2.]]) + } + label = constant_op.constant([[0], [1], [1]], dtype=dtypes.int32) + return features, label + + +class DNNBoostedTreeCombinedTest(test_util.TensorFlowTestCase): + + def testClassifierContract(self): + estimator_test_utils.assert_estimator_contract( + self, estimator.DNNBoostedTreeCombinedClassifier) + + def testRegressorContract(self): + estimator_test_utils.assert_estimator_contract( + self, estimator.DNNBoostedTreeCombinedRegressor) + + def testEstimatorContract(self): + estimator_test_utils.assert_estimator_contract( + self, estimator.DNNBoostedTreeCombinedEstimator) + + def testNoDNNFeatureColumns(self): + learner_config = learner_pb2.LearnerConfig() + learner_config.num_classes = 2 + + with self.assertRaisesRegexp( + ValueError, + "dnn_feature_columns must be specified"): + classifier = estimator.DNNBoostedTreeCombinedClassifier( + dnn_hidden_units=[1], + dnn_feature_columns=[], + tree_learner_config=learner_config, + num_trees=1, + tree_examples_per_layer=3, + n_classes=2) + classifier.fit(input_fn=_train_input_fn, steps=5) + + def testFitAndEvaluateDontThrowException(self): + learner_config = learner_pb2.LearnerConfig() + learner_config.num_classes = 2 + learner_config.constraints.max_tree_depth = 1 + model_dir = tempfile.mkdtemp() + config = run_config.RunConfig() + + classifier = estimator.DNNBoostedTreeCombinedClassifier( + dnn_hidden_units=[1], + dnn_feature_columns=[feature_column.real_valued_column("x")], + tree_learner_config=learner_config, + num_trees=1, + tree_examples_per_layer=3, + n_classes=2, + model_dir=model_dir, + config=config, + dnn_steps_to_train=10, + dnn_input_layer_to_tree=False, + tree_feature_columns=[feature_column.real_valued_column("x")]) + + classifier.fit(input_fn=_train_input_fn, steps=15) + classifier.evaluate(input_fn=_eval_input_fn, steps=1) + + +if __name__ == "__main__": + googletest.main() diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/trainer_hooks.py b/tensorflow/contrib/boosted_trees/estimator_batch/trainer_hooks.py index 79193fffc3d3fa97e20a12181bf20e6ad86dcb58..2e4151cac40f770e2bece70d752122eb7f34dd40 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/trainer_hooks.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/trainer_hooks.py @@ -24,6 +24,7 @@ from tensorflow.contrib.learn.python.learn import session_run_hook from tensorflow.contrib.learn.python.learn.session_run_hook import SessionRunArgs from tensorflow.core.framework.summary_pb2 import Summary from tensorflow.python.framework import ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import training_util from tensorflow.python.training.summary_io import SummaryWriterCache @@ -175,3 +176,40 @@ class StopAfterNTrees(session_run_hook.SessionRunHook): logging.info("Requesting stop since we have reached %d trees.", num_finalized_trees) run_context.request_stop() + + +class SwitchTrainOp(session_run_hook.SessionRunHook): + """Hook that switches the train op after specified number of steps. + + Hook that replaces the train op depending on the number of steps of training + that have taken place. The first_train_op is used till train_steps steps + are reached. Thereafter the second_train_op is used. + """ + + def __init__(self, first_train_op, train_steps, second_train_op): + """Initializes a `SwitchTrainOp`.""" + self._first_train_op = first_train_op + self._second_train_op = second_train_op + self._train_steps = train_steps + + def _get_train_op_for_global_step(self, current_step): + """Gets train_op for current global step.""" + if current_step < self._train_steps: + return self._first_train_op + return self._second_train_op + + def begin(self): + self._global_step_tensor = training_util.get_global_step() + self._current_train_op = control_flow_ops.no_op() + if self._global_step_tensor is None: + raise RuntimeError( + "Global step should be created to use SwitchTrainOp.") + + def before_run(self, run_context): # pylint: disable=unused-argument + return session_run_hook.SessionRunArgs( + {"global_step": self._global_step_tensor, + "train_op": self._current_train_op}) + + def after_run(self, run_context, run_values): + self._current_train_op = self._get_train_op_for_global_step( + run_values.results["global_step"]) diff --git a/tensorflow/contrib/boosted_trees/examples/boston_combined.py b/tensorflow/contrib/boosted_trees/examples/boston_combined.py new file mode 100644 index 0000000000000000000000000000000000000000..e04b56afbfd266dc13a5b0d78d171ea273415ee3 --- /dev/null +++ b/tensorflow/contrib/boosted_trees/examples/boston_combined.py @@ -0,0 +1,165 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +r"""Regression on Boston housing data using DNNBoostedTreeCombinedRegressor. + + Example Usage: + + python tensorflow/contrib/boosted_trees/examples/boston_combined.py \ + --batch_size=404 --output_dir="/tmp/boston" \ + --dnn_hidden_units="8,4" --dnn_steps_to_train=1000 \ + --tree_depth=4 --tree_learning_rate=0.1 \ + --num_trees=100 --tree_l2=0.001 --num_eval_steps=1 \ + --vmodule=training_ops=1 + + When training is done, mean squared error on eval data is reported. + Point tensorboard to the directory for the run to see how the training + progresses: + + tensorboard --logdir=/tmp/boston + +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import sys +import tensorflow as tf + +from tensorflow.contrib.boosted_trees.estimator_batch.dnn_tree_combined_estimator import DNNBoostedTreeCombinedRegressor +from tensorflow.contrib.boosted_trees.proto import learner_pb2 +from tensorflow.contrib.layers.python.layers import feature_column +from tensorflow.contrib.learn.python.learn import learn_runner +from tensorflow.contrib.learn.python.learn.utils import input_fn_utils +from tensorflow.contrib.learn.python.learn.utils import saved_model_export_utils + +_BOSTON_NUM_FEATURES = 13 + + +def _get_estimator(output_dir, feature_cols): + """Configures DNNBoostedTreeCombinedRegressor based on flags.""" + learner_config = learner_pb2.LearnerConfig() + learner_config.learning_rate_tuner.fixed.learning_rate = ( + FLAGS.tree_learning_rate) + learner_config.regularization.l1 = 0.0 + learner_config.regularization.l2 = FLAGS.tree_l2 + learner_config.constraints.max_tree_depth = FLAGS.tree_depth + + run_config = tf.contrib.learn.RunConfig(save_summary_steps=1) + + # Create a DNNBoostedTreeCombinedRegressor estimator. + estimator = DNNBoostedTreeCombinedRegressor( + dnn_hidden_units=[int(x) for x in FLAGS.dnn_hidden_units.split(",")], + dnn_feature_columns=feature_cols, + tree_learner_config=learner_config, + num_trees=FLAGS.num_trees, + # This should be the number of examples. For large datasets it can be + # larger than the batch_size. + tree_examples_per_layer=FLAGS.batch_size, + model_dir=output_dir, + config=run_config, + dnn_input_layer_to_tree=True, + dnn_steps_to_train=FLAGS.dnn_steps_to_train) + return estimator + + +def _make_experiment_fn(output_dir): + """Creates experiment for DNNBoostedTreeCombinedRegressor.""" + (x_train, y_train), (x_test, + y_test) = tf.keras.datasets.boston_housing.load_data() + + train_input_fn = tf.estimator.inputs.numpy_input_fn( + x={"x": x_train}, + y=y_train, + batch_size=FLAGS.batch_size, + num_epochs=None, + shuffle=True) + eval_input_fn = tf.estimator.inputs.numpy_input_fn( + x={"x": x_test}, y=y_test, num_epochs=1, shuffle=False) + + feature_columns = [ + feature_column.real_valued_column("x", dimension=_BOSTON_NUM_FEATURES) + ] + feature_spec = tf.contrib.layers.create_feature_spec_for_parsing( + feature_columns) + serving_input_fn = input_fn_utils.build_parsing_serving_input_fn(feature_spec) + export_strategies = [ + saved_model_export_utils.make_export_strategy(serving_input_fn)] + return tf.contrib.learn.Experiment( + estimator=_get_estimator(output_dir, feature_columns), + train_input_fn=train_input_fn, + eval_input_fn=eval_input_fn, + train_steps=None, + eval_steps=FLAGS.num_eval_steps, + eval_metrics=None, + export_strategies=export_strategies) + + +def main(unused_argv): + learn_runner.run( + experiment_fn=_make_experiment_fn, + output_dir=FLAGS.output_dir, + schedule="train_and_evaluate") + + +if __name__ == "__main__": + tf.logging.set_verbosity(tf.logging.INFO) + parser = argparse.ArgumentParser() + # Define the list of flags that users can change. + parser.add_argument( + "--batch_size", + type=int, + default=1000, + help="The batch size for reading data.") + parser.add_argument( + "--output_dir", + type=str, + required=True, + help="Choose the dir for the output.") + parser.add_argument( + "--num_eval_steps", + type=int, + default=1, + help="The number of steps to run evaluation for.") + # Flags for configuring DNNBoostedTreeCombinedRegressor. + parser.add_argument( + "--dnn_hidden_units", + type=str, + default="8,4", + help="Hidden layers for DNN.") + parser.add_argument( + "--dnn_steps_to_train", + type=int, + default=1000, + help="Number of steps to train DNN.") + parser.add_argument( + "--tree_depth", type=int, default=4, help="Maximum depth of trees.") + parser.add_argument( + "--tree_l2", type=float, default=1.0, help="l2 regularization per batch.") + parser.add_argument( + "--tree_learning_rate", + type=float, + default=0.1, + help=("Learning rate (shrinkage weight) with which each " + "new tree is added.")) + parser.add_argument( + "--num_trees", + type=int, + default=None, + required=True, + help="Number of trees to grow before stopping.") + + FLAGS, unparsed = parser.parse_known_args() + tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/contrib/boosted_trees/kernels/training_ops.cc b/tensorflow/contrib/boosted_trees/kernels/training_ops.cc index c77d90e243c304ec8e9a10a0b63401f9bd825c3e..7f8dea1d3c2a04b725843f6e2932a0cdfbc7733c 100644 --- a/tensorflow/contrib/boosted_trees/kernels/training_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/training_ops.cc @@ -361,10 +361,27 @@ class GrowTreeEnsembleOp : public OpKernel { // Increment attempt stats. ensemble_resource->IncrementAttempts(); + // In case we want to do feature selection and we have reached the limit, + // build a list of handlers used so far to avoid adding new features. + std::vector allowed_handlers; + if (learner_config_.constraints().max_number_of_unique_feature_columns() > + 0) { + allowed_handlers = ensemble_resource->GetUsedHandlers(); + // TODO(soroush): We can disable handlers that are not going to be used to + // avoid unnecessary computations. + if (allowed_handlers.size() < + learner_config_.constraints() + .max_number_of_unique_feature_columns()) { + // We have not reached the limit yet. Empty the list of allow features + // which means we can keep adding new features. + allowed_handlers.clear(); + } + } + // Find best splits for each active partition. std::map best_splits; - FindBestSplitsPerPartition(context, partition_ids_list, gains_list, - splits_list, &best_splits); + FindBestSplitsPerPartition(context, allowed_handlers, partition_ids_list, + gains_list, splits_list, &best_splits); // No-op if no new splits can be considered. if (best_splits.empty()) { @@ -381,7 +398,8 @@ class GrowTreeEnsembleOp : public OpKernel { // Split tree nodes. for (auto& split_entry : best_splits) { - SplitTreeNode(split_entry.first, &split_entry.second, tree_config); + SplitTreeNode(split_entry.first, &split_entry.second, tree_config, + ensemble_resource); } // Post-prune finalized tree if needed. @@ -403,12 +421,20 @@ class GrowTreeEnsembleOp : public OpKernel { // Helper method which effectively does a reduce over all split candidates // and finds the best split for each partition. void FindBestSplitsPerPartition( - OpKernelContext* const context, const OpInputList& partition_ids_list, - const OpInputList& gains_list, const OpInputList& splits_list, + OpKernelContext* const context, + const std::vector& allowed_handlers, // Empty means all handlers. + const OpInputList& partition_ids_list, const OpInputList& gains_list, + const OpInputList& splits_list, std::map* best_splits) { // Find best split per partition going through every feature candidate. // TODO(salehay): Is this worth parallelizing? for (int64 handler_id = 0; handler_id < num_handlers_; ++handler_id) { + if (!allowed_handlers.empty()) { + if (!std::binary_search(allowed_handlers.begin(), + allowed_handlers.end(), handler_id)) { + continue; + } + } const auto& partition_ids = partition_ids_list[handler_id].vec(); const auto& gains = gains_list[handler_id].vec(); const auto& splits = splits_list[handler_id].vec(); @@ -592,8 +618,10 @@ class GrowTreeEnsembleOp : public OpKernel { // Helper method to split a tree node and append its respective // leaf children given the split candidate. - void SplitTreeNode(const int32 node_id, SplitCandidate* split, - boosted_trees::trees::DecisionTreeConfig* tree_config) { + void SplitTreeNode( + const int32 node_id, SplitCandidate* split, + boosted_trees::trees::DecisionTreeConfig* tree_config, + boosted_trees::models::DecisionTreeEnsembleResource* ensemble_resource) { // No-op if we have no real node. CHECK(node_id < tree_config->nodes_size()) << "Invalid node " << node_id << " to split."; @@ -633,6 +661,9 @@ class GrowTreeEnsembleOp : public OpKernel { // Replace node in tree. (*tree_config->mutable_nodes(node_id)) = *split->split_info.mutable_split_node(); + if (learner_config_.constraints().max_number_of_unique_feature_columns()) { + ensemble_resource->MaybeAddUsedHandler(split->handler_id); + } } void PruneTree(boosted_trees::trees::DecisionTreeConfig* tree_config) { diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py index 72e20aaa127cda592bd314786cddb925cc87a075..7df514cd207c5e781f3b4abaa2020016b197669d 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py @@ -436,7 +436,7 @@ def dense_make_stats_update(is_active, are_buckets_ready, float_column, quantized_feature = quantile_ops.quantiles([float_column], [], [quantile_buckets], [], []) quantized_feature = math_ops.cast(quantized_feature[0], dtypes.int64) - quantized_feature = array_ops.squeeze(quantized_feature) + quantized_feature = array_ops.squeeze(quantized_feature, axis=0) return (example_partition_ids, quantized_feature, gradients, hessians) def not_ready_inputs_fn(): @@ -468,7 +468,7 @@ def sparse_make_stats_update( [sparse_column_indices]) quantized_feature = math_ops.cast(quantized_feature[1], dtypes.int64) - quantized_feature = array_ops.squeeze(quantized_feature) + quantized_feature = array_ops.squeeze(quantized_feature, axis=0) example_indices, _ = array_ops.split( sparse_column_indices, num_or_size_splits=2, axis=1) diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py index ee16a5f838a65f20db4436eb86527518621b6d8d..54d03018d9e266beabbbabd78ebbb80cfe689c04 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py @@ -1121,6 +1121,87 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): self.assertEqual(len(gains), 0) self.assertEqual(len(splits), 0) + def testDegenerativeCase(self): + with self.test_session() as sess: + # One data example only, one leaf and thus one quantile bucket.The same + # situation is when all examples have the same values. This case was + # causing before a failure. + gradients = array_ops.constant([0.2]) + hessians = array_ops.constant([0.12]) + example_partitions = array_ops.constant([1], dtype=dtypes.int32) + indices = array_ops.constant([[0, 0]], dtype=dtypes.int64) + values = array_ops.constant([0.58]) + sparse_column = sparse_tensor.SparseTensor(indices, values, [1, 1]) + + gradient_shape = tensor_shape.scalar() + hessian_shape = tensor_shape.scalar() + class_id = -1 + + split_handler = ordinal_split_handler.SparseSplitHandler( + l1_regularization=0, + l2_regularization=2, + tree_complexity_regularization=0, + min_node_weight=0, + epsilon=0.01, + num_quantiles=2, + feature_column_group_id=0, + sparse_float_column=sparse_column, + init_stamp_token=0, + gradient_shape=gradient_shape, + hessian_shape=hessian_shape, + multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS) + resources.initialize_resources(resources.shared_resources()).run() + + empty_gradients, empty_hessians = get_empty_tensors( + gradient_shape, hessian_shape) + example_weights = array_ops.ones([1, 1], dtypes.float32) + + update_1 = split_handler.update_stats_sync( + 0, + example_partitions, + gradients, + hessians, + empty_gradients, + empty_hessians, + example_weights, + is_active=array_ops.constant([True, True])) + with ops.control_dependencies([update_1]): + are_splits_ready = split_handler.make_splits(0, 1, class_id)[0] + + with ops.control_dependencies([are_splits_ready]): + update_2 = split_handler.update_stats_sync( + 1, + example_partitions, + gradients, + hessians, + empty_gradients, + empty_hessians, + example_weights, + is_active=array_ops.constant([True, True])) + with ops.control_dependencies([update_2]): + are_splits_ready2, partitions, gains, splits = ( + split_handler.make_splits(1, 2, class_id)) + are_splits_ready, are_splits_ready2, partitions, gains, splits = ( + sess.run([ + are_splits_ready, are_splits_ready2, partitions, gains, splits + ])) + + # During the first iteration, inequality split handlers are not going to + # have any splits. Make sure that we return not_ready in that case. + self.assertFalse(are_splits_ready) + self.assertTrue(are_splits_ready2) + + self.assertAllEqual([1], partitions) + self.assertAllEqual([0.0], gains) + + split_info = split_info_pb2.SplitInfo() + split_info.ParseFromString(splits[0]) + split_node = split_info.split_node.sparse_float_binary_split_default_left + + self.assertEqual(0, split_node.split.feature_column) + + self.assertAllClose(0.58, split_node.split.threshold) + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/contrib/boosted_trees/lib/utils/sparse_column_iterable.cc b/tensorflow/contrib/boosted_trees/lib/utils/sparse_column_iterable.cc index 0d46565a1962b88cbb267f3d6043610758790578..1297aa884938f2f099a32568acc80c6cd8162651 100644 --- a/tensorflow/contrib/boosted_trees/lib/utils/sparse_column_iterable.cc +++ b/tensorflow/contrib/boosted_trees/lib/utils/sparse_column_iterable.cc @@ -51,7 +51,7 @@ class IndicesRowIterator return tmp; } - reference operator*() { return iter_->ix()(row_idx_, 0); } + reference operator*() const { return iter_->ix()(row_idx_, 0); } pointer operator->() { return &iter_->ix()(row_idx_, 0); } @@ -97,7 +97,7 @@ class IndicesRowIterator } bool operator<(const IndicesRowIterator& other) const { - return (row_idx_ < other.row_idx_); + return (row_idx_ < other.row_idx_); } bool operator==(const IndicesRowIterator& other) const { diff --git a/tensorflow/contrib/boosted_trees/proto/learner.proto b/tensorflow/contrib/boosted_trees/proto/learner.proto index 919e7cd81427c27cf892bc77998f52406d2bcf15..d84ba7438e7f03685d5bafca52ff8283f0fce898 100644 --- a/tensorflow/contrib/boosted_trees/proto/learner.proto +++ b/tensorflow/contrib/boosted_trees/proto/learner.proto @@ -22,6 +22,10 @@ message TreeConstraintsConfig { // Min hessian weight per node. float min_node_weight = 2; + + // Maximum number of unique features used in the tree. Zero means there is no + // limit. + int64 max_number_of_unique_feature_columns = 3; } // LearningRateConfig describes all supported learning rate tuners. diff --git a/tensorflow/contrib/boosted_trees/proto/tree_config.proto b/tensorflow/contrib/boosted_trees/proto/tree_config.proto index fc570c1083d01a65760a456c109dad93afd9f62a..4407c4d981785a279b6296f4726a221cacb4c5b1 100644 --- a/tensorflow/contrib/boosted_trees/proto/tree_config.proto +++ b/tensorflow/contrib/boosted_trees/proto/tree_config.proto @@ -128,6 +128,10 @@ message GrowingMetadata { // Number of layers that we have attempted to build. After pruning, these // layers might have been removed. int64 num_layers_attempted = 2; + + // Sorted list of column handlers that have been used in at least one split + // so far. + repeated int64 used_handler_ids = 3; } // DecisionTreeEnsembleConfig describes an ensemble of decision trees. diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py index c2e65b643df90e88aadb0bb9acaf692da35b1a16..8ca1aabacaf53b66aaba184962922294427d6803 100644 --- a/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py +++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py @@ -63,7 +63,7 @@ def _gen_learner_config(num_classes, if dropout_prob_of_skipping is not None: config.learning_rate_tuner.dropout.dropout_prob_of_skipping = ( dropout_prob_of_skipping) - return config.SerializeToString() + return config def _gen_dense_split_info(fc, threshold, left_weight, right_weight): @@ -145,7 +145,7 @@ class CenterTreeEnsembleBiasOpTest(test_util.TensorFlowTestCase): pruning_mode=learner_pb2.LearnerConfig.PRE_PRUNE, growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE, # Dropout does not change anything here. - dropout_probability=0.5) + dropout_probability=0.5).SerializeToString() # Center bias for the initial step. grads = constant_op.constant([0.4, -0.3]) @@ -296,7 +296,7 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): pruning_mode=learner_pb2.LearnerConfig.PRE_PRUNE, growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE, # Dropout does not change anything here, tree is not finalized. - dropout_probability=0.5) + dropout_probability=0.5).SerializeToString() # Prepare handler inputs. # Note that handlers 1 & 3 have the same gain but different splits. @@ -443,7 +443,7 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): pruning_mode=learner_pb2.LearnerConfig.PRE_PRUNE, growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE, # Dropout does not change anything here - tree is not finalized. - dropout_probability=0.5) + dropout_probability=0.5).SerializeToString() # Prepare handler inputs. # Handler 1 only has a candidate for partition 1, handler 2 has candidates @@ -632,7 +632,8 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): max_depth=1, min_node_weight=0, pruning_mode=learner_pb2.LearnerConfig.PRE_PRUNE, - growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE) + growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE).SerializeToString( + ) # Prepare handler inputs. handler1_partitions = np.array([0], dtype=np.int32) @@ -772,7 +773,8 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): max_depth=1, min_node_weight=0, pruning_mode=learner_pb2.LearnerConfig.PRE_PRUNE, - growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE) + growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE).SerializeToString( + ) # Prepare handler inputs. # All handlers have negative gain. @@ -837,7 +839,8 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): max_depth=1, min_node_weight=0, pruning_mode=learner_pb2.LearnerConfig.POST_PRUNE, - growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE) + growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE).SerializeToString( + ) # Prepare handler inputs. # Note that handlers 1 & 3 have the same gain but different splits. @@ -943,7 +946,8 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): max_depth=2, min_node_weight=0, pruning_mode=learner_pb2.LearnerConfig.POST_PRUNE, - growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE) + growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE).SerializeToString( + ) # Prepare handler inputs. # All handlers have negative gain. @@ -1090,7 +1094,8 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): max_depth=2, min_node_weight=0, pruning_mode=learner_pb2.LearnerConfig.POST_PRUNE, - growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE) + growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE).SerializeToString( + ) # Prepare handler inputs. # Second handler has positive gain. @@ -1330,7 +1335,7 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): pruning_mode=learner_pb2.LearnerConfig.PRE_PRUNE, growing_mode=learner_pb2.LearnerConfig.LAYER_BY_LAYER, # Dropout will have no effect, since the tree will not be fully grown. - dropout_probability=1.0) + dropout_probability=1.0).SerializeToString() # Prepare handler inputs. # Handler 1 only has a candidate for partition 1, handler 2 has candidates @@ -1538,7 +1543,7 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): min_node_weight=0, pruning_mode=learner_pb2.LearnerConfig.PRE_PRUNE, growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE, - dropout_probability=1.0) + dropout_probability=1.0).SerializeToString() # Prepare handler inputs. handler1_partitions = np.array([0], dtype=np.int32) @@ -1583,6 +1588,301 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): self.assertEqual( 2, tree_ensemble_config.tree_metadata[2].num_tree_weight_updates) + def testGrowExistingEnsembleTreeWithFeatureSelectionCanStillGrow(self): + """Test growing a tree with feature selection.""" + with self.test_session() as session: + # Create existing ensemble with one root split and one bias tree. + tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig() + text_format.Merge(""" + trees { + nodes { + leaf { + vector { + value: -0.32 + value: 0.28 + } + } + } + } + trees { + nodes { + categorical_id_binary_split { + feature_column: 3 + feature_id: 7 + left_id: 1 + right_id: 2 + } + node_metadata { + gain: 1.3 + } + } + nodes { + leaf { + sparse_vector { + index: 0 + value: 2.3 + } + } + } + nodes { + leaf { + sparse_vector { + index: 0 + value: -0.9 + } + } + } + } + tree_weights: 0.7 + tree_weights: 1 + tree_metadata { + num_tree_weight_updates: 1 + num_layers_grown: 1 + is_finalized: true + } + tree_metadata { + num_tree_weight_updates: 5 + num_layers_grown: 1 + is_finalized: true + } + growing_metadata { + num_trees_attempted: 2 + num_layers_attempted: 2 + used_handler_ids: 2 + used_handler_ids: 5 + } + """, tree_ensemble_config) + tree_ensemble_handle = model_ops.tree_ensemble_variable( + stamp_token=0, + tree_ensemble_config=tree_ensemble_config.SerializeToString(), + name="tree_ensemble") + resources.initialize_resources(resources.shared_resources()).run() + + # Prepare learner config. + learner_config = _gen_learner_config( + num_classes=2, + l1_reg=0, + l2_reg=0, + tree_complexity=0, + max_depth=1, + min_node_weight=0, + pruning_mode=learner_pb2.LearnerConfig.PRE_PRUNE, + growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE) + # There are 2 handler_ids in used_handler_ids already but one of them + # is handler 2, so we can still grow trees. + learner_config.constraints.max_number_of_unique_feature_columns = 2 + learner_config = learner_config.SerializeToString() + # Prepare handler inputs. + handler1_partitions = np.array([0], dtype=np.int32) + handler1_gains = np.array([7.62], dtype=np.float32) + handler1_split = [_gen_dense_split_info(5, 0.52, -4.375, 7.143)] + handler2_partitions = np.array([0], dtype=np.int32) + handler2_gains = np.array([0.63], dtype=np.float32) + handler2_split = [_gen_dense_split_info(2, 0.23, -0.6, 0.24)] + handler3_partitions = np.array([0], dtype=np.int32) + handler3_gains = np.array([7.62], dtype=np.float32) + handler3_split = [_gen_categorical_split_info(8, 7, -4.375, 7.143)] + + # Grow tree ensemble. + grow_op = training_ops.grow_tree_ensemble( + tree_ensemble_handle, + stamp_token=0, + next_stamp_token=1, + learning_rate=1, + partition_ids=[ + handler1_partitions, handler2_partitions, handler3_partitions + ], + gains=[handler1_gains, handler2_gains, handler3_gains], + splits=[handler1_split, handler2_split, handler3_split], + learner_config=learner_config, + dropout_seed=123, + center_bias=True) + session.run(grow_op) + + # Expect a new tree to be added with the split from handler 1. + _, serialized = session.run( + model_ops.tree_ensemble_serialize(tree_ensemble_handle)) + tree_ensemble_config.ParseFromString(serialized) + self.assertEqual(3, len(tree_ensemble_config.trees)) + self.assertEqual( + 2, len(tree_ensemble_config.growing_metadata.used_handler_ids)) + + def testGrowExistingEnsembleTreeWithFeatureSelectionEmptyEnsemble(self): + """Test growing a tree with feature selection with empty ensemble.""" + with self.test_session() as session: + # Create existing ensemble with one root split and one bias tree. + tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig() + tree_ensemble_handle = model_ops.tree_ensemble_variable( + stamp_token=0, + tree_ensemble_config=tree_ensemble_config.SerializeToString(), + name="tree_ensemble") + resources.initialize_resources(resources.shared_resources()).run() + + # Prepare learner config. + learner_config = _gen_learner_config( + num_classes=2, + l1_reg=0, + l2_reg=0, + tree_complexity=0, + max_depth=1, + min_node_weight=0, + pruning_mode=learner_pb2.LearnerConfig.PRE_PRUNE, + growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE) + learner_config.constraints.max_number_of_unique_feature_columns = 2 + learner_config = learner_config.SerializeToString() + # Prepare handler inputs. + handler1_partitions = np.array([0], dtype=np.int32) + handler1_gains = np.array([7.62], dtype=np.float32) + handler1_split = [_gen_dense_split_info(5, 0.52, -4.375, 7.143)] + handler2_partitions = np.array([0], dtype=np.int32) + handler2_gains = np.array([0.63], dtype=np.float32) + handler2_split = [_gen_dense_split_info(2, 0.23, -0.6, 0.24)] + handler3_partitions = np.array([0], dtype=np.int32) + handler3_gains = np.array([7.62], dtype=np.float32) + handler3_split = [_gen_categorical_split_info(8, 7, -4.375, 7.143)] + + # Grow tree ensemble. + grow_op = training_ops.grow_tree_ensemble( + tree_ensemble_handle, + stamp_token=0, + next_stamp_token=1, + learning_rate=1, + partition_ids=[ + handler1_partitions, handler2_partitions, handler3_partitions + ], + gains=[handler1_gains, handler2_gains, handler3_gains], + splits=[handler1_split, handler2_split, handler3_split], + learner_config=learner_config, + dropout_seed=123, + center_bias=True) + session.run(grow_op) + + _, serialized = session.run( + model_ops.tree_ensemble_serialize(tree_ensemble_handle)) + tree_ensemble_config.ParseFromString(serialized) + self.assertEqual(1, len(tree_ensemble_config.trees)) + self.assertEqual( + 1, len(tree_ensemble_config.growing_metadata.used_handler_ids)) + + def testGrowExistingEnsembleTreeWithFeatureSelectionCantGrow(self): + """Test growing a tree with feature selection with empty ensemble.""" + with self.test_session() as session: + # Create existing ensemble with one root split and one bias tree. + tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig() + text_format.Merge(""" + trees { + nodes { + leaf { + vector { + value: -0.32 + value: 0.28 + } + } + } + } + trees { + nodes { + categorical_id_binary_split { + feature_column: 3 + feature_id: 7 + left_id: 1 + right_id: 2 + } + node_metadata { + gain: 1.3 + } + } + nodes { + leaf { + sparse_vector { + index: 0 + value: 2.3 + } + } + } + nodes { + leaf { + sparse_vector { + index: 0 + value: -0.9 + } + } + } + } + tree_weights: 0.7 + tree_weights: 1 + tree_metadata { + num_tree_weight_updates: 1 + num_layers_grown: 1 + is_finalized: true + } + tree_metadata { + num_tree_weight_updates: 5 + num_layers_grown: 1 + is_finalized: true + } + growing_metadata { + num_trees_attempted: 2 + num_layers_attempted: 2 + used_handler_ids: 4 + used_handler_ids: 5 + } + """, tree_ensemble_config) + tree_ensemble_handle = model_ops.tree_ensemble_variable( + stamp_token=0, + tree_ensemble_config=tree_ensemble_config.SerializeToString(), + name="tree_ensemble") + resources.initialize_resources(resources.shared_resources()).run() + + # Prepare learner config. + learner_config = _gen_learner_config( + num_classes=2, + l1_reg=0, + l2_reg=0, + tree_complexity=0, + max_depth=1, + min_node_weight=0, + pruning_mode=learner_pb2.LearnerConfig.PRE_PRUNE, + growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE) + learner_config.constraints.max_number_of_unique_feature_columns = 2 + learner_config = learner_config.SerializeToString() + # Prepare handler inputs. + handler1_partitions = np.array([0], dtype=np.int32) + handler1_gains = np.array([7.62], dtype=np.float32) + handler1_split = [_gen_dense_split_info(5, 0.52, -4.375, 7.143)] + handler2_partitions = np.array([0], dtype=np.int32) + handler2_gains = np.array([0.63], dtype=np.float32) + handler2_split = [_gen_dense_split_info(2, 0.23, -0.6, 0.24)] + handler3_partitions = np.array([0], dtype=np.int32) + handler3_gains = np.array([7.62], dtype=np.float32) + handler3_split = [_gen_categorical_split_info(8, 7, -4.375, 7.143)] + + # Grow tree ensemble. + grow_op = training_ops.grow_tree_ensemble( + tree_ensemble_handle, + stamp_token=0, + next_stamp_token=1, + learning_rate=1, + partition_ids=[ + handler1_partitions, handler2_partitions, handler3_partitions + ], + gains=[handler1_gains, handler2_gains, handler3_gains], + splits=[handler1_split, handler2_split, handler3_split], + learner_config=learner_config, + dropout_seed=123, + center_bias=True) + session.run(grow_op) + + _, serialized = session.run( + model_ops.tree_ensemble_serialize(tree_ensemble_handle)) + tree_ensemble_config.ParseFromString(serialized) + # We can't grow a tree since we have reached the limit of 2 unique + # features [4, 5] and the only available splits are from + # handlers [0, 1, 2]. + self.assertEqual(2, len(tree_ensemble_config.trees)) + self.assertEqual( + 2, len(tree_ensemble_config.growing_metadata.used_handler_ids)) + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py b/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py index 7e8e15e7d8c89d1adaa472b1da7e8bb3c73ca17e..294e04002adac62fc123a3242a05a1b36f422433 100644 --- a/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py +++ b/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py @@ -45,6 +45,7 @@ class QuantileAccumulator(saver.BaseSaverBuilder.SaveableObject): init_stamp_token, epsilon, num_quantiles, + max_elements=None, name=None, container=None): """Creates a QuantileAccumulator object. @@ -53,6 +54,7 @@ class QuantileAccumulator(saver.BaseSaverBuilder.SaveableObject): init_stamp_token: The initial value for the stamp token. epsilon: Error bound on the quantile computation. num_quantiles: Number of quantiles to produce from the final summary. + max_elements: Maximum number of elements added to the accumulator. name: the name to save the accumulator under. container: An optional `string`. Defaults to `""` """ @@ -67,6 +69,7 @@ class QuantileAccumulator(saver.BaseSaverBuilder.SaveableObject): self._quantile_accumulator_handle, init_stamp_token, epsilon=epsilon, + max_elements=max_elements, num_quantiles=num_quantiles) is_initialized_op = gen_quantile_ops.quantile_accumulator_is_initialized( self._quantile_accumulator_handle) diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py index 6094dae6b59d8b05bb12a28cf167a536e6825287..b95956dae2a62b28643cd31815c5f5650eca337b 100644 --- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py +++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py @@ -322,9 +322,11 @@ class GradientBoostedDecisionTreeModel(object): self._feature_columns = feature_columns self._learner_config_serialized = learner_config.SerializeToString() self._attempted_trees = variables.Variable( - initial_value=array_ops.zeros([], dtypes.int64), trainable=False) + initial_value=array_ops.zeros([], dtypes.int64), trainable=False, + name="attempted_trees") self._finalized_trees = variables.Variable( - initial_value=array_ops.zeros([], dtypes.int64), trainable=False) + initial_value=array_ops.zeros([], dtypes.int64), trainable=False, + name="finalized_trees") if not features: raise ValueError("Features dictionary must be specified.") (fc_names, dense_floats, sparse_float_indices, sparse_float_values, diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py index 16e24d97ddee0751e0b808b89080074c1b4baba7..dba51d4f527792d2a8dedc693f74c07119fd231d 100644 --- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py +++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py @@ -912,8 +912,10 @@ class GbdtTest(test_util.TensorFlowTestCase): self.assertEqual(1, len(output.trees[0].nodes[2].leaf.sparse_vector.index)) self.assertEqual(3, output.trees[0].nodes[2].leaf.sparse_vector.index[0]) - self.assertAlmostEqual( - 0.893284678459, output.trees[0].nodes[2].leaf.sparse_vector.value[0]) + self.assertAllClose( + 0.893284678459, + output.trees[0].nodes[2].leaf.sparse_vector.value[0], + atol=1e-4, rtol=1e-4) if __name__ == "__main__": diff --git a/tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h b/tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h index 284ad5cdb9abf374650940ade7bb36663d72c0dd..ad9c8961aaadbc4c1ff6bdc7793171d0ad48d75f 100644 --- a/tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h +++ b/tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h @@ -111,6 +111,35 @@ class DecisionTreeEnsembleResource : public StampedResource { return decision_tree_ensemble_->tree_weights(index); } + void MaybeAddUsedHandler(const int32 handler_id) { + protobuf::RepeatedField* used_ids = + decision_tree_ensemble_->mutable_growing_metadata() + ->mutable_used_handler_ids(); + protobuf::RepeatedField::iterator first = + std::lower_bound(used_ids->begin(), used_ids->end(), handler_id); + if (first == used_ids->end()) { + used_ids->Add(handler_id); + return; + } + if (handler_id == *first) { + // It is a duplicate entry. + return; + } + used_ids->Add(handler_id); + std::rotate(first, used_ids->end() - 1, used_ids->end()); + } + + std::vector GetUsedHandlers() const { + std::vector result; + result.reserve( + decision_tree_ensemble_->growing_metadata().used_handler_ids().size()); + for (int64 h : + decision_tree_ensemble_->growing_metadata().used_handler_ids()) { + result.push_back(h); + } + return result; + } + // Sets the weight of i'th tree, and increment num_updates in tree_metadata. void SetTreeWeight(const int32 index, const float weight, const int32 increment_num_updates) { diff --git a/tensorflow/contrib/cloud/BUILD b/tensorflow/contrib/cloud/BUILD index aa8f5ed12bc6f779e3c1a923b9225ec283189747..fe8bd072afd43a64fa62a65bd8900b5a98dbe761 100644 --- a/tensorflow/contrib/cloud/BUILD +++ b/tensorflow/contrib/cloud/BUILD @@ -60,9 +60,7 @@ tf_py_test( size = "small", srcs = ["python/ops/bigquery_reader_ops_test.py"], additional_deps = [ - ":bigquery_reader_ops_op_lib", ":cloud_py", - "//tensorflow/contrib/cloud/kernels:bigquery_reader_ops", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", diff --git a/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.cc b/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.cc index 51821f6653550afd2d2e8a49b7337ff8ba0b5489..deb324634b6edc17c9725996115d80c5bd11cbde 100644 --- a/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.cc +++ b/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.cc @@ -202,22 +202,21 @@ Status BigQueryTableAccessor::ReadRow(int64* row_id, Example* example) { std::unique_ptr request(http_request_factory_->Create()); std::vector output_buffer; output_buffer.reserve(kBufferSize); - TF_RETURN_IF_ERROR(request->Init()); // The first time that we access BigQuery there is no page token. After that // we use the page token (which returns rows faster). if (!next_page_token_.empty()) { - TF_RETURN_IF_ERROR(request->SetUri(strings::StrCat( + request->SetUri(strings::StrCat( BigQueryUriPrefix(), "data?maxResults=", ComputeMaxResultsArg(), - "&pageToken=", request->EscapeString(next_page_token_)))); + "&pageToken=", request->EscapeString(next_page_token_))); first_buffered_row_index_ += row_buffer_.size(); } else { - TF_RETURN_IF_ERROR(request->SetUri(strings::StrCat( + request->SetUri(strings::StrCat( BigQueryUriPrefix(), "data?maxResults=", ComputeMaxResultsArg(), - "&startIndex=", first_buffered_row_index_))); + "&startIndex=", first_buffered_row_index_)); } - TF_RETURN_IF_ERROR(request->AddAuthBearerHeader(auth_token)); - TF_RETURN_IF_ERROR(request->SetResultBuffer(&output_buffer)); + request->AddAuthBearerHeader(auth_token); + request->SetResultBuffer(&output_buffer); TF_RETURN_WITH_CONTEXT_IF_ERROR(request->Send(), " when reading rows from ", FullTableName()); @@ -293,10 +292,9 @@ Status BigQueryTableAccessor::ReadSchema() { std::unique_ptr request(http_request_factory_->Create()); std::vector output_buffer; output_buffer.reserve(kBufferSize); - TF_RETURN_IF_ERROR(request->Init()); - TF_RETURN_IF_ERROR(request->SetUri(BigQueryUriPrefix())); - TF_RETURN_IF_ERROR(request->AddAuthBearerHeader(auth_token)); - TF_RETURN_IF_ERROR(request->SetResultBuffer(&output_buffer)); + request->SetUri(BigQueryUriPrefix()); + request->AddAuthBearerHeader(auth_token); + request->SetResultBuffer(&output_buffer); TF_RETURN_WITH_CONTEXT_IF_ERROR(request->Send(), " when reading schema for ", FullTableName()); diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py index c74da9cabd6816bc9c7891e32937534cff2d677d..2e75ac226ea74e879edda5e03dff3d53c8a76569 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py +++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py @@ -18,6 +18,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function + +from six.moves.urllib.request import Request +from six.moves.urllib.request import urlopen + from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import ClusterResolver from tensorflow.python.training.server_lib import ClusterSpec @@ -38,10 +42,16 @@ class TPUClusterResolver(ClusterResolver): Cloud Platform project. """ + def _requestComputeMetadata(self, path): + req = Request('http://metadata/computeMetadata/v1/%s' % path, + headers={'Metadata-Flavor': 'Google'}) + resp = urlopen(req) + return resp.read() + def __init__(self, - project, - zone, tpu_names, + zone=None, + project=None, job_name='tpu_worker', credentials='default', service=None): @@ -51,9 +61,13 @@ class TPUClusterResolver(ClusterResolver): for the IP addresses and ports of each Cloud TPU listed. Args: - project: Name of the GCP project containing Cloud TPUs - zone: Zone where the TPUs are located tpu_names: A list of names of the target Cloud TPUs. + zone: Zone where the TPUs are located. If omitted or empty, we will assume + that the zone of the TPU is the same as the zone of the GCE VM, which we + will try to discover from the GCE metadata service. + project: Name of the GCP project containing Cloud TPUs. If omitted or + empty, we will try to discover the project name of the GCE VM from the + GCE metadata service. job_name: Name of the TensorFlow job the TPUs belong to. credentials: GCE Credentials. If None, then we use default credentials from the oauth2client @@ -65,6 +79,13 @@ class TPUClusterResolver(ClusterResolver): ImportError: If the googleapiclient is not installed. """ + if not project: + project = self._requestComputeMetadata('/project/project-id') + + if not zone: + zone_path = self._requestComputeMetadata('/instance/zone') + zone = zone_path.split('/')[-1] + self._project = project self._zone = zone self._tpu_names = tpu_names diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py index db7419be06b58e1c5737f69f2c7fd9fee44b9d95..0c4730613af4ad9ca87deb6200ab4bb93d3f6a53 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py +++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py @@ -48,6 +48,15 @@ class MockNodeClass(object): return MockRequestClass(name, self._tpu_map) +def mock_request_compute_metadata(cls, *args, **kwargs): + del cls, kwargs # Unused. + if args[0] == '/project/project-id': + return 'test-project' + elif args[0] == '/instance/zone': + return 'projects/test-project/locations/us-central1-c' + return '' + + class TPUClusterResolverTest(test.TestCase): def _verifyClusterSpecEquality(self, cluster_spec, expected_proto): @@ -89,6 +98,30 @@ class TPUClusterResolverTest(test.TestCase): return mock_client + @mock.patch.object(TPUClusterResolver, + '_requestComputeMetadata', + mock_request_compute_metadata) + def testRetrieveProjectAndZoneFromMetadata(self): + tpu_map = { + 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': { + 'ipAddress': '10.1.2.3', + 'port': '8470' + } + } + + tpu_cluster_resolver = TPUClusterResolver( + project=None, + zone=None, + tpu_names=['test-tpu-1'], + credentials=None, + service=self.mock_service_client(tpu_map=tpu_map)) + + actual_cluster_spec = tpu_cluster_resolver.cluster_spec() + expected_proto = """ + job { name: 'tpu_worker' tasks { key: 0 value: '10.1.2.3:8470' } } + """ + self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto) + def testSimpleSuccessfulRetrieval(self): tpu_map = { 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': { diff --git a/tensorflow/contrib/cmake/CMakeLists.txt b/tensorflow/contrib/cmake/CMakeLists.txt index ba708673b0d562f928230f427406147ab22f0007..817e96f5da0e7512a9fd99cc9a4b4c6025d7dd68 100644 --- a/tensorflow/contrib/cmake/CMakeLists.txt +++ b/tensorflow/contrib/cmake/CMakeLists.txt @@ -18,7 +18,6 @@ cmake_policy(SET CMP0022 NEW) # Options option(tensorflow_VERBOSE "Enable for verbose output" OFF) -option(tensorflow_ENABLE_GPU "Enable GPU support" OFF) option(tensorflow_ENABLE_SSL_SUPPORT "Enable boringssl support" OFF) option(tensorflow_ENABLE_GRPC_SUPPORT "Enable gRPC support" ON) option(tensorflow_ENABLE_HDFS_SUPPORT "Enable HDFS support" OFF) @@ -34,6 +33,13 @@ option(tensorflow_BUILD_SHARED_LIB "Build TensorFlow as a shared library" OFF) option(tensorflow_OPTIMIZE_FOR_NATIVE_ARCH "Enable compiler optimizations for the native processor architecture (if available)" ON) option(tensorflow_WIN_CPU_SIMD_OPTIONS "Enables CPU SIMD instructions") option(tensorflow_ENABLE_SNAPPY_SUPPORT "Enable SNAPPY compression support" ON) +option(tensorflow_DISABLE_EIGEN_FORCEINLINE "Disable forceinline, to speed up build on windows." OFF) + +# GPU, CUDA and cuDNN options +option(tensorflow_ENABLE_GPU "Enable GPU support" OFF) +set(tensorflow_CUDA_VERSION "9.0" CACHE STRING "CUDA version to build against") +set(tensorflow_CUDNN_VERSION "7" CACHE STRING "cuDNN version to build against") + if(HAIKU) option(tensorflow_ENABLE_POSITION_INDEPENDENT_CODE "Enable PIE support" OFF) else() @@ -53,7 +59,15 @@ if (NOT WIN32) set(tensorflow_CUDNN_INCLUDE /usr/include) endif (NOT tensorflow_CUDNN_INCLUDE) option(tensorflow_PATH_CUDNN_STATIC_LIB "Override PATH_STATIC_LIB for libcudnn_static.a" ${tensorflow_PATH_STATIC_LIB}) + if (NOT tensorflow_PATH_CUDNN_STATIC_LIB) + # option's default value is OFF. Fill it with real default values + set (tensorflow_PATH_CUDNN_STATIC_LIB ${tensorflow_PATH_STATIC_LIB}) + endif (NOT tensorflow_PATH_CUDNN_STATIC_LIB) option(tensorflow_PATH_NCCL_STATIC_LIB "Override PATH_STATIC_LIB for libnccl_static.a" ${tensorflow_PATH_STATIC_LIB}) + if (NOT tensorflow_PATH_NCCL_STATIC_LIB) + # option's default value is OFF. Fill it with real default values + set (tensorflow_PATH_NCCL_STATIC_LIB ${tensorflow_PATH_STATIC_LIB}) + endif (NOT tensorflow_PATH_NCCL_STATIC_LIB) option(tensorflow_CUDA_LIBRARY_PATH "Designate the default CUDA library paths" /usr/local/cuda/lib64) if (NOT tensorflow_CUDA_LIBRARY_PATH) # option's default value is OFF. Fill it with real default values @@ -92,6 +106,13 @@ else() set(CMAKE_POSITION_INDEPENDENT_CODE OFF) endif() +# TODO(jart): We should make this only apply to snapfn.cc +add_definitions(-DSQLITE_OMIT_LOAD_EXTENSION) + +if (tensorflow_DISABLE_EIGEN_FORCEINLINE) + add_definitions(-DEIGEN_STRONG_INLINE=inline) +endif() + add_definitions(-DEIGEN_AVOID_STL_ARRAY) if(WIN32) add_definitions(-DNOMINMAX -D_WIN32_WINNT=0x0A00 -DLANG_CXX11 -DCOMPILER_MSVC) @@ -160,7 +181,6 @@ include(protobuf) include(re2) include(cub) include(sqlite) -include(double_conversion) if (tensorflow_BUILD_CC_TESTS) include(googletest) endif() @@ -179,7 +199,6 @@ set(tensorflow_EXTERNAL_LIBRARIES ${protobuf_STATIC_LIBRARIES} ${re2_STATIC_LIBRARIES} ${sqlite_STATIC_LIBRARIES} - ${double_conversion_STATIC_LIBRARIES} ) set(tensorflow_EXTERNAL_DEPENDENCIES zlib_copy_headers_to_destination @@ -198,7 +217,6 @@ set(tensorflow_EXTERNAL_DEPENDENCIES fft2d re2 sqlite_copy_headers_to_destination - double_conversion ) include_directories( @@ -221,7 +239,6 @@ include_directories( ${PROTOBUF_INCLUDE_DIRS} ${re2_INCLUDE_DIR} ${sqlite_INCLUDE_DIR} - ${double_conversion_INCLUDE_DIR} ) if(tensorflow_ENABLE_SSL_SUPPORT) @@ -266,7 +283,7 @@ if (tensorflow_ENABLE_GPU) list(APPEND CMAKE_LIBRARY_PATH "${tensorflow_CUDA_LIBRARY_PATH}/stubs") endif (NOT WIN32) - find_package(CUDA 8.0 REQUIRED) + find_package(CUDA ${tensorflow_CUDA_VERSION} REQUIRED) # by default we assume compute cabability 3.5 and 5.2. If you change this change it in # CUDA_NVCC_FLAGS and cuda_config.h below @@ -320,13 +337,16 @@ if (tensorflow_ENABLE_GPU) ${CUDA_curand_LIBRARY} ${CUDA_cupti_LIBRARY} ${CUDA_cusolver_LIBRARY} ${cudnn_STATIC_LIBRARY} ${culibos_STATIC_LIBRARY} ${nccl_STATIC_LIBRARY}) endif (WIN32) + # Remove "." from CUDA version variable. + string(REPLACE "." "" short_CUDA_VER ${tensorflow_CUDA_VERSION}) + # create cuda_config.h FILE(WRITE ${tensorflow_source_dir}/third_party/gpus/cuda/cuda_config.h "#ifndef CUDA_CUDA_CONFIG_H_\n" "#define CUDA_CUDA_CONFIG_H_\n" "#define TF_CUDA_CAPABILITIES CudaVersion(\"3.0\"),CudaVersion(\"3.5\"),CudaVersion(\"5.2\")\n" - "#define TF_CUDA_VERSION \"64_80\"\n" - "#define TF_CUDNN_VERSION \"64_6\"\n" + "#define TF_CUDA_VERSION \"64_${short_CUDA_VER}\"\n" + "#define TF_CUDNN_VERSION \"64_${tensorflow_CUDNN_VERSION}\"\n" "#define TF_CUDA_TOOLKIT_PATH \"${CUDA_TOOLKIT_ROOT_DIR}\"\n" "#endif // CUDA_CUDA_CONFIG_H_\n" ) @@ -364,15 +384,15 @@ if (tensorflow_ENABLE_GPU) if(WIN32) set(tensorflow_BUILD_INFO_FLAGS --build_config cuda --key_value msvcp_dll_name=msvcp140.dll - cudart_dll_name=cudart64_80.dll - cuda_version_number=8.0 + cudart_dll_name=cudart64_${short_CUDA_VER}.dll + cuda_version_number=${tensorflow_CUDA_VERSION} nvcuda_dll_name=nvcuda.dll - cudnn_dll_name=cudnn64_6.dll - cudnn_version_number=6) + cudnn_dll_name=cudnn64_${tensorflow_CUDNN_VERSION}.dll + cudnn_version_number=${tensorflow_CUDNN_VERSION}) else(WIN32) set(tensorflow_BUILD_INFO_FLAGS --build_config cuda --key_value - cuda_version_number=8.0 - cudnn_version_number=6) + cuda_version_number=${tensorflow_CUDA_VERSION} + cudnn_version_number=${tensorflow_CUDNN_VERSION}) endif(WIN32) else(tensorflow_ENABLE_GPU) set(tensorflow_BUILD_INFO_FLAGS --build_config cpu --key_value @@ -387,10 +407,8 @@ endif() # Let's get to work! include(tf_core_framework.cmake) -# NOTE: Disabled until issue #3996 is fixed. -# include(tf_stream_executor.cmake) if (tensorflow_ENABLE_GPU) - include(tf_stream_executor.cmake) + include(tf_stream_executor.cmake) endif() include(tf_core_cpu.cmake) diff --git a/tensorflow/contrib/cmake/README.md b/tensorflow/contrib/cmake/README.md index 4ddfec5960d2b759bacb376202cd8dab6ef2b024..8f85a75ee466dbac524a1266dc2522109ca77cd5 100644 --- a/tensorflow/contrib/cmake/README.md +++ b/tensorflow/contrib/cmake/README.md @@ -19,23 +19,6 @@ for instructions on how to install a pre-built TensorFlow package on Windows. ### Current known limitations * It is not possible to load a custom Op library. * GCS file system is not supported. -* The following Ops are not currently implemented: - - Dequantize - - QuantizeAndDequantize - - QuantizedAvgPool - - QuantizedBatchNomWithGlobalNormalization - - QuantizedBiasAdd - - QuantizedConcat - - QuantizedConv2D - - QuantizedMatmul - - QuantizedMaxPoo - - QuantizeDownAndShrinkRange - - QuantizedRelu - - QuantizedRelu6 - - QuantizedReshape - - QuantizeV2 - - RequantizationRange - - Requantize ## Building with CMake @@ -47,7 +30,7 @@ bindings. * CMake version 3.5 or later. -* [Git](http://git-scm.com) +* [Git](https://git-scm.com) * [SWIG](http://www.swig.org/download.html) @@ -65,7 +48,7 @@ bindings. * Microsoft Windows 10 - Microsoft Visual Studio Enterprise 2015 with Visual C++ 2015 - - [Anaconda 4.1.1 (Python 3.5 64-bit)](https://www.continuum.io/downloads) + - [Anaconda 4.1.1 (Python 3.5 64-bit)](https://www.anaconda.com/download/) - [Git for Windows version 2.9.2.windows.1](https://git-scm.com/download/win) - [swigwin-3.0.10](http://www.swig.org/download.html) - [NVidia CUDA Toolkit 8.0](https://developer.nvidia.com/cuda-downloads) diff --git a/tensorflow/contrib/cmake/external/boringssl.cmake b/tensorflow/contrib/cmake/external/boringssl.cmake index cca8444e2ae9952ea7c69a9392580ead715d363b..5ad477fdff68feab4adf0c0072c68c8e55390ab8 100644 --- a/tensorflow/contrib/cmake/external/boringssl.cmake +++ b/tensorflow/contrib/cmake/external/boringssl.cmake @@ -39,11 +39,7 @@ ExternalProject_Add(boringssl # BUILD_IN_SOURCE 1 INSTALL_COMMAND "" CMAKE_CACHE_ARGS - if(tensorflow_ENABLE_POSITION_INDEPENDENT_CODE) - -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON - else() - -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=OFF - endif() + -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=${tensorflow_ENABLE_POSITION_INDEPENDENT_CODE} -DCMAKE_BUILD_TYPE:STRING=Release -DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF ) diff --git a/tensorflow/contrib/cmake/external/double_conversion.cmake b/tensorflow/contrib/cmake/external/double_conversion.cmake deleted file mode 100644 index 527ccdc8d887cb4c2e7d2412c99a8bc682568472..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/cmake/external/double_conversion.cmake +++ /dev/null @@ -1,54 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -include (ExternalProject) - -set(double_conversion_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/double_conversion/src/double_conversion) -set(double_conversion_URL https://github.com/google/double-conversion.git) -set(double_conversion_TAG 5664746) -set(double_conversion_BUILD ${double_conversion_INCLUDE_DIR}) -set(double_conversion_LIBRARIES ${double_conversion_BUILD}/double-conversion/libdouble-conversion.so) -set(double_conversion_INCLUDES ${double_conversion_BUILD}) - -if(WIN32) - set(double_conversion_STATIC_LIBRARIES ${double_conversion_BUILD}/double-conversion/$(Configuration)/double-conversion.lib) -else() - set(double_conversion_STATIC_LIBRARIES ${double_conversion_BUILD}/double-conversion/libdouble-conversion.a) -endif() - -set(double_conversion_HEADERS - "${double_conversion_INCLUDE_DIR}/double-conversion/bignum-dtoa.h" - "${double_conversion_INCLUDE_DIR}/double-conversion/cached-powers.h" - "${double_conversion_INCLUDE_DIR}/double-conversion/double-conversion.h" - "${double_conversion_INCLUDE_DIR}/double-conversion/fixed-dtoa.h" - "${double_conversion_INCLUDE_DIR}/double-conversion/strtod.h" - "${double_conversion_INCLUDE_DIR}/double-conversion/bignum.h" - "${double_conversion_INCLUDE_DIR}/double-conversion/diy-fp.h" - "${double_conversion_INCLUDE_DIR}/double-conversion/fast-dtoa.h" - "${double_conversion_INCLUDE_DIR}/double-conversion/ieee.h" - "${double_conversion_INCLUDE_DIR}/double-conversion/utils.h" -) - -ExternalProject_Add(double_conversion - PREFIX double_conversion - GIT_REPOSITORY ${double_conversion_URL} - GIT_TAG ${double_conversion_TAG} - DOWNLOAD_DIR "${DOWNLOAD_LOCATION}" - BUILD_IN_SOURCE 1 - INSTALL_COMMAND "" - CMAKE_CACHE_ARGS - -DCMAKE_BUILD_TYPE:STRING=Release - -DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF - -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON -) diff --git a/tensorflow/contrib/cmake/external/gemmlowp.cmake b/tensorflow/contrib/cmake/external/gemmlowp.cmake index 3b146657bfc9bdd54db14839195af45972e67aff..a235442dc5c0a07e249653381436eeae81575883 100644 --- a/tensorflow/contrib/cmake/external/gemmlowp.cmake +++ b/tensorflow/contrib/cmake/external/gemmlowp.cmake @@ -14,8 +14,8 @@ # ============================================================================== include (ExternalProject) -set(gemmlowp_URL https://mirror.bazel.build/github.com/google/gemmlowp/archive/010bb3e71a26ca1d0884a167081d092b43563996.zip) -set(gemmlowp_HASH SHA256=dd2557072bde12141419cb8320a9c25e6ec41a8ae53c2ac78c076a347bb46d9d) +set(gemmlowp_URL https://github.com/google/gemmlowp/archive/6a2a90822e8546fc2bfa7044de0faf1c1cb4862f.zip) +set(gemmlowp_HASH SHA256=3447948d219f3270383766bbe08942888c0eb4e0ca6663c0e0548502ec5bb77d) set(gemmlowp_BUILD ${CMAKE_CURRENT_BINARY_DIR}/gemmlowp/src/gemmlowp) set(gemmlowp_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/gemmlowp/src/gemmlowp) diff --git a/tensorflow/contrib/cmake/external/grpc.cmake b/tensorflow/contrib/cmake/external/grpc.cmake index 41ea0b48a4600d7ca2dd2f4a61c14ec0cc5b4734..28adb4fe84423bb5a21c78dac4e757505ce87d1d 100644 --- a/tensorflow/contrib/cmake/external/grpc.cmake +++ b/tensorflow/contrib/cmake/external/grpc.cmake @@ -17,7 +17,7 @@ include (ExternalProject) set(GRPC_INCLUDE_DIRS ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/include) set(GRPC_URL https://github.com/grpc/grpc.git) set(GRPC_BUILD ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc) -set(GRPC_TAG 54e8f37e537794c2d814c1604c1282125f64f093) +set(GRPC_TAG 730b778632e79cc3c96ad237f282d687ee325ce7) if(WIN32) set(grpc_STATIC_LIBRARIES diff --git a/tensorflow/contrib/cmake/external/jsoncpp.cmake b/tensorflow/contrib/cmake/external/jsoncpp.cmake index d2ae4c76e8cd175cdc3ba41fdf4e4009f8237309..861201f97edbce2d9d70a833ce5a8cad46f2470a 100644 --- a/tensorflow/contrib/cmake/external/jsoncpp.cmake +++ b/tensorflow/contrib/cmake/external/jsoncpp.cmake @@ -42,11 +42,7 @@ ExternalProject_Add(jsoncpp BUILD_IN_SOURCE 1 INSTALL_COMMAND "" CMAKE_CACHE_ARGS - if(tensorflow_ENABLE_POSITION_INDEPENDENT_CODE) - -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON - else() - -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=OFF - endif() + -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=${tensorflow_ENABLE_POSITION_INDEPENDENT_CODE} -DCMAKE_BUILD_TYPE:STRING=Release -DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF ) diff --git a/tensorflow/contrib/cmake/external/lmdb.cmake b/tensorflow/contrib/cmake/external/lmdb.cmake index e41384f023ca9fc4cba697917b491af5a9db92bc..41b314e2857577581eb27eb6c6480b757d0b436c 100644 --- a/tensorflow/contrib/cmake/external/lmdb.cmake +++ b/tensorflow/contrib/cmake/external/lmdb.cmake @@ -29,11 +29,7 @@ ExternalProject_Add(lmdb INSTALL_DIR ${lmdb_INSTALL} DOWNLOAD_DIR "${DOWNLOAD_LOCATION}" CMAKE_CACHE_ARGS - if(tensorflow_ENABLE_POSITION_INDEPENDENT_CODE) - -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON - else() - -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=OFF - endif() + -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=${tensorflow_ENABLE_POSITION_INDEPENDENT_CODE} -DCMAKE_BUILD_TYPE:STRING=Release -DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF -DCMAKE_INSTALL_PREFIX:STRING=${lmdb_INSTALL} diff --git a/tensorflow/contrib/cmake/external/nsync.cmake b/tensorflow/contrib/cmake/external/nsync.cmake index 155c91cb97dbe5ef33c318efb5544a9fa22166c7..05080060479b6240edb8ab9f65160b3dd182feb9 100644 --- a/tensorflow/contrib/cmake/external/nsync.cmake +++ b/tensorflow/contrib/cmake/external/nsync.cmake @@ -16,7 +16,7 @@ include (ExternalProject) set(nsync_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/external/nsync/public) set(nsync_URL https://github.com/google/nsync) -set(nsync_TAG 93815892dddafe9146a5f7e7042281d59d0f4323) +set(nsync_TAG 8502189abfa44c249c01c2cad64e6ed660a9a668) set(nsync_BUILD ${CMAKE_CURRENT_BINARY_DIR}/nsync/src/nsync) set(nsync_INSTALL ${CMAKE_CURRENT_BINARY_DIR}/nsync/install) diff --git a/tensorflow/contrib/cmake/external/png.cmake b/tensorflow/contrib/cmake/external/png.cmake index aad6618f52f909096fd2388e867ef3a965d033cb..b277be5690387b06876ca89eb88becbf885486a4 100644 --- a/tensorflow/contrib/cmake/external/png.cmake +++ b/tensorflow/contrib/cmake/external/png.cmake @@ -41,11 +41,7 @@ ExternalProject_Add(png INSTALL_DIR ${png_INSTALL} DOWNLOAD_DIR "${DOWNLOAD_LOCATION}" CMAKE_CACHE_ARGS - if(tensorflow_ENABLE_POSITION_INDEPENDENT_CODE) - -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON - else() - -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=OFF - endif() + -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=${tensorflow_ENABLE_POSITION_INDEPENDENT_CODE} -DCMAKE_BUILD_TYPE:STRING=Release -DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF -DCMAKE_INSTALL_PREFIX:STRING=${png_INSTALL} diff --git a/tensorflow/contrib/cmake/external/protobuf.cmake b/tensorflow/contrib/cmake/external/protobuf.cmake index b53857a47bfbf797af02fe7f69474263119161cd..aedb793d2aef4bf6950cd074cd065909667eaf75 100644 --- a/tensorflow/contrib/cmake/external/protobuf.cmake +++ b/tensorflow/contrib/cmake/external/protobuf.cmake @@ -44,11 +44,7 @@ ExternalProject_Add(protobuf ${PROTOBUF_ADDITIONAL_CMAKE_OPTIONS} INSTALL_COMMAND "" CMAKE_CACHE_ARGS - if(tensorflow_ENABLE_POSITION_INDEPENDENT_CODE) - -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON - else() - -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=OFF - endif() + -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=${tensorflow_ENABLE_POSITION_INDEPENDENT_CODE} -DCMAKE_BUILD_TYPE:STRING=Release -DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF -DZLIB_ROOT:STRING=${ZLIB_INSTALL} diff --git a/tensorflow/contrib/cmake/external/re2.cmake b/tensorflow/contrib/cmake/external/re2.cmake index d10f5959f71dd350e6e2bcb81be8882b203fb231..371d8447f93735e7af2a5a2b16f128a47b5a082a 100644 --- a/tensorflow/contrib/cmake/external/re2.cmake +++ b/tensorflow/contrib/cmake/external/re2.cmake @@ -38,11 +38,7 @@ ExternalProject_Add(re2 BUILD_IN_SOURCE 1 DOWNLOAD_DIR "${DOWNLOAD_LOCATION}" CMAKE_CACHE_ARGS - if(tensorflow_ENABLE_POSITION_INDEPENDENT_CODE) - -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON - else() - -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=OFF - endif() + -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=${tensorflow_ENABLE_POSITION_INDEPENDENT_CODE} -DCMAKE_BUILD_TYPE:STRING=Release -DCMAKE_INSTALL_PREFIX:STRING=${re2_INSTALL} -DRE2_BUILD_TESTING:BOOL=OFF diff --git a/tensorflow/contrib/cmake/external/snappy.cmake b/tensorflow/contrib/cmake/external/snappy.cmake index 926c271fd9ea6e2a30251aa408bd49859ae95070..013b3a862f13fd9017fade500d391ecc2bd27fae 100644 --- a/tensorflow/contrib/cmake/external/snappy.cmake +++ b/tensorflow/contrib/cmake/external/snappy.cmake @@ -40,11 +40,7 @@ ExternalProject_Add(snappy LOG_CONFIGURE ON LOG_BUILD ON CMAKE_CACHE_ARGS - if(tensorflow_ENABLE_POSITION_INDEPENDENT_CODE) - -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON - else() - -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=OFF - endif() + -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=${tensorflow_ENABLE_POSITION_INDEPENDENT_CODE} -DCMAKE_BUILD_TYPE:STRING=Release -DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF -DSNAPPY_BUILD_TESTS:BOOL=OFF diff --git a/tensorflow/contrib/cmake/external/sqlite.cmake b/tensorflow/contrib/cmake/external/sqlite.cmake index 785039a46983747557607562675349c150e064ad..8297c60712c49ed6f47a9750691eee1325a5b55e 100644 --- a/tensorflow/contrib/cmake/external/sqlite.cmake +++ b/tensorflow/contrib/cmake/external/sqlite.cmake @@ -28,6 +28,7 @@ endif() set(sqlite_HEADERS "${sqlite_BUILD}/sqlite3.h" + "${sqlite_BUILD}/sqlite3ext.h" ) if (WIN32) @@ -53,11 +54,7 @@ else() INSTALL_DIR ${sqlite_INSTALL} DOWNLOAD_DIR "${DOWNLOAD_LOCATION}" CMAKE_CACHE_ARGS - if(tensorflow_ENABLE_POSITION_INDEPENDENT_CODE) - -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON - else() - -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=OFF - endif() + -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=${tensorflow_ENABLE_POSITION_INDEPENDENT_CODE} -DCMAKE_BUILD_TYPE:STRING=Release -DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF -DCMAKE_INSTALL_PREFIX:STRING=${sqlite_INSTALL} diff --git a/tensorflow/contrib/cmake/external/zlib.cmake b/tensorflow/contrib/cmake/external/zlib.cmake index f10f84336e8b1c0a2c7de7ea1f8b8af7c21f8b51..5bec14fb00a50f6e6e8c7d8b703bde681e9d02ae 100644 --- a/tensorflow/contrib/cmake/external/zlib.cmake +++ b/tensorflow/contrib/cmake/external/zlib.cmake @@ -42,11 +42,7 @@ ExternalProject_Add(zlib BUILD_IN_SOURCE 1 DOWNLOAD_DIR "${DOWNLOAD_LOCATION}" CMAKE_CACHE_ARGS - if(tensorflow_ENABLE_POSITION_INDEPENDENT_CODE) - -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON - else() - -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=OFF - endif() + -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=${tensorflow_ENABLE_POSITION_INDEPENDENT_CODE} -DCMAKE_BUILD_TYPE:STRING=Release -DCMAKE_INSTALL_PREFIX:STRING=${ZLIB_INSTALL} ) diff --git a/tensorflow/tools/ci_build/install/install_cmake_for_clang.sh b/tensorflow/contrib/cmake/make.sh similarity index 83% rename from tensorflow/tools/ci_build/install/install_cmake_for_clang.sh rename to tensorflow/contrib/cmake/make.sh index 3e626a69ab5e6b7f8d1b4997b459301606501a8e..eed3c34aba1f0326ec741169a187eb2982f253a3 100755 --- a/tensorflow/tools/ci_build/install/install_cmake_for_clang.sh +++ b/tensorflow/contrib/cmake/make.sh @@ -14,6 +14,13 @@ # limitations under the License. # ============================================================================== -CMAKE_URL="https://cmake.org/files/v3.7/cmake-3.7.2-Linux-x86_64.tar.gz" +( +cd "$(dirname "$0")" +mkdir -p _build -wget -O - "${CMAKE_URL}" | tar xzf - -C /usr/local --strip-components=1 +( +cd _build +rm -rf -- * +cmake .. +) +) diff --git a/tensorflow/contrib/cmake/patches/nsync/CMakeLists.txt b/tensorflow/contrib/cmake/patches/nsync/CMakeLists.txt index 594c2492d4fd68b50c8493321a2c4dcc2d41917e..aaae18a313dd082b428654091c9411600c981ec9 100644 --- a/tensorflow/contrib/cmake/patches/nsync/CMakeLists.txt +++ b/tensorflow/contrib/cmake/patches/nsync/CMakeLists.txt @@ -158,12 +158,21 @@ if (NOT "${NSYNC_LANGUAGE}X" STREQUAL "c++11X") elseif ("${CMAKE_SYSTEM_NAME}X" STREQUAL "NetBSDX") include_directories ("${PROJECT_SOURCE_DIR}/platform/netbsd") set (NSYNC_POSIX ON) + set (NSYNC_OS_EXTRA_SRC + "platform/posix/src/nsync_semaphore_mutex.c" + ) elseif ("${CMAKE_SYSTEM_NAME}X" STREQUAL "FreeBSDX") include_directories ("${PROJECT_SOURCE_DIR}/platform/freebsd") set (NSYNC_POSIX ON) + set (NSYNC_OS_EXTRA_SRC + "platform/posix/src/nsync_semaphore_mutex.c" + ) elseif ("${CMAKE_SYSTEM_NAME}X" STREQUAL "OpenBSDX") include_directories ("${PROJECT_SOURCE_DIR}/platform/openbsd") set (NSYNC_POSIX ON) + set (NSYNC_OS_EXTRA_SRC + "platform/posix/src/nsync_semaphore_mutex.c" + ) endif () endif () diff --git a/tensorflow/contrib/cmake/python_modules.txt b/tensorflow/contrib/cmake/python_modules.txt new file mode 100644 index 0000000000000000000000000000000000000000..e37d059a84cb3d75cebf2473e7880f6d6cb20a69 --- /dev/null +++ b/tensorflow/contrib/cmake/python_modules.txt @@ -0,0 +1,440 @@ +tensorflow +tensorflow/core +tensorflow/core/example +tensorflow/core/framework +tensorflow/core/lib +tensorflow/core/lib/core +tensorflow/core/protobuf +tensorflow/core/util +tensorflow/examples +tensorflow/examples/tutorials +tensorflow/examples/tutorials/mnist +tensorflow/python +tensorflow/python/client +tensorflow/python/data +tensorflow/python/data/ops +tensorflow/python/data/util +tensorflow/python/debug +tensorflow/python/debug/cli +tensorflow/python/debug/examples +tensorflow/python/debug/lib +tensorflow/python/debug/wrappers +tensorflow/python/eager +tensorflow/python/estimator +tensorflow/python/estimator/canned +tensorflow/python/estimator/export +tensorflow/python/estimator/inputs +tensorflow/python/estimator/inputs/queues +tensorflow/python/feature_column +tensorflow/python/framework +tensorflow/python/grappler +tensorflow/python/keras +tensorflow/python/keras/activations +tensorflow/python/keras/applications +tensorflow/python/keras/applications/inception_resnet_v2 +tensorflow/python/keras/applications/inception_v3 +tensorflow/python/keras/applications/mobilenet +tensorflow/python/keras/applications/resnet50 +tensorflow/python/keras/applications/vgg16 +tensorflow/python/keras/applications/vgg19 +tensorflow/python/keras/applications/xception +tensorflow/python/keras/backend +tensorflow/python/keras/callbacks +tensorflow/python/keras/constraints +tensorflow/python/keras/datasets +tensorflow/python/keras/datasets/boston_housing +tensorflow/python/keras/datasets/cifar10 +tensorflow/python/keras/datasets/cifar100 +tensorflow/python/keras/datasets/fashion_mnist +tensorflow/python/keras/datasets/imdb +tensorflow/python/keras/datasets/mnist +tensorflow/python/keras/datasets/reuters +tensorflow/python/keras/estimator +tensorflow/python/keras/initializers +tensorflow/python/keras/layers +tensorflow/python/keras/losses +tensorflow/python/keras/metrics +tensorflow/python/keras/models +tensorflow/python/keras/optimizers +tensorflow/python/keras/preprocessing +tensorflow/python/keras/preprocessing/image +tensorflow/python/keras/preprocessing/sequence +tensorflow/python/keras/preprocessing/text +tensorflow/python/keras/regularizers +tensorflow/python/keras/utils +tensorflow/python/keras/wrappers +tensorflow/python/keras/wrappers/scikit_learn +tensorflow/python/keras/_impl +tensorflow/python/keras/_impl/keras +tensorflow/python/keras/_impl/keras/applications +tensorflow/python/keras/_impl/keras/datasets +tensorflow/python/keras/_impl/keras/engine +tensorflow/python/keras/_impl/keras/layers +tensorflow/python/keras/_impl/keras/preprocessing +tensorflow/python/keras/_impl/keras/utils +tensorflow/python/keras/_impl/keras/wrappers +tensorflow/python/kernel_tests +tensorflow/python/kernel_tests/distributions +tensorflow/python/kernel_tests/linalg +tensorflow/python/kernel_tests/random +tensorflow/python/layers +tensorflow/python/lib +tensorflow/python/lib/core +tensorflow/python/lib/io +tensorflow/python/ops +tensorflow/python/ops/distributions +tensorflow/python/ops/linalg +tensorflow/python/ops/losses +tensorflow/python/platform +tensorflow/python/profiler +tensorflow/python/profiler/internal +tensorflow/python/saved_model +tensorflow/python/summary +tensorflow/python/summary/writer +tensorflow/python/tools +tensorflow/python/training +tensorflow/python/user_ops +tensorflow/python/util +tensorflow/python/util/protobuf +tensorflow/tools +tensorflow/tools/graph_transforms +tensorflow/contrib +tensorflow/contrib/all_reduce +tensorflow/contrib/all_reduce/python +tensorflow/contrib/android +tensorflow/contrib/android/java +tensorflow/contrib/android/java/org +tensorflow/contrib/android/java/org/tensorflow +tensorflow/contrib/android/java/org/tensorflow/contrib +tensorflow/contrib/android/java/org/tensorflow/contrib/android +tensorflow/contrib/android/jni +tensorflow/contrib/batching +tensorflow/contrib/batching/kernels +tensorflow/contrib/batching/python +tensorflow/contrib/batching/python/ops +tensorflow/contrib/bayesflow +tensorflow/contrib/bayesflow/python +tensorflow/contrib/bayesflow/python/ops +tensorflow/contrib/boosted_trees +tensorflow/contrib/boosted_trees/estimator_batch +tensorflow/contrib/boosted_trees/kernels +tensorflow/contrib/boosted_trees/ops +tensorflow/contrib/boosted_trees/proto +tensorflow/contrib/boosted_trees/python +tensorflow/contrib/boosted_trees/python/ops +tensorflow/contrib/cloud +tensorflow/contrib/cloud/kernels +tensorflow/contrib/cloud/ops +tensorflow/contrib/cloud/python +tensorflow/contrib/cloud/python/ops +tensorflow/contrib/cluster_resolver +tensorflow/contrib/cluster_resolver/python +tensorflow/contrib/cluster_resolver/python/training +tensorflow/contrib/coder +tensorflow/contrib/coder/kernels +tensorflow/contrib/coder/ops +tensorflow/contrib/coder/python +tensorflow/contrib/coder/python/ops +tensorflow/contrib/compiler +tensorflow/contrib/copy_graph +tensorflow/contrib/copy_graph/python +tensorflow/contrib/copy_graph/python/util +tensorflow/contrib/crf +tensorflow/contrib/crf/python +tensorflow/contrib/crf/python/ops +tensorflow/contrib/cudnn_rnn +tensorflow/contrib/cudnn_rnn/kernels +tensorflow/contrib/cudnn_rnn/ops +tensorflow/contrib/cudnn_rnn/python +tensorflow/contrib/cudnn_rnn/python/layers +tensorflow/contrib/cudnn_rnn/python/ops +tensorflow/contrib/data +tensorflow/contrib/data/kernels +tensorflow/contrib/data/python +tensorflow/contrib/data/python/kernel_tests +tensorflow/contrib/data/python/ops +tensorflow/contrib/decision_trees +tensorflow/contrib/decision_trees/proto +tensorflow/contrib/deprecated +tensorflow/contrib/distributions +tensorflow/contrib/distributions/python +tensorflow/contrib/distributions/python/ops +tensorflow/contrib/distributions/python/ops/bijectors +tensorflow/contrib/eager +tensorflow/contrib/eager/python +tensorflow/contrib/estimator +tensorflow/contrib/estimator/python +tensorflow/contrib/estimator/python/estimator +tensorflow/contrib/factorization +tensorflow/contrib/factorization/examples +tensorflow/contrib/factorization/kernels +tensorflow/contrib/factorization/ops +tensorflow/contrib/factorization/python +tensorflow/contrib/factorization/python/ops +tensorflow/contrib/ffmpeg +tensorflow/contrib/ffmpeg/default +tensorflow/contrib/framework +tensorflow/contrib/framework/kernels +tensorflow/contrib/framework/ops +tensorflow/contrib/framework/python +tensorflow/contrib/framework/python/framework +tensorflow/contrib/framework/python/ops +tensorflow/contrib/fused_conv +tensorflow/contrib/fused_conv/kernels +tensorflow/contrib/fused_conv/python +tensorflow/contrib/fused_conv/python/ops +tensorflow/contrib/gan +tensorflow/contrib/gan/python +tensorflow/contrib/gan/python/estimator +tensorflow/contrib/gan/python/estimator/python +tensorflow/contrib/gan/python/eval +tensorflow/contrib/gan/python/eval/python +tensorflow/contrib/gan/python/features +tensorflow/contrib/gan/python/features/python +tensorflow/contrib/gan/python/losses +tensorflow/contrib/gan/python/losses/python +tensorflow/contrib/graph_editor +tensorflow/contrib/graph_editor/examples +tensorflow/contrib/grid_rnn +tensorflow/contrib/grid_rnn/python +tensorflow/contrib/grid_rnn/python/ops +tensorflow/contrib/hooks +tensorflow/contrib/hooks/python +tensorflow/contrib/image +tensorflow/contrib/image/kernels +tensorflow/contrib/image/ops +tensorflow/contrib/image/python +tensorflow/contrib/image/python/ops +tensorflow/contrib/input_pipeline +tensorflow/contrib/input_pipeline/kernels +tensorflow/contrib/input_pipeline/ops +tensorflow/contrib/input_pipeline/python +tensorflow/contrib/input_pipeline/python/ops +tensorflow/contrib/integrate +tensorflow/contrib/integrate/python +tensorflow/contrib/integrate/python/ops +tensorflow/contrib/keras +tensorflow/contrib/keras/api +tensorflow/contrib/keras/api/keras +tensorflow/contrib/keras/api/keras/activations +tensorflow/contrib/keras/api/keras/applications +tensorflow/contrib/keras/api/keras/applications/inception_v3 +tensorflow/contrib/keras/api/keras/applications/mobilenet +tensorflow/contrib/keras/api/keras/applications/resnet50 +tensorflow/contrib/keras/api/keras/applications/vgg16 +tensorflow/contrib/keras/api/keras/applications/vgg19 +tensorflow/contrib/keras/api/keras/applications/xception +tensorflow/contrib/keras/api/keras/backend +tensorflow/contrib/keras/api/keras/callbacks +tensorflow/contrib/keras/api/keras/constraints +tensorflow/contrib/keras/api/keras/datasets +tensorflow/contrib/keras/api/keras/datasets/boston_housing +tensorflow/contrib/keras/api/keras/datasets/cifar10 +tensorflow/contrib/keras/api/keras/datasets/cifar100 +tensorflow/contrib/keras/api/keras/datasets/imdb +tensorflow/contrib/keras/api/keras/datasets/mnist +tensorflow/contrib/keras/api/keras/datasets/reuters +tensorflow/contrib/keras/api/keras/initializers +tensorflow/contrib/keras/api/keras/layers +tensorflow/contrib/keras/api/keras/losses +tensorflow/contrib/keras/api/keras/metrics +tensorflow/contrib/keras/api/keras/models +tensorflow/contrib/keras/api/keras/optimizers +tensorflow/contrib/keras/api/keras/preprocessing +tensorflow/contrib/keras/api/keras/preprocessing/image +tensorflow/contrib/keras/api/keras/preprocessing/sequence +tensorflow/contrib/keras/api/keras/preprocessing/text +tensorflow/contrib/keras/api/keras/regularizers +tensorflow/contrib/keras/api/keras/utils +tensorflow/contrib/keras/api/keras/wrappers +tensorflow/contrib/keras/api/keras/wrappers/scikit_learn +tensorflow/contrib/kernel_methods +tensorflow/contrib/kernel_methods/python +tensorflow/contrib/kernel_methods/python/mappers +tensorflow/contrib/kfac +tensorflow/contrib/kfac/examples +tensorflow/contrib/kfac/python +tensorflow/contrib/kfac/python/ops +tensorflow/contrib/labeled_tensor +tensorflow/contrib/labeled_tensor/python +tensorflow/contrib/labeled_tensor/python/ops +tensorflow/contrib/layers +tensorflow/contrib/layers/kernels +tensorflow/contrib/layers/ops +tensorflow/contrib/layers/python +tensorflow/contrib/layers/python/layers +tensorflow/contrib/layers/python/ops +tensorflow/contrib/learn +tensorflow/contrib/learn/python +tensorflow/contrib/learn/python/learn +tensorflow/contrib/learn/python/learn/datasets +tensorflow/contrib/learn/python/learn/datasets/data +tensorflow/contrib/learn/python/learn/estimators +tensorflow/contrib/learn/python/learn/learn_io +tensorflow/contrib/learn/python/learn/ops +tensorflow/contrib/learn/python/learn/preprocessing +tensorflow/contrib/learn/python/learn/utils +tensorflow/contrib/legacy_seq2seq +tensorflow/contrib/legacy_seq2seq/python +tensorflow/contrib/legacy_seq2seq/python/ops +tensorflow/contrib/libsvm +tensorflow/contrib/libsvm/python +tensorflow/contrib/libsvm/python/kernel_tests +tensorflow/contrib/libsvm/python/ops +tensorflow/contrib/linalg +tensorflow/contrib/linalg/python +tensorflow/contrib/linalg/python/ops +tensorflow/contrib/linear_optimizer +tensorflow/contrib/linear_optimizer/kernels +tensorflow/contrib/linear_optimizer/kernels/g3doc +tensorflow/contrib/linear_optimizer/python +tensorflow/contrib/linear_optimizer/python/ops +# TODO(drpngx): Fix failing imports +# tensorflow/contrib/lite/python +# tensorflow/contrib/lite/toco/python +tensorflow/contrib/lookup +tensorflow/contrib/losses +tensorflow/contrib/losses/python +tensorflow/contrib/losses/python/losses +tensorflow/contrib/losses/python/metric_learning +tensorflow/contrib/makefile +tensorflow/contrib/memory_stats +tensorflow/contrib/memory_stats/kernels +tensorflow/contrib/memory_stats/ops +tensorflow/contrib/memory_stats/python +tensorflow/contrib/memory_stats/python/ops +tensorflow/contrib/meta_graph_transform +tensorflow/contrib/metrics +tensorflow/contrib/metrics/python +tensorflow/contrib/metrics/python/metrics +tensorflow/contrib/metrics/python/ops +tensorflow/contrib/model_pruning +tensorflow/contrib/model_pruning/examples +tensorflow/contrib/model_pruning/examples/cifar10 +tensorflow/contrib/model_pruning/python +tensorflow/contrib/model_pruning/python/layers +tensorflow/contrib/nccl +tensorflow/contrib/nccl/kernels +tensorflow/contrib/nccl/ops +tensorflow/contrib/nccl/python +tensorflow/contrib/nccl/python/ops +tensorflow/contrib/ndlstm +tensorflow/contrib/ndlstm/python +tensorflow/contrib/nearest_neighbor/kernels +tensorflow/contrib/nearest_neighbor/ops +tensorflow/contrib/nearest_neighbor/python +tensorflow/contrib/nearest_neighbor/python/ops +tensorflow/contrib/nn +tensorflow/contrib/nn/python +tensorflow/contrib/nn/python/ops +tensorflow/contrib/opt +tensorflow/contrib/opt/python +tensorflow/contrib/opt/python/training +tensorflow/contrib/pi_examples +tensorflow/contrib/pi_examples/camera +tensorflow/contrib/pi_examples/label_image +tensorflow/contrib/pi_examples/label_image/data +tensorflow/contrib/periodic_resample +tensorflow/contrib/periodic_resample/python +tensorflow/contrib/periodic_resample/python/ops +tensorflow/contrib/predictor +tensorflow/contrib/quantization +tensorflow/contrib/quantization/python +tensorflow/contrib/quantize +tensorflow/contrib/quantize/python +tensorflow/contrib/receptive_field +tensorflow/contrib/receptive_field/python +tensorflow/contrib/receptive_field/python/util +tensorflow/contrib/receptive_field/python/util/examples +tensorflow/contrib/reduce_slice_ops +tensorflow/contrib/reduce_slice_ops/kernels +tensorflow/contrib/reduce_slice_ops/ops +tensorflow/contrib/reduce_slice_ops/python +tensorflow/contrib/reduce_slice_ops/python/ops +tensorflow/contrib/remote_fused_graph/pylib +tensorflow/contrib/remote_fused_graph/pylib/python +tensorflow/contrib/remote_fused_graph/pylib/python/ops +tensorflow/contrib/resampler +tensorflow/contrib/resampler/kernels +tensorflow/contrib/resampler/ops +tensorflow/contrib/resampler/python +tensorflow/contrib/resampler/python/ops +tensorflow/contrib/rnn +tensorflow/contrib/rnn/kernels +tensorflow/contrib/rnn/ops +tensorflow/contrib/rnn/python +tensorflow/contrib/rnn/python/kernel_tests +tensorflow/contrib/rnn/python/ops +tensorflow/contrib/saved_model +tensorflow/contrib/saved_model/python +tensorflow/contrib/saved_model/python/saved_model +tensorflow/contrib/seq2seq +tensorflow/contrib/seq2seq/kernels +tensorflow/contrib/seq2seq/ops +tensorflow/contrib/seq2seq/python +tensorflow/contrib/seq2seq/python/ops +tensorflow/contrib/session_bundle +tensorflow/contrib/session_bundle/example +tensorflow/contrib/signal +tensorflow/contrib/signal/python +tensorflow/contrib/signal/python/ops +tensorflow/contrib/slim +tensorflow/contrib/slim/python +tensorflow/contrib/slim/python/slim +tensorflow/contrib/slim/python/slim/data +tensorflow/contrib/slim/python/slim/nets +tensorflow/contrib/solvers +tensorflow/contrib/solvers/python +tensorflow/contrib/solvers/python/ops +tensorflow/contrib/sparsemax +tensorflow/contrib/sparsemax/python +tensorflow/contrib/sparsemax/python/ops +tensorflow/contrib/specs +tensorflow/contrib/specs/python +tensorflow/contrib/staging +tensorflow/contrib/stat_summarizer +tensorflow/contrib/stat_summarizer/python +tensorflow/contrib/stateless +tensorflow/contrib/stateless/python +tensorflow/contrib/summary +tensorflow/contrib/tensorboard +tensorflow/contrib/tensorboard/plugins +tensorflow/contrib/tensorboard/plugins/projector +tensorflow/contrib/tensor_forest +tensorflow/contrib/tensor_forest/client +tensorflow/contrib/tensor_forest/hybrid +tensorflow/contrib/tensor_forest/hybrid/core +tensorflow/contrib/tensor_forest/hybrid/core/ops +tensorflow/contrib/tensor_forest/hybrid/python +tensorflow/contrib/tensor_forest/hybrid/python/layers +tensorflow/contrib/tensor_forest/hybrid/python/models +tensorflow/contrib/tensor_forest/hybrid/python/ops +tensorflow/contrib/tensor_forest/kernels +tensorflow/contrib/tensor_forest/python +tensorflow/contrib/tensor_forest/python/ops +tensorflow/contrib/testing +tensorflow/contrib/testing/python +tensorflow/contrib/testing/python/framework +tensorflow/contrib/text +tensorflow/contrib/text/kernels +tensorflow/contrib/text/ops +tensorflow/contrib/text/python +tensorflow/contrib/text/python/ops +tensorflow/contrib/tfprof +tensorflow/contrib/timeseries +tensorflow/contrib/timeseries/examples +tensorflow/contrib/timeseries/examples/data +tensorflow/contrib/timeseries/python +tensorflow/contrib/timeseries/python/timeseries +tensorflow/contrib/timeseries/python/timeseries/state_space_models +tensorflow/contrib/tpu +tensorflow/contrib/tpu/ops +tensorflow/contrib/tpu/profiler +tensorflow/contrib/tpu/python +tensorflow/contrib/tpu/python/ops +tensorflow/contrib/tpu/python/profiler +tensorflow/contrib/tpu/python/tpu +tensorflow/contrib/training +tensorflow/contrib/training/python +tensorflow/contrib/training/python/training +tensorflow/contrib/util diff --git a/tensorflow/contrib/cmake/python_protos.txt b/tensorflow/contrib/cmake/python_protos.txt new file mode 100644 index 0000000000000000000000000000000000000000..8a9c406d8b118c10ddcaafb0e4fc242aa79cdb57 --- /dev/null +++ b/tensorflow/contrib/cmake/python_protos.txt @@ -0,0 +1,19 @@ +tensorflow/core +tensorflow/core/profiler +tensorflow/python +tensorflow/contrib/boosted_trees/proto +tensorflow/contrib/cloud/kernels +tensorflow/contrib/decision_trees/proto +tensorflow/contrib/gdr +tensorflow/contrib/lite/toco +tensorflow/contrib/mpi +tensorflow/contrib/mpi_collectives +tensorflow/contrib/session_bundle +tensorflow/contrib/tensor_forest/proto +tensorflow/contrib/tensorboard/graph_explorer/proto +tensorflow/contrib/tensorboard/plugins/projector +tensorflow/contrib/tensorboard/plugins/trace +tensorflow/contrib/tpu/proto +tensorflow/contrib/tpu/profiler +tensorflow/contrib/training/python/training +tensorflow/contrib/verbs diff --git a/tensorflow/contrib/cmake/python_protos_cc.txt b/tensorflow/contrib/cmake/python_protos_cc.txt new file mode 100644 index 0000000000000000000000000000000000000000..d4a257b25c814a1464308d0e6ce3ce65d21f6a36 --- /dev/null +++ b/tensorflow/contrib/cmake/python_protos_cc.txt @@ -0,0 +1,5 @@ +tensorflow/core/profiler +tensorflow/python +tensorflow/contrib/session_bundle +tensorflow/contrib/tensorboard +tensorflow/contrib/training diff --git a/tensorflow/contrib/cmake/tf_cc_ops.cmake b/tensorflow/contrib/cmake/tf_cc_ops.cmake index 6e2ac203f9a7f96cb14752a91483840a9eb6b451..f3cf3e70441de67ef79bc9cedf85549315170c29 100644 --- a/tensorflow/contrib/cmake/tf_cc_ops.cmake +++ b/tensorflow/contrib/cmake/tf_cc_ops.cmake @@ -83,7 +83,7 @@ foreach(tf_cc_op_lib_name ${tf_cc_op_lib_names}) ${cc_ops_target_dir}/${tf_cc_op_lib_name}.cc ${cc_ops_target_dir}/${tf_cc_op_lib_name}_internal.h ${cc_ops_target_dir}/${tf_cc_op_lib_name}_internal.cc - COMMAND ${tf_cc_op_lib_name}_gen_cc ${cc_ops_target_dir}/${tf_cc_op_lib_name}.h ${cc_ops_target_dir}/${tf_cc_op_lib_name}.cc ${tensorflow_source_dir}/tensorflow/cc/ops/op_gen_overrides.pbtxt ${cc_ops_include_internal} ${tensorflow_source_dir}/tensorflow/core/api_def/base_api + COMMAND ${tf_cc_op_lib_name}_gen_cc ${cc_ops_target_dir}/${tf_cc_op_lib_name}.h ${cc_ops_target_dir}/${tf_cc_op_lib_name}.cc ${cc_ops_include_internal} ${tensorflow_source_dir}/tensorflow/core/api_def/base_api DEPENDS ${tf_cc_op_lib_name}_gen_cc create_cc_ops_header_dir ) diff --git a/tensorflow/contrib/cmake/tf_core_cpu.cmake b/tensorflow/contrib/cmake/tf_core_cpu.cmake index 5c01ca382fb9cc7a01a6f2b60a510c59f0aa7119..e4213ea2a47da2a7381cccd0504235ad62018d4e 100644 --- a/tensorflow/contrib/cmake/tf_core_cpu.cmake +++ b/tensorflow/contrib/cmake/tf_core_cpu.cmake @@ -63,7 +63,7 @@ if (tensorflow_ENABLE_GPU) file(GLOB_RECURSE tf_core_gpu_srcs "${tensorflow_source_dir}/tensorflow/core/common_runtime/gpu/*.cc" "${tensorflow_source_dir}/tensorflow/core/platform/default/gpu/cupti_wrapper.cc" - "${tensorflow_source_dir}/tensorflow/core/platform/default/gpu_tracer.cc" + "${tensorflow_source_dir}/tensorflow/core/platform/default/device_tracer.cc" "${tensorflow_source_dir}/tensorflow/core/common_runtime/gpu_device_factory.cc" "${tensorflow_source_dir}/tensorflow/core/grappler/devices.h" "${tensorflow_source_dir}/tensorflow/core/grappler/devices.cc" diff --git a/tensorflow/contrib/cmake/tf_core_framework.cmake b/tensorflow/contrib/cmake/tf_core_framework.cmake index c607546f4a5244fb6e7cd12db874f07a962f6f4d..24d7fb82a268623be06c2b98b5857b6b9b95c3a1 100644 --- a/tensorflow/contrib/cmake/tf_core_framework.cmake +++ b/tensorflow/contrib/cmake/tf_core_framework.cmake @@ -191,10 +191,6 @@ file(GLOB_RECURSE tf_core_lib_srcs "${tensorflow_source_dir}/tensorflow/core/lib/*.h" "${tensorflow_source_dir}/tensorflow/core/lib/*.cc" "${tensorflow_source_dir}/tensorflow/core/public/*.h" - # TODO(@jart): Move StatusOr into core. - "${tensorflow_source_dir}/tensorflow/compiler/xla/statusor.cc" - "${tensorflow_source_dir}/tensorflow/compiler/xla/statusor.h" - "${tensorflow_source_dir}/tensorflow/compiler/xla/statusor_internals.h" ) file(GLOB tf_core_platform_srcs @@ -211,7 +207,7 @@ if (NOT tensorflow_ENABLE_GPU) list(REMOVE_ITEM tf_core_platform_srcs ${tf_core_platform_gpu_srcs}) else() file(GLOB tf_core_platform_srcs_exclude - "${tensorflow_source_dir}/tensorflow/core/platform/default/gpu_tracer.cc") + "${tensorflow_source_dir}/tensorflow/core/platform/default/device_tracer.cc") list(REMOVE_ITEM tf_core_platform_srcs ${tf_core_platform_srcs_exclude}) endif() @@ -317,8 +313,15 @@ file(GLOB_RECURSE tf_core_framework_exclude_srcs "${tensorflow_source_dir}/tensorflow/core/util/*test*.cc" "${tensorflow_source_dir}/tensorflow/core/util/*main.cc" "${tensorflow_source_dir}/tensorflow/contrib/tensorboard/db/*test*.cc" + "${tensorflow_source_dir}/tensorflow/contrib/tensorboard/db/loader.cc" + "${tensorflow_source_dir}/tensorflow/contrib/tensorboard/db/vacuum.cc" ) +# TODO(jart): Why doesn't this work? +# set_source_files_properties( +# ${tensorflow_source_dir}/tensorflow/contrib/tensorboard/db/snapfn.cc +# PROPERTIES COMPILE_FLAGS -DSQLITE_OMIT_LOAD_EXTENSION) + list(REMOVE_ITEM tf_core_framework_srcs ${tf_core_framework_exclude_srcs}) add_library(tf_core_framework OBJECT diff --git a/tensorflow/contrib/cmake/tf_core_kernels.cmake b/tensorflow/contrib/cmake/tf_core_kernels.cmake index 2d015908a890fd7757bf212573f4ebce8ba8b30d..6927bf03f08b68a1f13f6a0978af629af45575e8 100644 --- a/tensorflow/contrib/cmake/tf_core_kernels.cmake +++ b/tensorflow/contrib/cmake/tf_core_kernels.cmake @@ -63,6 +63,10 @@ if(tensorflow_BUILD_CONTRIB_KERNELS) "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/ops/stats_accumulator_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/ops/training_ops.cc" + "${tensorflow_source_dir}/tensorflow/contrib/coder/kernels/range_coder.cc" + "${tensorflow_source_dir}/tensorflow/contrib/coder/kernels/range_coder_ops.cc" + "${tensorflow_source_dir}/tensorflow/contrib/coder/kernels/range_coder_ops_util.cc" + "${tensorflow_source_dir}/tensorflow/contrib/coder/ops/coder_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/cudnn_rnn/kernels/cudnn_rnn_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/cudnn_rnn/ops/cudnn_rnn_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/data/kernels/prefetching_kernels.cc" @@ -79,12 +83,15 @@ if(tensorflow_BUILD_CONTRIB_KERNELS) "${tensorflow_source_dir}/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.cc" "${tensorflow_source_dir}/tensorflow/contrib/image/kernels/bipartite_match_op.cc" "${tensorflow_source_dir}/tensorflow/contrib/image/kernels/image_ops.cc" + "${tensorflow_source_dir}/tensorflow/contrib/image/kernels/segmentation_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/image/kernels/single_image_random_dot_stereograms_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/image/ops/distort_image_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/image/ops/image_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/image/ops/single_image_random_dot_stereograms_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/layers/kernels/sparse_feature_cross_kernel.cc" "${tensorflow_source_dir}/tensorflow/contrib/layers/ops/sparse_feature_cross_op.cc" + "${tensorflow_source_dir}/tensorflow/contrib/libsvm/kernels/decode_libsvm_op.cc" + "${tensorflow_source_dir}/tensorflow/contrib/libsvm/ops/libsvm_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/nccl/kernels/nccl_manager.cc" "${tensorflow_source_dir}/tensorflow/contrib/nccl/kernels/nccl_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/nccl/ops/nccl_ops.cc" @@ -150,9 +157,6 @@ list(REMOVE_ITEM tf_core_kernels_srcs ${tf_core_kernels_exclude_srcs}) if(WIN32) file(GLOB_RECURSE tf_core_kernels_windows_exclude_srcs # not working on windows yet - "${tensorflow_source_dir}/tensorflow/core/kernels/meta_support.*" - "${tensorflow_source_dir}/tensorflow/core/kernels/*quantiz*.h" - "${tensorflow_source_dir}/tensorflow/core/kernels/*quantiz*.cc" "${tensorflow_source_dir}/tensorflow/core/kernels/neon/*" # not in core - those are loaded dynamically as dll "${tensorflow_source_dir}/tensorflow/contrib/nearest_neighbor/kernels/hyperplane_lsh_probes.cc" diff --git a/tensorflow/contrib/cmake/tf_core_ops.cmake b/tensorflow/contrib/cmake/tf_core_ops.cmake index e8c2cd347327843d10d13c1d24a800ff776aa8c1..6f56e9d0869bc0d3311ffbc68326f8ab43758019 100644 --- a/tensorflow/contrib/cmake/tf_core_ops.cmake +++ b/tensorflow/contrib/cmake/tf_core_ops.cmake @@ -26,6 +26,7 @@ set(tf_op_lib_names "image_ops" "io_ops" "linalg_ops" + "list_ops" "lookup_ops" "logging_ops" "math_ops" @@ -80,6 +81,7 @@ GENERATE_CONTRIB_OP_LIBRARY(boosted_trees_training "${tensorflow_source_dir}/ten GENERATE_CONTRIB_OP_LIBRARY(boosted_trees_prediction "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/ops/prediction_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(boosted_trees_quantiles "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/ops/quantile_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(boosted_trees_stats_accumulator "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/ops/stats_accumulator_ops.cc") +GENERATE_CONTRIB_OP_LIBRARY(coder "${tensorflow_source_dir}/tensorflow/contrib/coder/ops/coder_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(cudnn_rnn "${tensorflow_source_dir}/tensorflow/contrib/cudnn_rnn/ops/cudnn_rnn_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(data_prefetching "${tensorflow_source_dir}/tensorflow/contrib/data/ops/prefetching_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(factorization_clustering "${tensorflow_source_dir}/tensorflow/contrib/factorization/ops/clustering_ops.cc") diff --git a/tensorflow/contrib/cmake/tf_core_profiler.cmake b/tensorflow/contrib/cmake/tf_core_profiler.cmake index 61ed6a1e145299125d037b48b8b644cae1ce96e7..b91a7f43e5c03e933d10572e54e0c8c914c55f71 100644 --- a/tensorflow/contrib/cmake/tf_core_profiler.cmake +++ b/tensorflow/contrib/cmake/tf_core_profiler.cmake @@ -17,6 +17,8 @@ ######################################################## file(GLOB_RECURSE tf_core_profiler_srcs "${tensorflow_source_dir}/tensorflow/core/profiler/*.proto" + "${tensorflow_source_dir}/tensorflow/core/profiler/tfprof_options.h" + "${tensorflow_source_dir}/tensorflow/core/profiler/tfprof_options.cc" "${tensorflow_source_dir}/tensorflow/core/profiler/internal/*.h" "${tensorflow_source_dir}/tensorflow/core/profiler/internal/*.cc" "${tensorflow_source_dir}/tensorflow/core/profiler/internal/advisor/*.h" diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake index 5e15a972d6272151e128c37dfe398225e3b4f44e..17bbdb1a86f4a1b026b6d159a7b8adad9a3d1f57 100755 --- a/tensorflow/contrib/cmake/tf_python.cmake +++ b/tensorflow/contrib/cmake/tf_python.cmake @@ -120,33 +120,44 @@ function(RELATIVE_PROTOBUF_GENERATE_CPP SRCS HDRS ROOT_DIR) set(${HDRS} ${${HDRS}} PARENT_SCOPE) endfunction() -file(GLOB_RECURSE tf_protos_python_srcs RELATIVE ${tensorflow_source_dir} - "${tensorflow_source_dir}/tensorflow/core/*.proto" - "${tensorflow_source_dir}/tensorflow/core/profiler/*.proto" - "${tensorflow_source_dir}/tensorflow/python/*.proto" - "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/proto/*.proto" - "${tensorflow_source_dir}/tensorflow/contrib/decision_trees/proto/*.proto" - "${tensorflow_source_dir}/tensorflow/contrib/session_bundle/*.proto" - "${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/proto/*.proto" - "${tensorflow_source_dir}/tensorflow/contrib/tensorboard/*.proto" - "${tensorflow_source_dir}/tensorflow/contrib/tpu/proto/*.proto" - "${tensorflow_source_dir}/tensorflow/contrib/tpu/profiler/*.proto" - "${tensorflow_source_dir}/tensorflow/contrib/training/*.proto" -) +FILE(READ python_protos.txt python_protos) +# Convert file contents into a CMake list (where each element in the list is one line of the file) +STRING(REGEX REPLACE ";" "\\\\;" python_protos "${python_protos}") +STRING(REGEX REPLACE "\n" ";" python_protos "${python_protos}") + +foreach(python_proto ${python_protos}) + if(NOT python_proto MATCHES "\#") + if(NOT EXISTS "${tensorflow_source_dir}/${python_proto}") + message(SEND_ERROR "Python proto directory not found: ${python_proto}") + endif() + file(GLOB_RECURSE tf_python_protos_src RELATIVE ${tensorflow_source_dir} + "${tensorflow_source_dir}/${python_proto}/*.proto" + ) + list(APPEND tf_python_protos_srcs ${tf_python_protos_src}) + endif() +endforeach(python_proto) + RELATIVE_PROTOBUF_GENERATE_PYTHON( - ${tensorflow_source_dir} PYTHON_PROTO_GENFILES ${tf_protos_python_srcs} + ${tensorflow_source_dir} PYTHON_PROTO_GENFILES ${tf_python_protos_srcs} ) -# NOTE(mrry): Avoid regenerating the tensorflow/core protos because this -# can cause benign-but-failing-on-Windows-due-to-file-locking conflicts -# when two rules attempt to generate the same file. -file(GLOB_RECURSE tf_python_protos_cc_srcs RELATIVE ${tensorflow_source_dir} - "${tensorflow_source_dir}/tensorflow/core/profiler/*.proto" - "${tensorflow_source_dir}/tensorflow/python/*.proto" - "${tensorflow_source_dir}/tensorflow/contrib/session_bundle/*.proto" - "${tensorflow_source_dir}/tensorflow/contrib/tensorboard/*.proto" - "${tensorflow_source_dir}/tensorflow/contrib/training/*.proto" -) +FILE(READ python_protos_cc.txt python_protos_cc) +# Convert file contents into a CMake list (where each element in the list is one line of the file) +STRING(REGEX REPLACE ";" "\\\\;" python_protos_cc "${python_protos_cc}") +STRING(REGEX REPLACE "\n" ";" python_protos_cc "${python_protos_cc}") + +foreach(python_proto_cc ${python_protos_cc}) + if(NOT python_proto_cc MATCHES "\#") + if(NOT EXISTS "${tensorflow_source_dir}/${python_proto_cc}") + message(SEND_ERROR "Python proto CC directory not found: ${python_proto_cc}") + endif() + file(GLOB_RECURSE tf_python_protos_cc_src RELATIVE ${tensorflow_source_dir} + "${tensorflow_source_dir}/${python_proto_cc}/*.proto" + ) + list(APPEND tf_python_protos_cc_srcs ${tf_python_protos_cc_src}) + endif() +endforeach(python_proto_cc) + RELATIVE_PROTOBUF_GENERATE_CPP(PROTO_SRCS PROTO_HDRS ${tensorflow_source_dir} ${tf_python_protos_cc_srcs} ) @@ -192,315 +203,20 @@ function(add_python_module MODULE_NAME) endif() endfunction() -add_python_module("tensorflow") -add_python_module("tensorflow/core") -add_python_module("tensorflow/core/example") -add_python_module("tensorflow/core/framework") -add_python_module("tensorflow/core/lib") -add_python_module("tensorflow/core/lib/core") -add_python_module("tensorflow/core/protobuf") -add_python_module("tensorflow/core/util") -add_python_module("tensorflow/examples") -add_python_module("tensorflow/examples/tutorials") -add_python_module("tensorflow/examples/tutorials/mnist") -add_python_module("tensorflow/python") -add_python_module("tensorflow/python/client") -add_python_module("tensorflow/python/data") -add_python_module("tensorflow/python/data/ops") -add_python_module("tensorflow/python/data/util") -add_python_module("tensorflow/python/debug") -add_python_module("tensorflow/python/debug/cli") -add_python_module("tensorflow/python/debug/examples") -add_python_module("tensorflow/python/debug/lib") -add_python_module("tensorflow/python/debug/wrappers") -add_python_module("tensorflow/python/eager") -add_python_module("tensorflow/python/estimator") -add_python_module("tensorflow/python/estimator/canned") -add_python_module("tensorflow/python/estimator/export") -add_python_module("tensorflow/python/estimator/inputs") -add_python_module("tensorflow/python/estimator/inputs/queues") -add_python_module("tensorflow/python/feature_column") -add_python_module("tensorflow/python/framework") -add_python_module("tensorflow/python/grappler") -add_python_module("tensorflow/python/keras") -add_python_module("tensorflow/python/keras/activations") -add_python_module("tensorflow/python/keras/applications") -add_python_module("tensorflow/python/keras/applications/inception_resnet_v2") -add_python_module("tensorflow/python/keras/applications/inception_v3") -add_python_module("tensorflow/python/keras/applications/mobilenet") -add_python_module("tensorflow/python/keras/applications/resnet50") -add_python_module("tensorflow/python/keras/applications/vgg16") -add_python_module("tensorflow/python/keras/applications/vgg19") -add_python_module("tensorflow/python/keras/applications/xception") -add_python_module("tensorflow/python/keras/backend") -add_python_module("tensorflow/python/keras/callbacks") -add_python_module("tensorflow/python/keras/constraints") -add_python_module("tensorflow/python/keras/datasets") -add_python_module("tensorflow/python/keras/datasets/boston_housing") -add_python_module("tensorflow/python/keras/datasets/cifar10") -add_python_module("tensorflow/python/keras/datasets/cifar100") -add_python_module("tensorflow/python/keras/datasets/fashion_mnist") -add_python_module("tensorflow/python/keras/datasets/imdb") -add_python_module("tensorflow/python/keras/datasets/mnist") -add_python_module("tensorflow/python/keras/datasets/reuters") -add_python_module("tensorflow/python/keras/estimator") -add_python_module("tensorflow/python/keras/initializers") -add_python_module("tensorflow/python/keras/layers") -add_python_module("tensorflow/python/keras/losses") -add_python_module("tensorflow/python/keras/metrics") -add_python_module("tensorflow/python/keras/models") -add_python_module("tensorflow/python/keras/optimizers") -add_python_module("tensorflow/python/keras/preprocessing") -add_python_module("tensorflow/python/keras/preprocessing/image") -add_python_module("tensorflow/python/keras/preprocessing/sequence") -add_python_module("tensorflow/python/keras/preprocessing/text") -add_python_module("tensorflow/python/keras/regularizers") -add_python_module("tensorflow/python/keras/utils") -add_python_module("tensorflow/python/keras/wrappers") -add_python_module("tensorflow/python/keras/wrappers/scikit_learn") -add_python_module("tensorflow/python/keras/_impl") -add_python_module("tensorflow/python/keras/_impl/keras") -add_python_module("tensorflow/python/keras/_impl/keras/applications") -add_python_module("tensorflow/python/keras/_impl/keras/datasets") -add_python_module("tensorflow/python/keras/_impl/keras/engine") -add_python_module("tensorflow/python/keras/_impl/keras/layers") -add_python_module("tensorflow/python/keras/_impl/keras/preprocessing") -add_python_module("tensorflow/python/keras/_impl/keras/utils") -add_python_module("tensorflow/python/keras/_impl/keras/wrappers") -add_python_module("tensorflow/python/kernel_tests") -add_python_module("tensorflow/python/kernel_tests/distributions") -add_python_module("tensorflow/python/kernel_tests/linalg") -add_python_module("tensorflow/python/layers") -add_python_module("tensorflow/python/lib") -add_python_module("tensorflow/python/lib/core") -add_python_module("tensorflow/python/lib/io") -add_python_module("tensorflow/python/ops") -add_python_module("tensorflow/python/ops/distributions") -add_python_module("tensorflow/python/ops/linalg") -add_python_module("tensorflow/python/ops/losses") -add_python_module("tensorflow/python/platform") -add_python_module("tensorflow/python/platform/default") -add_python_module("tensorflow/python/platform/summary") -add_python_module("tensorflow/python/profiler/") -add_python_module("tensorflow/python/profiler/internal") -add_python_module("tensorflow/python/saved_model") -add_python_module("tensorflow/python/summary") -add_python_module("tensorflow/python/summary/writer") -add_python_module("tensorflow/python/tools") -add_python_module("tensorflow/python/training") -add_python_module("tensorflow/python/user_ops") -add_python_module("tensorflow/python/util") -add_python_module("tensorflow/python/util/protobuf") -add_python_module("tensorflow/tools") -add_python_module("tensorflow/tools/graph_transforms") -add_python_module("tensorflow/contrib") -add_python_module("tensorflow/contrib/all_reduce") -add_python_module("tensorflow/contrib/all_reduce/python") -add_python_module("tensorflow/contrib/android") -add_python_module("tensorflow/contrib/android/java") -add_python_module("tensorflow/contrib/android/java/org") -add_python_module("tensorflow/contrib/android/java/org/tensorflow") -add_python_module("tensorflow/contrib/android/java/org/tensorflow/contrib") -add_python_module("tensorflow/contrib/android/java/org/tensorflow/contrib/android") -add_python_module("tensorflow/contrib/android/jni") -add_python_module("tensorflow/contrib/bayesflow") -add_python_module("tensorflow/contrib/bayesflow/examples") -add_python_module("tensorflow/contrib/bayesflow/examples/reinforce_simple") -add_python_module("tensorflow/contrib/bayesflow/python") -add_python_module("tensorflow/contrib/bayesflow/python/kernel_tests") -add_python_module("tensorflow/contrib/bayesflow/python/ops") -add_python_module("tensorflow/contrib/boosted_trees") -add_python_module("tensorflow/contrib/boosted_trees/estimator_batch") -add_python_module("tensorflow/contrib/boosted_trees/ops") -add_python_module("tensorflow/contrib/boosted_trees/proto") -add_python_module("tensorflow/contrib/boosted_trees/python") -add_python_module("tensorflow/contrib/boosted_trees/python/kernel_tests") -add_python_module("tensorflow/contrib/boosted_trees/python/ops") -add_python_module("tensorflow/contrib/cloud") -add_python_module("tensorflow/contrib/cloud/kernels") -add_python_module("tensorflow/contrib/cloud/ops") -add_python_module("tensorflow/contrib/cloud/python") -add_python_module("tensorflow/contrib/cloud/python/ops") -add_python_module("tensorflow/contrib/cluster_resolver") -add_python_module("tensorflow/contrib/cluster_resolver/python") -add_python_module("tensorflow/contrib/cluster_resolver/python/training") -add_python_module("tensorflow/contrib/compiler") -add_python_module("tensorflow/contrib/copy_graph") -add_python_module("tensorflow/contrib/copy_graph/python") -add_python_module("tensorflow/contrib/copy_graph/python/util") -add_python_module("tensorflow/contrib/crf") -add_python_module("tensorflow/contrib/crf/python") -add_python_module("tensorflow/contrib/crf/python/kernel_tests") -add_python_module("tensorflow/contrib/crf/python/ops") -add_python_module("tensorflow/contrib/cudnn_rnn") -add_python_module("tensorflow/contrib/cudnn_rnn/kernels") -add_python_module("tensorflow/contrib/cudnn_rnn/ops") -add_python_module("tensorflow/contrib/cudnn_rnn/python") -add_python_module("tensorflow/contrib/cudnn_rnn/python/kernel_tests") -add_python_module("tensorflow/contrib/cudnn_rnn/python/layers") -add_python_module("tensorflow/contrib/cudnn_rnn/python/ops") -add_python_module("tensorflow/contrib/data") -add_python_module("tensorflow/contrib/data/python") -add_python_module("tensorflow/contrib/data/python/kernel_tests") -add_python_module("tensorflow/contrib/data/python/ops") -add_python_module("tensorflow/contrib/decision_trees") -add_python_module("tensorflow/contrib/decision_trees/proto") -add_python_module("tensorflow/contrib/deprecated") -add_python_module("tensorflow/contrib/distributions") -add_python_module("tensorflow/contrib/distributions/python") -add_python_module("tensorflow/contrib/distributions/python/kernel_tests") -add_python_module("tensorflow/contrib/distributions/python/ops") -add_python_module("tensorflow/contrib/distributions/python/ops/bijectors") -add_python_module("tensorflow/contrib/eager") -add_python_module("tensorflow/contrib/eager/python") -add_python_module("tensorflow/contrib/estimator") -add_python_module("tensorflow/contrib/estimator/python") -add_python_module("tensorflow/contrib/estimator/python/estimator") -add_python_module("tensorflow/contrib/factorization") -add_python_module("tensorflow/contrib/factorization/examples") -add_python_module("tensorflow/contrib/factorization/kernels") -add_python_module("tensorflow/contrib/factorization/ops") -add_python_module("tensorflow/contrib/factorization/python") -add_python_module("tensorflow/contrib/factorization/python/kernel_tests") -add_python_module("tensorflow/contrib/factorization/python/ops") -add_python_module("tensorflow/contrib/ffmpeg") -add_python_module("tensorflow/contrib/ffmpeg/default") -add_python_module("tensorflow/contrib/ffmpeg/testdata") -add_python_module("tensorflow/contrib/framework") -add_python_module("tensorflow/contrib/framework/kernels") -add_python_module("tensorflow/contrib/framework/ops") -add_python_module("tensorflow/contrib/framework/python") -add_python_module("tensorflow/contrib/framework/python/framework") -add_python_module("tensorflow/contrib/framework/python/ops") -add_python_module("tensorflow/contrib/gan") -add_python_module("tensorflow/contrib/gan/python") -add_python_module("tensorflow/contrib/gan/python/eval") -add_python_module("tensorflow/contrib/gan/python/eval/python") -add_python_module("tensorflow/contrib/gan/python/features") -add_python_module("tensorflow/contrib/gan/python/features/python") -add_python_module("tensorflow/contrib/gan/python/estimator") -add_python_module("tensorflow/contrib/gan/python/estimator/python") -add_python_module("tensorflow/contrib/gan/python/losses") -add_python_module("tensorflow/contrib/gan/python/losses/python") -add_python_module("tensorflow/contrib/graph_editor") -add_python_module("tensorflow/contrib/graph_editor/examples") -add_python_module("tensorflow/contrib/graph_editor/tests") -add_python_module("tensorflow/contrib/grid_rnn") -add_python_module("tensorflow/contrib/grid_rnn/python") -add_python_module("tensorflow/contrib/grid_rnn/python/kernel_tests") -add_python_module("tensorflow/contrib/grid_rnn/python/ops") -add_python_module("tensorflow/contrib/hooks") -add_python_module("tensorflow/contrib/image") -add_python_module("tensorflow/contrib/image/ops") -add_python_module("tensorflow/contrib/image/python") -add_python_module("tensorflow/contrib/image/python/ops") -add_python_module("tensorflow/contrib/input_pipeline") -add_python_module("tensorflow/contrib/input_pipeline/ops") -add_python_module("tensorflow/contrib/input_pipeline/python") -add_python_module("tensorflow/contrib/input_pipeline/python/ops") -add_python_module("tensorflow/contrib/integrate") -add_python_module("tensorflow/contrib/integrate/python") -add_python_module("tensorflow/contrib/integrate/python/ops") -add_python_module("tensorflow/contrib/ios_examples") -add_python_module("tensorflow/contrib/ios_examples/benchmark") -add_python_module("tensorflow/contrib/ios_examples/benchmark/benchmark.xcodeproj") -add_python_module("tensorflow/contrib/ios_examples/benchmark/data") -add_python_module("tensorflow/contrib/ios_examples/camera") -add_python_module("tensorflow/contrib/ios_examples/camera/camera_example.xcodeproj") -add_python_module("tensorflow/contrib/ios_examples/camera/en.lproj") -add_python_module("tensorflow/contrib/ios_examples/simple") -add_python_module("tensorflow/contrib/ios_examples/simple/data") -add_python_module("tensorflow/contrib/ios_examples/simple/tf_ios_makefile_example.xcodeproj") -add_python_module("tensorflow/contrib/keras") -add_python_module("tensorflow/contrib/keras/api") -add_python_module("tensorflow/contrib/keras/api/keras") -add_python_module("tensorflow/contrib/keras/api/keras/activations") -add_python_module("tensorflow/contrib/keras/api/keras/applications") -add_python_module("tensorflow/contrib/keras/api/keras/applications/inception_v3") -add_python_module("tensorflow/contrib/keras/api/keras/applications/mobilenet") -add_python_module("tensorflow/contrib/keras/api/keras/applications/resnet50") -add_python_module("tensorflow/contrib/keras/api/keras/applications/vgg16") -add_python_module("tensorflow/contrib/keras/api/keras/applications/vgg19") -add_python_module("tensorflow/contrib/keras/api/keras/applications/xception") -add_python_module("tensorflow/contrib/keras/api/keras/backend") -add_python_module("tensorflow/contrib/keras/api/keras/callbacks") -add_python_module("tensorflow/contrib/keras/api/keras/constraints") -add_python_module("tensorflow/contrib/keras/api/keras/datasets") -add_python_module("tensorflow/contrib/keras/api/keras/datasets/boston_housing") -add_python_module("tensorflow/contrib/keras/api/keras/datasets/cifar10") -add_python_module("tensorflow/contrib/keras/api/keras/datasets/cifar100") -add_python_module("tensorflow/contrib/keras/api/keras/datasets/imdb") -add_python_module("tensorflow/contrib/keras/api/keras/datasets/mnist") -add_python_module("tensorflow/contrib/keras/api/keras/datasets/reuters") -add_python_module("tensorflow/contrib/keras/api/keras/initializers") -add_python_module("tensorflow/contrib/keras/api/keras/layers") -add_python_module("tensorflow/contrib/keras/api/keras/losses") -add_python_module("tensorflow/contrib/keras/api/keras/metrics") -add_python_module("tensorflow/contrib/keras/api/keras/models") -add_python_module("tensorflow/contrib/keras/api/keras/optimizers") -add_python_module("tensorflow/contrib/keras/api/keras/preprocessing") -add_python_module("tensorflow/contrib/keras/api/keras/preprocessing/image") -add_python_module("tensorflow/contrib/keras/api/keras/preprocessing/sequence") -add_python_module("tensorflow/contrib/keras/api/keras/preprocessing/text") -add_python_module("tensorflow/contrib/keras/api/keras/regularizers") -add_python_module("tensorflow/contrib/keras/api/keras/utils") -add_python_module("tensorflow/contrib/keras/api/keras/wrappers") -add_python_module("tensorflow/contrib/keras/api/keras/wrappers/scikit_learn") -add_python_module("tensorflow/contrib/keras/python") -add_python_module("tensorflow/contrib/keras/python/keras") -add_python_module("tensorflow/contrib/keras/python/keras/applications") -add_python_module("tensorflow/contrib/keras/python/keras/datasets") -add_python_module("tensorflow/contrib/keras/python/keras/engine") -add_python_module("tensorflow/contrib/keras/python/keras/layers") -add_python_module("tensorflow/contrib/keras/python/keras/preprocessing") -add_python_module("tensorflow/contrib/keras/python/keras/utils") -add_python_module("tensorflow/contrib/keras/python/keras/wrappers") -add_python_module("tensorflow/contrib/kernel_methods") -add_python_module("tensorflow/contrib/kernel_methods/python") -add_python_module("tensorflow/contrib/kernel_methods/python/mappers") -add_python_module("tensorflow/contrib/kfac") -add_python_module("tensorflow/contrib/kfac/examples") -add_python_module("tensorflow/contrib/kfac/python") -add_python_module("tensorflow/contrib/kfac/python/ops") -add_python_module("tensorflow/contrib/labeled_tensor") -add_python_module("tensorflow/contrib/labeled_tensor/python") -add_python_module("tensorflow/contrib/labeled_tensor/python/ops") -add_python_module("tensorflow/contrib/layers") -add_python_module("tensorflow/contrib/layers/kernels") -add_python_module("tensorflow/contrib/layers/ops") -add_python_module("tensorflow/contrib/layers/python") -add_python_module("tensorflow/contrib/layers/python/kernel_tests") -add_python_module("tensorflow/contrib/layers/python/layers") -add_python_module("tensorflow/contrib/layers/python/ops") -add_python_module("tensorflow/contrib/learn") -add_python_module("tensorflow/contrib/learn/python") -add_python_module("tensorflow/contrib/learn/python/learn") -add_python_module("tensorflow/contrib/learn/python/learn/dataframe") -add_python_module("tensorflow/contrib/learn/python/learn/dataframe/queues") -add_python_module("tensorflow/contrib/learn/python/learn/dataframe/transforms") -add_python_module("tensorflow/contrib/learn/python/learn/datasets") -add_python_module("tensorflow/contrib/learn/python/learn/datasets/data") -add_python_module("tensorflow/contrib/learn/python/learn/estimators") -add_python_module("tensorflow/contrib/learn/python/learn/learn_io") -add_python_module("tensorflow/contrib/learn/python/learn/ops") -add_python_module("tensorflow/contrib/learn/python/learn/preprocessing") -add_python_module("tensorflow/contrib/learn/python/learn/preprocessing/tests") -add_python_module("tensorflow/contrib/learn/python/learn/tests") -add_python_module("tensorflow/contrib/learn/python/learn/tests/dataframe") -add_python_module("tensorflow/contrib/learn/python/learn/utils") -add_python_module("tensorflow/contrib/legacy_seq2seq") -add_python_module("tensorflow/contrib/legacy_seq2seq/python") -add_python_module("tensorflow/contrib/legacy_seq2seq/python/ops") -add_python_module("tensorflow/contrib/linalg") -add_python_module("tensorflow/contrib/linalg/python") -add_python_module("tensorflow/contrib/linalg/python/ops") -add_python_module("tensorflow/contrib/linalg/python/kernel_tests") -add_python_module("tensorflow/contrib/linear_optimizer") -add_python_module("tensorflow/contrib/linear_optimizer/kernels") -add_python_module("tensorflow/contrib/linear_optimizer/kernels/g3doc") -add_python_module("tensorflow/contrib/linear_optimizer/python") -add_python_module("tensorflow/contrib/linear_optimizer/python/kernel_tests") -add_python_module("tensorflow/contrib/linear_optimizer/python/ops") +FILE(READ python_modules.txt python_modules) +# Convert file contents into a CMake list (where each element in the list is one line of the file) +STRING(REGEX REPLACE ";" "\\\\;" python_modules "${python_modules}") +STRING(REGEX REPLACE "\n" ";" python_modules "${python_modules}") + +foreach(python_module ${python_modules}) + if(NOT python_module MATCHES "\#") + if(NOT EXISTS "${tensorflow_source_dir}/${python_module}") + message(SEND_ERROR "Python module not found: ${python_module}") + endif() + add_python_module(${python_module}) + endif() +endforeach(python_module) + add_custom_command(TARGET tf_python_touchup_modules PRE_BUILD COMMAND ${CMAKE_COMMAND} -E make_directory "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/lite") @@ -514,161 +230,6 @@ add_custom_command( TARGET tf_python_copy_scripts_to_destination PRE_BUILD COMMAND ${CMAKE_COMMAND} -E touch ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/lite/python/lite.py) -add_python_module("tensorflow/contrib/lookup") -add_python_module("tensorflow/contrib/losses") -add_python_module("tensorflow/contrib/losses/python") -add_python_module("tensorflow/contrib/losses/python/losses") -add_python_module("tensorflow/contrib/losses/python/metric_learning") -add_python_module("tensorflow/contrib/makefile") -add_python_module("tensorflow/contrib/makefile/test") -add_python_module("tensorflow/contrib/memory_stats") -add_python_module("tensorflow/contrib/memory_stats/kernels") -add_python_module("tensorflow/contrib/memory_stats/ops") -add_python_module("tensorflow/contrib/memory_stats/python") -add_python_module("tensorflow/contrib/memory_stats/python/kernel_tests") -add_python_module("tensorflow/contrib/memory_stats/python/ops") -add_python_module("tensorflow/contrib/meta_graph_transform") -add_python_module("tensorflow/contrib/metrics") -add_python_module("tensorflow/contrib/metrics/kernels") -add_python_module("tensorflow/contrib/metrics/ops") -add_python_module("tensorflow/contrib/metrics/python") -add_python_module("tensorflow/contrib/metrics/python/kernel_tests") -add_python_module("tensorflow/contrib/metrics/python/metrics") -add_python_module("tensorflow/contrib/metrics/python/ops") -add_python_module("tensorflow/contrib/model_pruning") -add_python_module("tensorflow/contrib/model_pruning/examples") -add_python_module("tensorflow/contrib/model_pruning/examples/cifar10") -add_python_module("tensorflow/contrib/model_pruning/python") -add_python_module("tensorflow/contrib/model_pruning/python/layers") -add_python_module("tensorflow/contrib/ndlstm") -add_python_module("tensorflow/contrib/ndlstm/python") -add_python_module("tensorflow/contrib/nn") -add_python_module("tensorflow/contrib/nn/python") -add_python_module("tensorflow/contrib/nn/python/ops") -add_python_module("tensorflow/contrib/nccl") -add_python_module("tensorflow/contrib/nccl/kernels") -add_python_module("tensorflow/contrib/nccl/ops") -add_python_module("tensorflow/contrib/nccl/python") -add_python_module("tensorflow/contrib/nccl/python/ops") -add_python_module("tensorflow/contrib/nearest_neighbor/kernels") -add_python_module("tensorflow/contrib/nearest_neighbor/ops") -add_python_module("tensorflow/contrib/nearest_neighbor/python") -add_python_module("tensorflow/contrib/nearest_neighbor/python/kernel_tests") -add_python_module("tensorflow/contrib/nearest_neighbor/python/ops") -add_python_module("tensorflow/contrib/opt") -add_python_module("tensorflow/contrib/opt/python") -add_python_module("tensorflow/contrib/opt/python/training") -add_python_module("tensorflow/contrib/pi_examples") -add_python_module("tensorflow/contrib/pi_examples/camera") -add_python_module("tensorflow/contrib/pi_examples/label_image") -add_python_module("tensorflow/contrib/pi_examples/label_image/data") -add_python_module("tensorflow/contrib/periodic_resample") -add_python_module("tensorflow/contrib/periodic_resample/python") -add_python_module("tensorflow/contrib/periodic_resample/python/ops") -add_python_module("tensorflow/contrib/periodic_resample/python/kernel_tests") -add_python_module("tensorflow/contrib/predictor") -add_python_module("tensorflow/contrib/quantization") -add_python_module("tensorflow/contrib/quantization/python") -add_python_module("tensorflow/contrib/quantize") -add_python_module("tensorflow/contrib/quantize/python") -add_python_module("tensorflow/contrib/remote_fused_graph/pylib") -add_python_module("tensorflow/contrib/remote_fused_graph/pylib/python") -add_python_module("tensorflow/contrib/remote_fused_graph/pylib/python/ops") -add_python_module("tensorflow/contrib/resampler") -add_python_module("tensorflow/contrib/resampler/kernels") -add_python_module("tensorflow/contrib/resampler/ops") -add_python_module("tensorflow/contrib/resampler/python") -add_python_module("tensorflow/contrib/resampler/python/ops") -add_python_module("tensorflow/contrib/rnn") -add_python_module("tensorflow/contrib/rnn/kernels") -add_python_module("tensorflow/contrib/rnn/ops") -add_python_module("tensorflow/contrib/rnn/python") -add_python_module("tensorflow/contrib/rnn/python/kernel_tests") -add_python_module("tensorflow/contrib/rnn/python/ops") -add_python_module("tensorflow/contrib/saved_model") -add_python_module("tensorflow/contrib/saved_model/python") -add_python_module("tensorflow/contrib/saved_model/python/saved_model") -add_python_module("tensorflow/contrib/seq2seq") -add_python_module("tensorflow/contrib/seq2seq/kernels") -add_python_module("tensorflow/contrib/seq2seq/ops") -add_python_module("tensorflow/contrib/seq2seq/python") -add_python_module("tensorflow/contrib/seq2seq/python/kernel_tests") -add_python_module("tensorflow/contrib/seq2seq/python/ops") -add_python_module("tensorflow/contrib/session_bundle") -add_python_module("tensorflow/contrib/session_bundle/example") -add_python_module("tensorflow/contrib/session_bundle/testdata") -add_python_module("tensorflow/contrib/signal") -add_python_module("tensorflow/contrib/signal/python") -add_python_module("tensorflow/contrib/signal/python/ops") -add_python_module("tensorflow/contrib/slim") -add_python_module("tensorflow/contrib/slim/python") -add_python_module("tensorflow/contrib/slim/python/slim") -add_python_module("tensorflow/contrib/slim/python/slim/data") -add_python_module("tensorflow/contrib/slim/python/slim/nets") -add_python_module("tensorflow/contrib/solvers") -add_python_module("tensorflow/contrib/solvers/python") -add_python_module("tensorflow/contrib/solvers/python/ops") -add_python_module("tensorflow/contrib/sparsemax") -add_python_module("tensorflow/contrib/sparsemax/python") -add_python_module("tensorflow/contrib/sparsemax/python/ops") -add_python_module("tensorflow/contrib/specs") -add_python_module("tensorflow/contrib/specs/python") -add_python_module("tensorflow/contrib/staging") -add_python_module("tensorflow/contrib/stat_summarizer") -add_python_module("tensorflow/contrib/stateless") -add_python_module("tensorflow/contrib/tensorboard") -add_python_module("tensorflow/contrib/tensorboard/plugins") -add_python_module("tensorflow/contrib/tensorboard/plugins/projector") -add_python_module("tensorflow/contrib/tensor_forest") -add_python_module("tensorflow/contrib/tensor_forest/client") -add_python_module("tensorflow/contrib/tensor_forest/core") -add_python_module("tensorflow/contrib/tensor_forest/core/ops") -add_python_module("tensorflow/contrib/tensor_forest/data") -add_python_module("tensorflow/contrib/tensor_forest/hybrid") -add_python_module("tensorflow/contrib/tensor_forest/hybrid/core") -add_python_module("tensorflow/contrib/tensor_forest/hybrid/core/ops") -add_python_module("tensorflow/contrib/tensor_forest/hybrid/ops") -add_python_module("tensorflow/contrib/tensor_forest/hybrid/python") -add_python_module("tensorflow/contrib/tensor_forest/hybrid/python/kernel_tests") -add_python_module("tensorflow/contrib/tensor_forest/hybrid/python/layers") -add_python_module("tensorflow/contrib/tensor_forest/hybrid/python/models") -add_python_module("tensorflow/contrib/tensor_forest/hybrid/python/ops") -add_python_module("tensorflow/contrib/tensor_forest/python") -add_python_module("tensorflow/contrib/tensor_forest/python/kernel_tests") -add_python_module("tensorflow/contrib/tensor_forest/python/ops") -add_python_module("tensorflow/contrib/testing") -add_python_module("tensorflow/contrib/testing/python") -add_python_module("tensorflow/contrib/testing/python/framework") -add_python_module("tensorflow/contrib/text") -add_python_module("tensorflow/contrib/text/kernels") -add_python_module("tensorflow/contrib/text/ops") -add_python_module("tensorflow/contrib/text/python") -add_python_module("tensorflow/contrib/text/python/ops") -add_python_module("tensorflow/contrib/tfprof") -add_python_module("tensorflow/contrib/timeseries") -add_python_module("tensorflow/contrib/timeseries/examples") -add_python_module("tensorflow/contrib/timeseries/examples/data") -add_python_module("tensorflow/contrib/timeseries/python") -add_python_module("tensorflow/contrib/timeseries/python/timeseries") -add_python_module("tensorflow/contrib/timeseries/python/timeseries/state_space_models") -add_python_module("tensorflow/contrib/tpu") -add_python_module("tensorflow/contrib/tpu/ops") -add_python_module("tensorflow/contrib/tpu/profiler") -add_python_module("tensorflow/contrib/tpu/python") -add_python_module("tensorflow/contrib/tpu/python/ops") -add_python_module("tensorflow/contrib/tpu/python/profiler") -add_python_module("tensorflow/contrib/tpu/python/tpu") -add_python_module("tensorflow/contrib/training") -add_python_module("tensorflow/contrib/training/python") -add_python_module("tensorflow/contrib/training/python/training") -add_python_module("tensorflow/contrib/util") -add_python_module("tensorflow/contrib/reduce_slice_ops") -add_python_module("tensorflow/contrib/reduce_slice_ops/kernels") -add_python_module("tensorflow/contrib/reduce_slice_ops/ops") -add_python_module("tensorflow/contrib/reduce_slice_ops/python") -add_python_module("tensorflow/contrib/reduce_slice_ops/python/kernel_tests") -add_python_module("tensorflow/contrib/reduce_slice_ops/python/ops") -add_python_module("tensorflow/contrib/summary") # Generate the tensorflow.python.platform.build_info module. set(BUILD_INFO_PY "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/platform/build_info.py") @@ -743,7 +304,7 @@ function(GENERATE_PYTHON_OP_LIB tf_python_op_lib_name) # containing the wrappers. add_custom_command( OUTPUT ${GENERATE_PYTHON_OP_LIB_DESTINATION} - COMMAND ${tf_python_op_lib_name}_gen_python @${tensorflow_source_dir}/tensorflow/python/ops/hidden_ops.txt ${require_shape_fn} > ${GENERATE_PYTHON_OP_LIB_DESTINATION} + COMMAND ${tf_python_op_lib_name}_gen_python ${tensorflow_source_dir}/tensorflow/core/api_def/base_api,${tensorflow_source_dir}/tensorflow/core/api_def/python_api @${tensorflow_source_dir}/tensorflow/python/ops/hidden_ops.txt ${require_shape_fn} > ${GENERATE_PYTHON_OP_LIB_DESTINATION} DEPENDS ${tf_python_op_lib_name}_gen_python ) @@ -766,6 +327,7 @@ GENERATE_PYTHON_OP_LIB("dataset_ops") GENERATE_PYTHON_OP_LIB("image_ops") GENERATE_PYTHON_OP_LIB("io_ops") GENERATE_PYTHON_OP_LIB("linalg_ops") +GENERATE_PYTHON_OP_LIB("list_ops") GENERATE_PYTHON_OP_LIB("logging_ops") GENERATE_PYTHON_OP_LIB("lookup_ops") GENERATE_PYTHON_OP_LIB("nn_ops") @@ -797,6 +359,8 @@ GENERATE_PYTHON_OP_LIB("contrib_boosted_trees_quantiles_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/boosted_trees/python/ops/gen_quantile_ops.py) GENERATE_PYTHON_OP_LIB("contrib_boosted_trees_stats_accumulator_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/boosted_trees/python/ops/gen_stats_accumulator_ops.py) +GENERATE_PYTHON_OP_LIB("contrib_coder_ops" + DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/coder/python/ops/gen_coder_ops.py) GENERATE_PYTHON_OP_LIB("contrib_cudnn_rnn_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/cudnn_rnn/ops/gen_cudnn_rnn_ops.py) GENERATE_PYTHON_OP_LIB("contrib_data_prefetching_ops" @@ -896,6 +460,8 @@ set (pywrap_tensorflow_internal_src "${tensorflow_source_dir}/tensorflow/python/framework/cpp_shape_inference.cc" "${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen.h" "${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen.cc" + "${tensorflow_source_dir}/tensorflow/python/lib/core/bfloat16.h" + "${tensorflow_source_dir}/tensorflow/python/lib/core/bfloat16.cc" "${tensorflow_source_dir}/tensorflow/python/lib/core/numpy.h" "${tensorflow_source_dir}/tensorflow/python/lib/core/numpy.cc" "${tensorflow_source_dir}/tensorflow/python/lib/core/ndarray_tensor.h" @@ -906,6 +472,8 @@ set (pywrap_tensorflow_internal_src "${tensorflow_source_dir}/tensorflow/python/lib/core/py_func.cc" "${tensorflow_source_dir}/tensorflow/python/lib/core/py_seq_tensor.h" "${tensorflow_source_dir}/tensorflow/python/lib/core/py_seq_tensor.cc" + "${tensorflow_source_dir}/tensorflow/python/lib/core/py_util.h" + "${tensorflow_source_dir}/tensorflow/python/lib/core/py_util.cc" "${tensorflow_source_dir}/tensorflow/python/lib/core/safe_ptr.h" "${tensorflow_source_dir}/tensorflow/python/lib/core/safe_ptr.cc" "${tensorflow_source_dir}/tensorflow/python/lib/io/py_record_reader.h" diff --git a/tensorflow/contrib/cmake/tf_tests.cmake b/tensorflow/contrib/cmake/tf_tests.cmake index 7884631d6d27bd5375b80d7eb5593d10d709e450..2e79eadf7f566690a7742757ceb56e147ebd6ea0 100644 --- a/tensorflow/contrib/cmake/tf_tests.cmake +++ b/tensorflow/contrib/cmake/tf_tests.cmake @@ -139,17 +139,21 @@ if (tensorflow_BUILD_PYTHON_TESTS) file(GLOB_RECURSE tf_test_src_py ${tf_test_rnn_src_py} + "${tensorflow_source_dir}/tensorflow/python/data/kernel_tests/*.py" "${tensorflow_source_dir}/tensorflow/python/debug/cli/*_test.py" "${tensorflow_source_dir}/tensorflow/python/debug/lib/*_test.py" "${tensorflow_source_dir}/tensorflow/python/debug/wrappers/*_test.py" "${tensorflow_source_dir}/tensorflow/contrib/estimator/python/estimator/*_test.py" "${tensorflow_source_dir}/tensorflow/python/kernel_tests/*.py" "${tensorflow_source_dir}/tensorflow/python/meta_graph_transform/*_test.py" + "${tensorflow_source_dir}/tensorflow/python/ops/quantized_conv_ops_test.py" + "${tensorflow_source_dir}/tensorflow/python/ops/quantized_ops_test.py" "${tensorflow_source_dir}/tensorflow/python/platform/build_info_test.py" "${tensorflow_source_dir}/tensorflow/python/profiler/*_test.py" "${tensorflow_source_dir}/tensorflow/python/profiler/internal/*_test.py" "${tensorflow_source_dir}/tensorflow/python/saved_model/*_test.py" "${tensorflow_source_dir}/tensorflow/python/training/*_test.py" + "${tensorflow_source_dir}/tensorflow/contrib/coder/*_test.py" "${tensorflow_source_dir}/tensorflow/contrib/data/*_test.py" "${tensorflow_source_dir}/tensorflow/contrib/factorization/*_test.py" "${tensorflow_source_dir}/tensorflow/contrib/image/*_test.py" @@ -187,6 +191,7 @@ if (tensorflow_BUILD_PYTHON_TESTS) "${tensorflow_source_dir}/tensorflow/python/profiler/pprof_profiler_test.py" # flaky test "${tensorflow_source_dir}/tensorflow/python/profiler/internal/run_metadata_test.py" + "${tensorflow_source_dir}/tensorflow/python/profiler/model_analyzer_test.py" # Fails because uses data dependencies with bazel "${tensorflow_source_dir}/tensorflow/python/saved_model/saved_model_test.py" # requires scipy @@ -217,16 +222,20 @@ if (tensorflow_BUILD_PYTHON_TESTS) # TFDBG grpc:// mode is not yet available on Windows. "${tensorflow_source_dir}/tensorflow/python/debug/lib/dist_session_debug_grpc_test.py" "${tensorflow_source_dir}/tensorflow/python/debug/lib/session_debug_grpc_test.py" + "${tensorflow_source_dir}/tensorflow/python/debug/lib/source_remote_test.py" # stl on windows handles overflows different "${tensorflow_source_dir}/tensorflow/python/kernel_tests/as_string_op_test.py" "${tensorflow_source_dir}/tensorflow/python/kernel_tests/string_to_number_op_test.py" "${tensorflow_source_dir}/tensorflow/python/kernel_tests/clip_ops_test.py" + "${tensorflow_source_dir}/tensorflow/python/kernel_tests/list_ops_test.py" # Needs portpicker. "${tensorflow_source_dir}/tensorflow/python/kernel_tests/tensor_array_ops_test.py" # Needs portpicker. # Numerical issues, calculations off. "${tensorflow_source_dir}/tensorflow/python/kernel_tests/concat_op_test.py" "${tensorflow_source_dir}/tensorflow/contrib/factorization/python/ops/wals_test.py" "${tensorflow_source_dir}/tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py" "${tensorflow_source_dir}/tensorflow/python/keras/_impl/keras/utils/data_utils_test.py" + "${tensorflow_source_dir}/tensorflow/python/keras/_impl/keras/backend_test.py" + "${tensorflow_source_dir}/tensorflow/python/keras/_impl/keras/preprocessing/image_test.py" # Float division by zero "${tensorflow_source_dir}/tensorflow/python/kernel_tests/benchmark_test.py" # Flaky, for unknown reasons. Cannot reproduce in terminal. Revisit once we can get stack traces. @@ -235,11 +244,11 @@ if (tensorflow_BUILD_PYTHON_TESTS) "${tensorflow_source_dir}/tensorflow/python/training/sync_replicas_optimizer_test.py" "${tensorflow_source_dir}/tensorflow/python/debug/lib/session_debug_grpc_test.py" "${tensorflow_source_dir}tensorflow/python/training/localhost_cluster_performance_test.py" - "${tensorflow_source_dir}/tensorflow/python/kernel_tests/iterator_ops_cluster_test.py" + "${tensorflow_source_dir}/tensorflow/python/data/kernel_tests/iterator_ops_cluster_test.py" "${tensorflow_source_dir}/tensorflow/python/kernel_tests/functional_ops_test.py" "${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/iterator_ops_cluster_test.py" # Type error in testRemoteIteratorUsingRemoteCallOpDirectSessionGPUCPU. - "${tensorflow_source_dir}/tensorflow/python/kernel_tests/iterator_ops_test.py" + "${tensorflow_source_dir}/tensorflow/python/data/kernel_tests/iterator_ops_test.py" "${tensorflow_source_dir}/tensorflow/python/kernel_tests/self_adjoint_eig_op_test.py" "${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py" # IteratorGetMax OutOfRangeError @@ -263,9 +272,9 @@ if (tensorflow_BUILD_PYTHON_TESTS) "${tensorflow_source_dir}/tensorflow/python/kernel_tests/linalg_grad_test.py" # cudaSolver handle creation fails. "${tensorflow_source_dir}/tensorflow/python/kernel_tests/array_ops_test.py" # depends on python/framework/test_ops # Dataset tests - "${tensorflow_source_dir}/tensorflow/python/kernel_tests/dataset_constructor_op_test.py" # Segfaults on windows + "${tensorflow_source_dir}/tensorflow/python/data/kernel_tests/dataset_constructor_op_test.py" # Segfaults on windows "${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py" # Segfaults on Windows. - "${tensorflow_source_dir}/tensorflow/python/kernel_tests/iterator_ops_cluster_test.py" + "${tensorflow_source_dir}/tensorflow/python/data/kernel_tests/iterator_ops_cluster_test.py" # Broken tensorboard test due to cmake issues. "${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/iterator_ops_cluster_test.py" # Needs portpicker "${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/sloppy_transformation_dataset_op_test.py" # b/65430561 @@ -296,6 +305,11 @@ if (tensorflow_BUILD_PYTHON_TESTS) # Test should only be run manually "${tensorflow_source_dir}/tensorflow/python/kernel_tests/reduction_ops_test_big.py" "${tensorflow_source_dir}/tensorflow/python/kernel_tests/svd_op_test.py" + # Depends on python/framework/test_ops + "${tensorflow_source_dir}/tensorflow/python/kernel_tests/array_ops_test.py" + "${tensorflow_source_dir}/tensorflow/python/kernel_tests/control_flow_util_test.py" + # Flaky replicate_model_fn_test + "${tensorflow_source_dir}/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py" # b/71901810 ) endif() list(REMOVE_ITEM tf_test_src_py ${tf_test_src_py_exclude}) @@ -363,7 +377,6 @@ if (tensorflow_BUILD_CC_TESTS) "${tensorflow_source_dir}/tensorflow/core/distributed_runtime/tensor_coding_test.cc" "${tensorflow_source_dir}/tensorflow/core/kernels/remote_fused_graph_rewriter_transform_test.cc" "${tensorflow_source_dir}/tensorflow/core/kernels/hexagon/graph_transferer_test.cc" - "${tensorflow_source_dir}/tensorflow/core/kernels/hexagon/quantized_matmul_op_for_hexagon_test.cc" ) if (NOT tensorflow_ENABLE_GPU) diff --git a/tensorflow/contrib/coder/BUILD b/tensorflow/contrib/coder/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..ec3d550b70d2aaa23b989c44f3d86fa87cffb335 --- /dev/null +++ b/tensorflow/contrib/coder/BUILD @@ -0,0 +1,167 @@ +# Description: +# Contains entropy coding related modules. + +package(default_visibility = [ + "//learning/brain:__subpackages__", + "//tensorflow:__subpackages__", +]) + +licenses(["notice"]) # Apache 2.0 + +load( + "//tensorflow:tensorflow.bzl", + "tf_cc_test", + "tf_custom_op_library", + "tf_custom_op_py_library", + "tf_gen_op_libs", + "tf_gen_op_wrapper_py", + "tf_kernel_library", + "tf_py_test", +) + +cc_library( + name = "range_coder", + srcs = [ + "kernels/range_coder.cc", + ], + hdrs = [ + "kernels/range_coder.h", + ], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "range_coder_test", + size = "small", + srcs = ["kernels/range_coder_test.cc"], + deps = [ + ":range_coder", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +tf_gen_op_libs( + op_lib_names = ["coder_ops"], + deps = [ + "//tensorflow/core:lib", + ], +) + +tf_kernel_library( + name = "range_coder_ops", + srcs = [ + "kernels/range_coder_ops.cc", + "kernels/range_coder_ops_util.cc", + ], + hdrs = [ + "kernels/range_coder_ops_util.h", + ], + visibility = ["//visibility:public"], + deps = [ + ":coder_ops_op_lib", + ":range_coder", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + ], + alwayslink = 1, +) + +tf_cc_test( + name = "range_coder_ops_test", + size = "small", + srcs = ["kernels/range_coder_ops_test.cc"], + deps = [ + ":range_coder", + ":range_coder_ops", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/kernels:ops_testutil", + ], +) + +cc_library( + name = "all_ops", + deps = [":coder_ops_op_lib"], +) + +cc_library( + name = "all_kernels", + deps = [":range_coder_ops"], +) + +tf_custom_op_library( + name = "python/ops/_coder_ops.so", + srcs = [ + "kernels/range_coder.cc", + "kernels/range_coder.h", + "kernels/range_coder_ops.cc", + "kernels/range_coder_ops_util.cc", + "kernels/range_coder_ops_util.h", + "ops/coder_ops.cc", + ], +) + +tf_gen_op_wrapper_py( + name = "gen_coder_ops", + out = "python/ops/gen_coder_ops.py", + deps = [":coder_ops_op_lib"], +) + +tf_custom_op_py_library( + name = "coder_ops_py", + srcs = [ + "__init__.py", + "python/ops/coder_ops.py", + ], + dso = [ + ":python/ops/_coder_ops.so", + ], + kernels = [ + ":all_kernels", + ], + srcs_version = "PY2AND3", + deps = [ + ":gen_coder_ops", + "//tensorflow/contrib/util:util_py", + ], +) + +tf_py_test( + name = "coder_ops_py_test", + srcs = [ + "python/ops/coder_ops_test.py", + ], + additional_deps = [ + ":coder_ops_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python:math_ops", + "//tensorflow/python:random_ops", + ], + main = "python/ops/coder_ops_test.py", +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), +) diff --git a/tensorflow/contrib/coder/README.md b/tensorflow/contrib/coder/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e1e867db5aa701eb73ee43a47cd3dcc2dc783a04 --- /dev/null +++ b/tensorflow/contrib/coder/README.md @@ -0,0 +1,73 @@ +# Entropy coder + +This module contains range encoder and range decoder which can encode integer +data into string with cumulative distribution functions (CDF). + +## Data and CDF values + +The data to be encoded should be non-negative integers in half-open interval +`[0, m)`. Then a CDF is represented as an integral vector of length `m + 1` +where `CDF(i) = f(Pr(X < i) * 2^precision)` for i = 0,1,...,m, and `precision` +is an attribute in range `0 < precision <= 16`. The function `f` maps real +values into integers, e.g., round or floor. It is important that to encode a +number `i`, `CDF(i + 1) - CDF(i)` cannot be zero. + +Note that we used `Pr(X < i)` not `Pr(X <= i)`, and therefore CDF(0) = 0 always. + +## RangeEncode: data shapes and CDF shapes + +For each data element, its CDF has to be provided. Therefore if the shape of CDF +should be `data.shape + (m + 1,)` in NumPy-like notation. For example, if `data` +is a 2-D tensor of shape (10, 10) and its elements are in `[0, 64)`, then the +CDF tensor should have shape (10, 10, 65). + +This may make CDF tensor too large, and in many applications all data elements +may have the same probability distribution. To handle this, `RangeEncode` +supports limited broadcasting CDF into data. Broadcasting is limited in the +following sense: + +- All CDF axes but the last one is broadcasted into data but not the other way + around, +- The number of CDF axes does not extend, i.e., `CDF.ndim == data.ndim + 1`. + +In the previous example where data has shape (10, 10), the followings are +acceptable CDF shapes: + +- (10, 10, 65) +- (1, 10, 65) +- (10, 1, 65) +- (1, 1, 65) + +## RangeDecode + +`RangeEncode` encodes neither data shape nor termination character. Therefore +the decoder should know how many characters are encoded into the string, and +`RangeDecode` takes the encoded data shape as the second argument. The same +shape restrictions as `RangeEncode` inputs apply here. + +## Example + +```python +data = tf.random_uniform((128, 128), 0, 10, dtype=tf.int32) + +histogram = tf.bincount(data, minlength=10, maxlength=10) +cdf = tf.cumsum(histogram, exclusive=False) +# CDF should have length m + 1. +cdf = tf.pad(cdf, [[1, 0]]) +# CDF axis count must be one more than data. +cdf = tf.reshape(cdf, [1, 1, -1]) + +# Note that data has 2^14 elements, and therefore the sum of CDF is 2^14. +data = tf.cast(data, tf.int16) +encoded = coder.range_encode(data, cdf, precision=14) +decoded = coder.range_decode(encoded, tf.shape(data), cdf, precision=14) + +# data and decoded should be the same. +sess = tf.Session() +x, y = sess.run((data, decoded)) +assert np.all(x == y) +``` + +## Authors +Sung Jin Hwang (github: [ssjhv](https://github.com/ssjhv)) and Nick Johnston +(github: [nmjohn](https://github.com/nmjohn)) diff --git a/tensorflow/contrib/coder/__init__.py b/tensorflow/contrib/coder/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b7e663e6f1359f399cdaa80e037635a8f7546b37 --- /dev/null +++ b/tensorflow/contrib/coder/__init__.py @@ -0,0 +1,26 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Entropy code operations.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=wildcard-import +from tensorflow.contrib.coder.python.ops.coder_ops import * +# pylint: enable=wildcard-import + +from tensorflow.python.util.all_util import remove_undocumented +remove_undocumented(__name__) diff --git a/tensorflow/contrib/coder/kernels/range_coder.cc b/tensorflow/contrib/coder/kernels/range_coder.cc new file mode 100644 index 0000000000000000000000000000000000000000..f4f076b6c4e0c82cc297266bedc63034d5f5bf8b --- /dev/null +++ b/tensorflow/contrib/coder/kernels/range_coder.cc @@ -0,0 +1,374 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Range coder implementation, based on [1]. +// +// [1] G. N. N. Martin, "Range coding: an algorithm for removing redundancy from +// a digitised message", presented to the Video & Data Recording Conference, +// held in Southampton, July 24-27, 1979. +// +#include "tensorflow/contrib/coder/kernels/range_coder.h" + +#include +#include + +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +RangeEncoder::RangeEncoder(int precision) : precision_(precision) { + CHECK_GT(precision, 0); + CHECK_LE(precision, 16); +} + +void RangeEncoder::Encode(int32 lower, int32 upper, string* sink) { + // Input requirement: 0 <= lower < upper <= 2^precision. + DCHECK_LE(0, lower); + DCHECK_LT(lower, upper); + DCHECK_LE(upper, 1 << precision_); + + // `base` and `size` represent a half-open interval [base, base + size). + // Loop invariant: 2^16 <= size <= 2^32. + // + // Note that keeping size above 2^16 is important. Since the interval sizes + // are quantized to up to 16 bits, the smallest interval size the encode may + // handle is 2^-16. If size is smaller than 2^16, a small interval input may + // collapse the encoder range into an empty interval. + const uint64 size = static_cast(size_minus1_) + 1; + DCHECK_NE(size >> 16, 0); + + // For short notation, let u := lower and v := upper. + // + // The input u, v represents a half-open interval [u, v) / 2^precision. + // This narrows the current interval roughly to + // [base + (size * u) / 2^precision, base + (size * v) / 2^precision). + // + // TODO(sjhwang): Try rounding if it helps improve compression ratio, at the + // expense of more operations. In the test using Zipf distribution, the + // overhead over the theoretical compression ratio was ~0.01%. + // NOTE: The max value of `size` is 2^32 and size > 0. Therefore `size * u` + // can be rewritten as `(size - 1) * u + u` and all the computation can be + // done in 32-bit mode. If 32-bit multiply is faster, then rewrite. + const uint32 a = (size * static_cast(lower)) >> precision_; + const uint32 b = ((size * static_cast(upper)) >> precision_) - 1; + DCHECK_LE(a, b); + + // Let's confirm the RHS of a, b fit in uint32 type. + // Recall that 0 <= u < 2^precision, and size <= 2^32. Therefore + // (size * u) / 2^precision < size <= 2^32, + // and the value of a fits in uint32 type. Similarly, since v <= 2^precision, + // (size * v) / 2^precision - 1 <= size - 1 < 2^32. + // For lower bound of b, note that 1 <= v, 2^16 <= size, and 16 <= precision. + // Therefore (size * v) / 2^precision - 1 >= 2^16 / 2^precision - 1 >= 0. + + // The new interval is [base + a, base + b] = [base + a, base + b + 1). + base_ += a; // May overflow. + size_minus1_ = b - a; + const bool base_overflow = (base_ < a); + + // The encoder has two states. Let's call them state 0 and state 1. + // State 0 is when base < base + size <= 2^32. + // State 1 is when base < 2^32 < base + size. + // + // The encoder initially starts in state 0, with base = 0, size = 2^32. + // + // TODO(sjhwang): Requires some profiling, but the encoder stays in state 0 + // most of the time. Should optimize code for state 0. + // + // Each Encode() has up to two places where the interval changes: + // #1. Refine the interval. [base, base + size) -> [base + a, base + b + 1). + // #2. Expand interval if the new size is too small, + // and each change may cause a state transition. + // + // First, consider when the current state is 0. + // + // In this case, the next state after #1 is always state 0, since refining + // interval only shrinks the interval, therefore new_base + new_size <= 2^32. + // + // Let us explain #2. + // + // Recall that at the beginning of each Encode(), the encoder requires + // 2^16 < size <= 2^32. As precision <= 16, the new interval size can be as + // small as 1, but never zero. + // + // To keep size above 2^16, if new size is smaller than or equal to 2^16, the + // encoder would left-shift base and size by 16 bits: size' <- size * 2^16. + // Note that new size' is now in the range [2^16, 2^32]. + // + // Since size is left-shifted, the same should be applied to base as well. + // However, after the left-shift, base will then contain 48 bits instead of 32 + // bits. Therefore prior to the shift, The upper 16 bits in base should be + // stored somewhere else. + // + // If the upper 16 bits of all values in the interval were the same, i.e., if + // base[32:16] == (base + size - 1)[32:16], then base[32:16] can be written + // out to `output` string, since any further Encode() only narrows down the + // interval and that 16 bits would never change. + // + // If the upper 16 bits were not all the same, since this happens only when + // size <= 2^16, the upper 16 bits may differ only by one, i.e., + // base[32:16] + 1 == (base + size - 1)[32:16]. At this stage, it is not + // determined yet whether base[32:16] should be written to the output or + // (base[32:16] + 1) should be written to the output. In this case, + // (base[32:16] + 1) is temporarily stored in `delay`, and base is + // left-shifted by 16 bits. + // + // In the latter case, the condition implies that (base // 2^16) and + // ((base + size - 1) // 2^16) were different. Therefore after left-shift by + // 16 bits, the new (base + size) is greater than 2^32, i.e., the encoder + // transition to state 1. + // + // ==== Summary ==== + // To detect the current encoder state, + // state 0: delay == 0 iff (base mod 2^32) < (base + size) mod 2^32, + // state 1: delay != 0 iff (base + size) mod 2^32 <= base mod 2^32, + // because size <= 2^32. + // + // ==== Summary for state 0 ==== + // 1. Interval refinement does not cause state transition. + // 2. Interval expansion may cause state transition, depending on the upper 16 + // bits of base and base + size - 1. + // + // Now suppose the previous state was 1. This means that + // base <= 2^32 < base + size. + // + // When in state 1, an interval refinement may trigger state transition. + // After Encode() refines the interval, there are three possibilities: + // #1. base <= 2^32 < base + size (unchanged), + // #2. 2^32 <= base < base + size (base overflowed), + // #3. base < base + size <= 2^32 (base + size - 1 underflowed). + // + // In case #1, the encoder remains in state 1. + // In case #2 or #3, the encoder state changes to state 0. + // + // ==== State transition for interval refinement ==== + // 1. state 0 -> state 0, + // 2. state 1 -> state 0 or state 1. + // + // Therefore if the new state is 1, then the previous state must have been + // state 1. + if (base_ + size_minus1_ < base_) { + // If statement checked if 2^32 < base + size. The new state is 1, hence the + // previous state was also state 1. + DCHECK_NE(((base_ - a) + size) >> 32, 0); + DCHECK_NE(delay_ & 0xFFFF, 0); + + // Like in state 0, if the new size is <= 2^16, then base and size should + // be left-shifted by 16 bits. Combine the conditions + // base <= 2^32 < base + size and size <= 2^16 to conclude that + // base[32:16] >= 0xFFFF and (base + size - 1)[32:16] = 0x0000. + // + // Note that 2^32 - base < size, and since base is at least 0xFFFF0000, + // 2^16 - base[16:0] < size. Let base' and size' be the new base and size + // after the bit-shift. Then 2^32 - base' < size' => 2^32 < base' + size'. + // Therefore the encoder remains in state 1. + // + // Lastly, `delay` is modified. Conceptually, delay has to be changed to + // delay' <- delay * 2^16 + (base + size - 1)[32:16]. + // Since we know above that (base + size - 1)[32:16] = 0x0000, there is no + // need to explicitly do the computation above, but rather store how many + // trailing zeros there were. For this reason, the lower 16 bits of + // `delay` stores the delayed value when state changed from 0 to 1, and + // delay[32:16] stores the # of trailing zeros (in bytes). + // + // ==== State transition for interval expansion ==== + // 1. state 0 -> state 0 or state 1, + // 2. state 1 -> state 1. + if (size_minus1_ >> 16 == 0) { + DCHECK_EQ(base_ >> 16, 0xFFFF); + base_ <<= 16; + size_minus1_ <<= 16; + size_minus1_ |= 0xFFFF; + // TODO(sjhwang): It is possible that for very long input, delay + // overflow during below. If overflow is detected, this delay is too + // long the encoder should forcefully move to state 0. In such case, + // base can be raised to 2^32 (force case #2), or (base + size) can be + // lowered to 2^32 (force case #3), depending on which transition + // keeps size larger. + CHECK_LT(delay_, static_cast(1) << 62); + delay_ += 0x20000; // Two more bytes of zeros. Check overflow? + } + return; + } + + // If reached here, the current state is 0. + // First handle the case when the previous state was state 1. + if (delay_ != 0) { + // In case #2 or #3, the encoder state changes to state 0. Recall that when + // the encoder state changed from state 0 to state 1, the top 16 bits of + // (base + size - 1) was temporarily stored in `delay`, because the output + // could be either (delay - 1) or (delay). + // + // And from above, the delayed value encoded in `delay` is + // delay' <- delay[16:0] * 2^(8 * delay[MAX:16]) + // + // In case #2, the interval moved below 2^32. So (delay' - 1) is the + // converged value after interval refinements. Write out + // (delay[16:0] - 1) and write (8 * delay[MAX:16]) bytes of 0xFF. + // + // In case #3, the interval moved above 2^32. So delay' is the converged + // value after interval refinement. Write out delay[16:0] and write + // (8 * delay[MAX:16]) bytes of 0x00. + if (base_overflow) { + // Case #2. + DCHECK_NE((static_cast(base_ - a) + a) >> 32, 0); + sink->push_back(static_cast(delay_ >> 8)); + sink->push_back(static_cast(delay_ >> 0)); + sink->append(delay_ >> 16, static_cast(0)); + } else { + // Case #3. + DCHECK_EQ(static_cast(base_ + size_minus1_) >> 32, 0); + --delay_; + sink->push_back(static_cast(delay_ >> 8)); + sink->push_back(static_cast(delay_ >> 0)); + sink->append(delay_ >> 16, static_cast(0xFF)); + } + // Reset to state 0. + delay_ = 0; + } + + if (size_minus1_ >> 16 == 0) { + const uint32 top = base_ >> 16; + + base_ <<= 16; + size_minus1_ <<= 16; + size_minus1_ |= 0xFFFF; + + if (base_ <= base_ + size_minus1_) { + // Still in state 0. Write the top 16 bits. + sink->push_back(static_cast(top >> 8)); + sink->push_back(static_cast(top)); + } else { + // New state is 1. + DCHECK_LT(top, 0xFFFF); + delay_ = top + 1; + } + } +} + +void RangeEncoder::Finalize(string* sink) { + // Finalize the encode by writing out any number in the interval + // [base, base + size). + // + // Trailing zeros are not explicitly written out as decoder can fill in zeros + // by default. + if (delay_ != 0) { + // The last state was state 1. Since base < 2^32 < base + size, pick 2^32 + // (state 1, case #3). + // NOTE: It is a bit difficult to trigger this code path on purpose. + // TODO(sjhwang): Find a way to trigger this code path for test coverage. + sink->push_back(static_cast(delay_ >> 8)); + if ((delay_ & 0xFF) != 0) { + sink->push_back(static_cast(delay_)); + } + } else if (base_ != 0) { + // If base == 0, then pick 0 from [base, base + size) and no zeros are + // explcitly written. + // + // Otherwise, pick (base + (2^16 - base[16:0])), i.e., round up base to the + // next multiple of 2^16. As 2^16 < size, this value should be in the + // interval [base, base + size). + const uint32 mid = ((base_ - 1) >> 16) + 1; + DCHECK_EQ(mid & 0xFFFF, mid); + sink->push_back(static_cast(mid >> 8)); + if ((mid & 0xFF) != 0) { + sink->push_back(static_cast(mid >> 0)); + } + } + + base_ = 0; + size_minus1_ = std::numeric_limits::max(); + delay_ = 0; +} + +RangeDecoder::RangeDecoder(const string& source, int precision) + : current_(source.begin()), + begin_(source.begin()), + end_(source.end()), + precision_(precision) { + CHECK_LE(precision, 16); + + Read16BitValue(); + Read16BitValue(); +} + +int32 RangeDecoder::Decode(tensorflow::gtl::ArraySlice cdf) { + const uint64 size = static_cast(size_minus1_) + 1; + const uint64 offset = + ((static_cast(value_ - base_) + 1) << precision_) - 1; + + // This is similar to std::lower_range() with std::less_equal as comparison. + // After the binary search, `pv` points to the smallest number v that + // satisfies offset < (size * v) / 2^precision. + + // Assumes that cdf[0] == 0. Therefore (size * cdf[0]) / 2^precision is always + // less than or equal to offset. + const int32* pv = cdf.data() + 1; + // `len` can be cdf.size() - 2 if there is guarantee that the last element of + // cdf is 2^precision. + auto len = cdf.size() - 1; + DCHECK_GT(len, 0); + + do { + const auto half = len / 2; + const int32* mid = pv + half; + DCHECK_GE(*mid, 0); + DCHECK_LE(*mid, 1 << precision_); + if (size * static_cast(*mid) <= offset) { + pv = mid + 1; + len -= half + 1; + } else { + len = half; + } + } while (len > 0); + + // If (size * v) / 2^precision <= offset for all v in cdf, then pv points to + // one after the last element of cdf. That is a decoding error. + // + // TODO(sjhwang): Consider returning -1 to indicate error. Or start len = + // cdf.size() - 2 instead and give up detecting this error. + CHECK_LT(pv, cdf.data() + cdf.size()); + + const uint32 a = (size * static_cast(*(pv - 1))) >> precision_; + const uint32 b = ((size * static_cast(*pv)) >> precision_) - 1; + DCHECK_LE(a, offset >> precision_); + DCHECK_LE(offset >> precision_, b); + + base_ += a; + size_minus1_ = b - a; + + if (size_minus1_ >> 16 == 0) { + base_ <<= 16; + size_minus1_ <<= 16; + size_minus1_ |= 0xFFFF; + + Read16BitValue(); + } + + return pv - cdf.data() - 1; +} + +void RangeDecoder::Read16BitValue() { + value_ <<= 8; + if (current_ != end_) { + value_ |= static_cast(*current_++); + } + value_ <<= 8; + if (current_ != end_) { + value_ |= static_cast(*current_++); + } +} +} // namespace tensorflow diff --git a/tensorflow/contrib/coder/kernels/range_coder.h b/tensorflow/contrib/coder/kernels/range_coder.h new file mode 100644 index 0000000000000000000000000000000000000000..c24fb707fc9f1776a4e6e7be7df3245c0cdccb0b --- /dev/null +++ b/tensorflow/contrib/coder/kernels/range_coder.h @@ -0,0 +1,109 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_CODER_KERNELS_RANGE_CODER_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_CODER_KERNELS_RANGE_CODER_H_ + +#include +#include + +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +class RangeEncoder { + public: + // `precision` determines the granularity of probability masses passed to + // Encode() function below. + // + // REQUIRES: 0 < precision <= 16. + explicit RangeEncoder(int precision); + + // Encodes a half-open interval [lower / 2^precision, upper / 2^precision). + // Suppose each character to be encoded is from an integer-valued + // distribution. When encoding a random character x0, the arguments lower and + // upper represent + // Pr(X < x0) = lower / 2^precision, + // Pr(X < x0 + 1) = upper / 2^precision, + // where X is a random variable following the distribution. + // + // For example, assume that the distribution has possible outputs 0, 1, 2, ... + // To encode value 0, lower = 0 and upper = Pr(X = 0). + // To encode value 1, lower = Pr(X = 0) and upper = Pr(X = 0 or 1). + // To encode value 2, lower = Pr(X = 0 or 1) and upper = Pr(X = 0, 1, or 2). + // ... + // + // REQUIRES: 0 <= lower < upper <= 2^precision. + void Encode(int32 lower, int32 upper, string* sink); + + // The encode may contain some under-determined values from previous encoding. + // After Encode() calls, Finalize() must be called. Otherwise the encoded + // string may not be decoded. + void Finalize(string* sink); + + private: + uint32 base_ = 0; + uint32 size_minus1_ = std::numeric_limits::max(); + uint64 delay_ = 0; + + const int precision_; +}; + +class RangeDecoder { + public: + // Holds a reference to `source`. The caller has to make sure that `source` + // outlives the decoder object. + // + // REQUIRES: `precision` must be the same as the encoder's precision. + // REQUIRES: 0 < precision <= 16. + RangeDecoder(const string& source, int precision); + + // Decodes a character from `source` using CDF. The size of `cdf` should be + // one more than the number of the character in the alphabet. + // + // If x0, x1, x2, ... are the possible characters (in increasing order) from + // the distribution, then + // cdf[0] = 0 + // cdf[1] = Pr(X <= x0), + // cdf[2] = Pr(X <= x1), + // cdf[3] = Pr(X <= x2), + // ... + // + // The returned value is an index to `cdf` where the decoded character + // corresponds to. + // + // REQUIRES: cdf.size() > 1. + // REQUIRES: cdf[i] <= cdf[i + 1] for i = 0, 1, ..., cdf.size() - 2. + // REQUIRES: cdf[cdf.size() - 1] <= 2^precision. + // + // In practice the last element of `cdf` should equal to 2^precision. + int32 Decode(gtl::ArraySlice cdf); + + private: + void Read16BitValue(); + + uint32 base_ = 0; + uint32 size_minus1_ = std::numeric_limits::max(); + uint32 value_ = 0; + + string::const_iterator current_; + const string::const_iterator begin_; + const string::const_iterator end_; + + const int precision_; +}; +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_CODER_KERNELS_RANGE_CODER_H_ diff --git a/tensorflow/contrib/coder/kernels/range_coder_ops.cc b/tensorflow/contrib/coder/kernels/range_coder_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..cde7982530fea6407aaf074f7af4a22263d50da3 --- /dev/null +++ b/tensorflow/contrib/coder/kernels/range_coder_ops.cc @@ -0,0 +1,307 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#define EIGEN_USE_THREADS + +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/coder/kernels/range_coder.h" +#include "tensorflow/contrib/coder/kernels/range_coder_ops_util.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace { +// A helper class to iterate over data and cdf simultaneously, while cdf is +// broadcasted to data. +// NOTE: Moving this class out of anonymous namespace impacts compiler +// optimization and affects performance. When moving this code around (e.g., +// into a library header), be sure to check the benchmark tests. +template +class BroadcastRange { + public: + BroadcastRange(T* data_pointer, gtl::ArraySlice data_shape, + const U* cdf_pointer, gtl::ArraySlice cdf_shape) + : data_pointer_(data_pointer), cdf_pointer_(cdf_pointer) { + CHECK(!data_shape.empty()); + CHECK_EQ(data_shape.size(), N); + CHECK_EQ(cdf_shape.size(), N + 1); + + std::copy(data_shape.begin(), data_shape.end(), &data_shape_[0]); + data_index_.fill(0); + + const int64 innermost_stride = cdf_shape[N]; + cdf_displace_.fill(innermost_stride); + + // Pre-compute the pointer displacement for cdf. + int64 stride = innermost_stride; + for (int i = N - 1; i >= 0; --i) { + const bool broadcasting = (cdf_shape[i] <= 1); + + // When the data linear index advances by one, the cdf linear index + // advances by `innermost_stride`. + // + // Suppose that the i-th axis coordinate of data increased by one, and + // that i-th axis is broadcasting. The cdf linear index should be wound + // back by i-th axis stride, so that i-th axis coordinate of cdf is + // effectively kept at 0. + if (broadcasting) { + cdf_displace_[i] -= stride; + } + stride *= cdf_shape[i]; + } + } + + // Returns the pointers to the current iterating locations to data and cdf + // tensors. + // + // Note that this function does not track whether data pointer is running past + // the end of data buffer. The caller has to make sure Next() is called no + // more than that. + std::pair Next() { + std::pair return_value = {data_pointer_, cdf_pointer_}; + + int i = N - 1; + for (; i > 0; --i) { + ++data_index_[i]; + if (data_index_[i] < data_shape_[i]) { + break; + } + data_index_[i] = 0; + } + + // Advance data pointer by one. + data_pointer_ += 1; + + // For cdf pointer, it's more complicated because of broadcasting. When i-th + // coordinate increase by one, and if i-th axis is broadcasting, then we + // need to rewind back the pointer so that the effective i-th axis + // coordinate for cdf is always 0. This value is precomputed as + // cdf_displace_. + cdf_pointer_ += cdf_displace_[i]; + return return_value; + } + + private: + std::array data_shape_; + std::array cdf_displace_; + std::array data_index_; + + T* data_pointer_; + const U* cdf_pointer_; +}; + +Status CheckCdfShape(const TensorShape& data_shape, + const TensorShape& cdf_shape) { + if (TF_PREDICT_FALSE(cdf_shape.dims() != data_shape.dims() + 1)) { + return errors::InvalidArgument( + "`cdf` should have one more axis than `data`: data shape=", + data_shape.DebugString(), ", cdf shape=", cdf_shape.DebugString()); + } + + if (TF_PREDICT_FALSE(cdf_shape.dim_size(cdf_shape.dims() - 1) <= 1)) { + return errors::InvalidArgument( + "The last dimension of `cdf` should be > 1: ", cdf_shape.DebugString()); + } + + return Status::OK(); +} + +// Non-incremental encoder op ------------------------------------------------- +class RangeEncodeOp : public OpKernel { + public: + explicit RangeEncodeOp(OpKernelConstruction* context) : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("precision", &precision_)); + OP_REQUIRES(context, 0 < precision_ && precision_ <= 16, + errors::InvalidArgument("`precision` must be in [1, 16]: ", + precision_)); + } + + void Compute(OpKernelContext* context) override { + const Tensor& data = context->input(0); + const Tensor& cdf = context->input(1); + + OP_REQUIRES_OK(context, CheckCdfShape(data.shape(), cdf.shape())); + + std::vector data_shape, cdf_shape; + OP_REQUIRES_OK( + context, MergeAxes(data.shape(), cdf.shape(), &data_shape, &cdf_shape)); + + Tensor* output_tensor; + OP_REQUIRES_OK(context, + context->allocate_output(0, TensorShape{}, &output_tensor)); + string* output = &output_tensor->scalar()(); + + switch (data_shape.size()) { +#define RANGE_ENCODE_CASE(dims) \ + case dims: { \ + RangeEncodeImpl(data.flat(), data_shape, \ + cdf.flat_inner_dims(), cdf_shape, output); \ + } break + RANGE_ENCODE_CASE(1); + RANGE_ENCODE_CASE(2); + RANGE_ENCODE_CASE(3); + RANGE_ENCODE_CASE(4); + RANGE_ENCODE_CASE(5); + RANGE_ENCODE_CASE(6); +#undef RANGE_ENCODE_CASE + default: + context->CtxFailure(errors::InvalidArgument( + "Irregular broadcast pattern: ", data.shape().DebugString(), ", ", + cdf.shape().DebugString())); + return; + } + } + + private: + template + void RangeEncodeImpl(TTypes::ConstFlat data, + gtl::ArraySlice data_shape, + TTypes::ConstMatrix cdf, + gtl::ArraySlice cdf_shape, string* output) const { + const int64 data_size = data.size(); + const int64 cdf_size = cdf.size(); + const int64 chip_size = cdf.dimension(1); + + BroadcastRange view{data.data(), data_shape, + cdf.data(), cdf_shape}; + RangeEncoder encoder{precision_}; + for (int64 linear = 0; linear < data_size; ++linear) { + const auto pair = view.Next(); + + const int64 index = *pair.first; + DCHECK_GE(index, 0); + DCHECK_LT(index + 1, chip_size); + + const int32* cdf_slice = pair.second; + DCHECK_LE(cdf_slice + chip_size, cdf.data() + cdf_size); + + const int32 lower = cdf_slice[index]; + const int32 upper = cdf_slice[index + 1]; + encoder.Encode(lower, upper, output); + } + + encoder.Finalize(output); + } + + int precision_; +}; + +REGISTER_KERNEL_BUILDER(Name("RangeEncode").Device(DEVICE_CPU), RangeEncodeOp); + +// Non-incremental decoder op ------------------------------------------------- +class RangeDecodeOp : public OpKernel { + public: + explicit RangeDecodeOp(OpKernelConstruction* context) : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("precision", &precision_)); + OP_REQUIRES(context, 0 < precision_ && precision_ <= 16, + errors::InvalidArgument("`precision` must be in [1, 16]: ", + precision_)); + } + + void Compute(OpKernelContext* context) override { + const Tensor& encoded_tensor = context->input(0); + const Tensor& shape = context->input(1); + const Tensor& cdf = context->input(2); + + OP_REQUIRES(context, TensorShapeUtils::IsScalar(encoded_tensor.shape()), + errors::InvalidArgument("Invalid `encoded` shape: ", + encoded_tensor.shape().DebugString())); + OP_REQUIRES(context, TensorShapeUtils::IsVector(shape.shape()), + errors::InvalidArgument("Invalid `shape` shape: ", + shape.shape().DebugString())); + TensorShape output_shape; + OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(shape.vec(), + &output_shape)); + OP_REQUIRES_OK(context, CheckCdfShape(output_shape, cdf.shape())); + + std::vector data_shape, cdf_shape; + OP_REQUIRES_OK( + context, MergeAxes(output_shape, cdf.shape(), &data_shape, &cdf_shape)); + + const string& encoded = encoded_tensor.scalar()(); + + Tensor* output; + OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output)); + + switch (data_shape.size()) { +#define RANGE_DECODE_CASE(dim) \ + case dim: { \ + RangeDecodeImpl(output->flat(), data_shape, \ + cdf.flat_inner_dims(), cdf_shape, encoded); \ + } break + RANGE_DECODE_CASE(1); + RANGE_DECODE_CASE(2); + RANGE_DECODE_CASE(3); + RANGE_DECODE_CASE(4); + RANGE_DECODE_CASE(5); + RANGE_DECODE_CASE(6); +#undef RANGE_DECODE_CASE + default: + context->CtxFailure(errors::InvalidArgument( + "Irregular broadcast pattern: ", output_shape.DebugString(), ", ", + cdf.shape().DebugString())); + return; + } + } + + private: + template + void RangeDecodeImpl(TTypes::Flat output, + gtl::ArraySlice output_shape, + TTypes::ConstMatrix cdf, + gtl::ArraySlice cdf_shape, + const string& encoded) const { + BroadcastRange view{output.data(), output_shape, + cdf.data(), cdf_shape}; + + RangeDecoder decoder{encoded, precision_}; + + const int64 output_size = output.size(); + const int64 cdf_size = cdf.size(); + const auto chip_size = + static_cast::size_type>(cdf.dimension(1)); + + for (int64 i = 0; i < output_size; ++i) { + const auto pair = view.Next(); + + int16* data = pair.first; + DCHECK_LT(data, output.data() + output_size); + + const int32* cdf_slice = pair.second; + DCHECK_LE(cdf_slice + chip_size, cdf.data() + cdf_size); + + *data = decoder.Decode(gtl::ArraySlice{cdf_slice, chip_size}); + } + } + + int precision_; +}; + +REGISTER_KERNEL_BUILDER(Name("RangeDecode").Device(DEVICE_CPU), RangeDecodeOp); +} // namespace +} // namespace tensorflow diff --git a/tensorflow/contrib/coder/kernels/range_coder_ops_test.cc b/tensorflow/contrib/coder/kernels/range_coder_ops_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..ae4d9d2836a0f89a9765004a85bc3c292b0e484f --- /dev/null +++ b/tensorflow/contrib/coder/kernels/range_coder_ops_test.cc @@ -0,0 +1,521 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "tensorflow/contrib/coder/kernels/range_coder.h" +#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" +#include "tensorflow/core/common_runtime/shape_refiner.h" +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/framework/versions.pb.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/graph/testlib.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/lib/core/bits.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/random/random.h" +#include "tensorflow/core/lib/random/simple_philox.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/public/session.h" +#include "tensorflow/core/public/session_options.h" + +namespace tensorflow { +namespace { +int LogUniform(random::SimplePhilox* gen, uint32 n) { + CHECK_GT(n, 0); + + // Split [0, n) into {0}, [1, 2), [2, 4), [4, 8), ..., [2^(m-1), n). + const int m = Log2Ceiling(n); + + int outcome; + do { + // Uniform() consumes at least 32 bits per call, therefore this is somewhat + // wasteful implementation. Since this is used only for test, we do not + // refine this implementation further. + const int k = gen->Uniform(m + 1) - 1; + // If k == -1, then sample from {0}. + // If k == 0, then sample from [1, 2). + // If k == 1, then sample from [2, 4), ... and so on. + if (k < 1) { + outcome = k + 1; + } else { + outcome = (1 << k) + gen->Uniform(1 << k); + } + } while (n <= outcome); + return outcome; +} + +std::vector ComputeStrides(const TensorShape& shape) { + std::vector stride(shape.dims()); + int64 current = 1; + for (int i = shape.dims() - 1; i >= 0; --i) { + stride[i] = current; + current *= shape.dim_size(i); + } + return stride; +} + +class RangeCoderOpsTest : public OpsTestBase { + protected: + Status RunEncodeOp(int precision, gtl::ArraySlice input, + Tensor* output) { + TF_RETURN_IF_ERROR(NodeDefBuilder("encode", "RangeEncode") + .Input(tensorflow::FakeInput(DT_INT16)) + .Input(tensorflow::FakeInput(DT_INT32)) + .Attr("precision", precision) + .Finalize(node_def())); + TF_RETURN_IF_ERROR(InitOp()); + + inputs_.clear(); + std::vector copies(input.size()); + for (int i = 0; i < input.size(); ++i) { + copies[i] = input[i]; + inputs_.emplace_back(&copies[i]); + } + + TF_RETURN_IF_ERROR(RunOpKernel()); + + *output = *GetOutput(0); + inputs_.clear(); + + return Status::OK(); + } + + Status RunDecodeOp(int precision, gtl::ArraySlice input, + Tensor* output) { + TF_RETURN_IF_ERROR(NodeDefBuilder("decode", "RangeDecode") + .Input(tensorflow::FakeInput(DT_STRING)) + .Input(tensorflow::FakeInput(DT_INT32)) + .Input(tensorflow::FakeInput(DT_INT32)) + .Attr("precision", precision) + .Finalize(node_def())); + TF_RETURN_IF_ERROR(InitOp()); + + inputs_.clear(); + std::vector copies(input.size()); + for (int i = 0; i < input.size(); ++i) { + copies[i] = input[i]; + inputs_.emplace_back(&copies[i]); + } + + TF_RETURN_IF_ERROR(RunOpKernel()); + + *output = *GetOutput(0); + inputs_.clear(); + + return Status::OK(); + } + + void TestEncodeAndDecode(int precision, const Tensor& data, + const Tensor& cdf) { + Tensor encoded; + TF_ASSERT_OK(RunEncodeOp(precision, {data, cdf}, &encoded)); + + const TensorShape& data_shape = data.shape(); + Tensor shape{DT_INT32, {data_shape.dims()}}; + for (int i = 0; i < data_shape.dims(); ++i) { + shape.flat()(i) = data_shape.dim_size(i); + } + + Tensor decoded; + TF_ASSERT_OK(RunDecodeOp(precision, {encoded, shape, cdf}, &decoded)); + + EXPECT_EQ(decoded.dtype(), data.dtype()); + EXPECT_EQ(decoded.shape(), data.shape()); + EXPECT_EQ(decoded.tensor_data(), data.tensor_data()); + } + + void PopulateMaxValues(random::SimplePhilox* gen, Tensor* maxvalue_tensor, + int min_maxvalue, int max_maxvalue) { + const int range = max_maxvalue - min_maxvalue; + TTypes::Flat flat = maxvalue_tensor->flat(); + + for (int64 i = 0; i < flat.size(); ++i) { + flat(i) = min_maxvalue + gen->Uniform(range); + } + } + + void BuildCdf(random::SimplePhilox* gen, Tensor* data_tensor, + Tensor* cdf_tensor, const Tensor& maxvalue_tensor) { + CHECK(TensorShapeUtils::StartsWith(cdf_tensor->shape(), + maxvalue_tensor.shape())); + CHECK_EQ(cdf_tensor->dims(), maxvalue_tensor.dims() + 1); + const int64 chip_size = cdf_tensor->dim_size(cdf_tensor->dims() - 1); + + std::vector data_stride = ComputeStrides(data_tensor->shape()); + std::vector cdf_stride = ComputeStrides(cdf_tensor->shape()); + + for (int i = 0; i < cdf_tensor->dims(); ++i) { + if (cdf_tensor->dim_size(i) == 1) { + cdf_stride[i] = 0; + } + } + + Tensor histogram_tensor{DT_INT32, cdf_tensor->shape()}; + TTypes::Flat data = data_tensor->flat(); + TTypes::Flat histogram = histogram_tensor.flat(); + TTypes::ConstFlat maxvalue = maxvalue_tensor.flat(); + histogram.setZero(); + + for (int64 index = 0; index < data.size(); ++index) { + int64 temp = index; + int64 offset = 0; + for (int dim = 0; dim < data_stride.size(); ++dim) { + const int64 coord = temp / data_stride[dim]; + offset += coord * cdf_stride[dim]; + temp -= coord * data_stride[dim]; + } + ASSERT_EQ(temp, 0); + + const int64 maxvalue_offset = offset / chip_size; + CHECK_EQ(maxvalue_offset * chip_size, offset); + CHECK_LT(maxvalue(maxvalue_offset) + 1, chip_size); + const int value = LogUniform(gen, maxvalue(maxvalue_offset)); + data(index) = value; + histogram(offset + value + 1) += 1; + } + + cdf_tensor->flat_inner_dims() = + histogram_tensor.flat_inner_dims().cumsum(1); + } +}; + +TEST_F(RangeCoderOpsTest, NoBroadcast) { + constexpr int kPrecision = 14; + constexpr int kMaxValue = 10; + + Tensor data{DT_INT16, {1, 32, 32, 16}}; + Tensor temp{DT_INT32, {1, 1, 1, 1, kMaxValue + 2}}; + Tensor maxvalue{DT_INT16, {1, 1, 1, 1}}; + maxvalue.flat()(0) = kMaxValue; + + ASSERT_LE(data.shape().num_elements(), 1 << kPrecision); + + random::PhiloxRandom philox(random::New64(), random::New64()); + random::SimplePhilox gen(&philox); + BuildCdf(&gen, &data, &temp, maxvalue); + + const Eigen::array broadcast = {1, 32, 32, 16, 1}; + + Tensor cdf{DT_INT32, {1, 32, 32, 16, kMaxValue + 2}}; + cdf.tensor() = temp.tensor().broadcast(broadcast); + + TestEncodeAndDecode(kPrecision, data, cdf); +} + +TEST_F(RangeCoderOpsTest, Broadcast1Axis) { + constexpr int kPrecision = 9; + constexpr int kDimensionSize = 1 << kPrecision; + constexpr int kMinMaxValue = 10; + constexpr int kMaxMaxValue = 64; + + random::PhiloxRandom philox(random::New64(), random::New64()); + random::SimplePhilox gen(&philox); + Tensor data{DT_INT16, {1, kDimensionSize, kDimensionSize}}; + + Tensor maxvalue{DT_INT16, {kDimensionSize}}; + PopulateMaxValues(&gen, &maxvalue, kMinMaxValue, kMaxMaxValue); + + { + // Axis 1. + Tensor maxvalue1; + ASSERT_TRUE(maxvalue1.CopyFrom(maxvalue, {1, 1, kDimensionSize})); + + Tensor cdf{DT_INT32, {1, 1, kDimensionSize, kMaxMaxValue + 2}}; + BuildCdf(&gen, &data, &cdf, maxvalue1); + TestEncodeAndDecode(kPrecision, data, cdf); + } + + { + // Axis 2. + Tensor maxvalue2; + ASSERT_TRUE(maxvalue2.CopyFrom(maxvalue, {1, kDimensionSize, 1})); + + Tensor cdf{DT_INT32, {1, kDimensionSize, 1, kMaxMaxValue + 2}}; + BuildCdf(&gen, &data, &cdf, maxvalue2); + TestEncodeAndDecode(kPrecision, data, cdf); + } +} + +TEST_F(RangeCoderOpsTest, Broadcast2Axes) { + constexpr int kPrecision = 13; + constexpr int kDimensionSize1 = 1 << (kPrecision / 2); + constexpr int kDimensionSize2 = 1 << (kPrecision - kPrecision / 2); + constexpr int kMinMaxValue = 10; + constexpr int kMaxMaxValue = 64; + + random::PhiloxRandom philox(random::New64(), random::New64()); + random::SimplePhilox gen(&philox); + Tensor maxvalue{DT_INT16, {2, 1, 1, 7}}; + PopulateMaxValues(&gen, &maxvalue, kMinMaxValue, kMaxMaxValue); + + Tensor data{DT_INT16, {2, kDimensionSize1, kDimensionSize2, 7}}; + Tensor cdf{DT_INT32, {2, 1, 1, 7, kMaxMaxValue + 2}}; + BuildCdf(&gen, &data, &cdf, maxvalue); + TestEncodeAndDecode(kPrecision, data, cdf); +} + +TEST_F(RangeCoderOpsTest, InvalidCdfShape) { + Tensor data{DT_INT16, {3, 3}}; + Tensor cdf{DT_INT32, {3, 3}}; + + Tensor unused; + { + const Status status = RunEncodeOp(10, {data, cdf}, &unused); + EXPECT_FALSE(status.ok()); + EXPECT_NE(status.error_message().find("`cdf` should have one more axis"), + string::npos); + } + + Tensor empty{DT_STRING, {}}; + Tensor shape{DT_INT32, {2}}; + shape.vec().setValues({3, 3}); + { + const Status status = RunDecodeOp(10, {empty, shape, cdf}, &unused); + EXPECT_FALSE(status.ok()); + EXPECT_NE(status.error_message().find("`cdf` should have one more axis"), + string::npos); + } + + cdf = Tensor{DT_INT32, {3, 3, 1}}; + { + const Status status = RunEncodeOp(10, {data, cdf}, &unused); + EXPECT_FALSE(status.ok()); + EXPECT_NE( + status.error_message().find("last dimension of `cdf` should be > 1"), + string::npos); + } + { + const Status status = RunDecodeOp(10, {empty, shape, cdf}, &unused); + EXPECT_FALSE(status.ok()); + EXPECT_NE( + status.error_message().find("last dimension of `cdf` should be > 1"), + string::npos); + } +} + +TEST_F(RangeCoderOpsTest, DecoderShapeFn) { + Tensor encoded_tensor{DT_STRING, {}}; + Tensor shape_tensor{DT_INT32, {3}}; + Tensor cdf_tensor{DT_INT32, {4, 6, 8, 2}}; + + shape_tensor.flat().setValues({4, 6, 8}); + + Graph g{OpRegistry::Global()}; + Node* encoded = test::graph::Constant(&g, encoded_tensor); + Node* shape = test::graph::Constant(&g, shape_tensor); + Node* cdf = test::graph::Constant(&g, cdf_tensor); + Node* decode; + TF_ASSERT_OK(NodeBuilder("range_decode", "RangeDecode", g.op_registry()) + .Input(encoded) + .Input(shape) + .Input(cdf) + .Attr("precision", 10) + .Finalize(&g, &decode)); + + ShapeRefiner refiner{g.versions().producer(), g.op_registry()}; + TF_ASSERT_OK(refiner.AddNode(encoded)); + TF_ASSERT_OK(refiner.AddNode(shape)); + TF_ASSERT_OK(refiner.AddNode(cdf)); + TF_ASSERT_OK(refiner.AddNode(decode)); + + auto* context = refiner.GetContext(decode); + ASSERT_NE(context, nullptr); + + ASSERT_EQ(context->num_outputs(), 1); + auto shape_handle = context->output(0); + + ASSERT_EQ(context->Rank(shape_handle), 3); + EXPECT_EQ(context->Value(context->Dim(shape_handle, 0)), 4); + EXPECT_EQ(context->Value(context->Dim(shape_handle, 1)), 6); + EXPECT_EQ(context->Value(context->Dim(shape_handle, 2)), 8); +} + +TEST_F(RangeCoderOpsTest, InvalidBroadcast) { + Tensor data{DT_INT16, {3, 3}}; + Tensor cdf{DT_INT32, {3, 2, 2}}; + + Tensor unused; + { + const Status status = RunEncodeOp(10, {data, cdf}, &unused); + EXPECT_FALSE(status.ok()); + EXPECT_NE(status.error_message().find("Cannot broadcast shape"), + string::npos); + } + + data = Tensor{DT_INT16, {3, 1}}; + cdf = Tensor{DT_INT32, {3, 3, 2}}; + Tensor empty{DT_STRING, {}}; + Tensor shape{DT_INT32, {2}}; + shape.vec().setValues({3, 1}); + { + const Status status = RunDecodeOp(10, {empty, shape, cdf}, &unused); + EXPECT_FALSE(status.ok()); + EXPECT_NE(status.error_message().find("Cannot broadcast shape"), + string::npos); + } + + std::vector shape_vector = {2, 2, 2, 2, 2, 2, 2, 2, 2}; + data = Tensor{DT_INT16, TensorShape{shape_vector}}; + cdf = Tensor{DT_INT32, {2, 1, 2, 1, 2, 1, 2, 1, 2, 2}}; + { + const Status status = RunEncodeOp(10, {data, cdf}, &unused); + EXPECT_FALSE(status.ok()); + EXPECT_NE(status.error_message().find("Irregular broadcast"), string::npos); + } + + shape = Tensor{DT_INT32, {static_cast(shape_vector.size())}}; + for (int i = 0; i < shape_vector.size(); ++i) { + shape.flat()(i) = shape_vector[i]; + } + { + const Status status = RunDecodeOp(10, {empty, shape, cdf}, &unused); + EXPECT_FALSE(status.ok()); + EXPECT_NE(status.error_message().find("Irregular broadcast"), string::npos); + } +} + +// Benchmark ------------------------------------------------------------- + +// This function creates RangeEncode graph with CDF built from a separate data +// sample. +Graph* CreateRangeEncodeFullBroadcastGraph(const TensorShape& shape, + int precision) { + CHECK_EQ(shape.dims(), 4); + + constexpr int kAlphabetSize = 70; + + Tensor histogram{DT_INT32, {kAlphabetSize + 1}}; + TTypes::Vec h = histogram.vec(); + h.setConstant(1); + h(0) = 0; + + random::PhiloxRandom philox(random::New64(), random::New64()); + random::SimplePhilox gen(&philox); + for (int i = 0; i < (1 << precision) - kAlphabetSize; ++i) { + const int value = LogUniform(&gen, kAlphabetSize - 1); + h(value + 1) += 1; + } + + Tensor cdf{DT_INT32, {1, 1, 1, 1, kAlphabetSize + 1}}; + cdf.flat() = h.cumsum(0); + + Tensor data{DT_INT16, shape}; + TTypes::Flat d = data.flat(); + for (int64 i = 0; i < d.size(); ++i) { + d(i) = LogUniform(&gen, kAlphabetSize - 1); + } + + Graph* g = new Graph(OpRegistry::Global()); + TF_CHECK_OK(NodeBuilder("range_encode", "RangeEncode", g->op_registry()) + .Input(test::graph::Constant(g, data)) + .Input(test::graph::Constant(g, cdf)) + .Attr("precision", precision) + .Finalize(g, nullptr)); + return g; +} + +// This function creates RangeDecode graph with CDF built from a separate data +// sample. +Graph* CreateRangeDecodeFullBroadcastGraph(const TensorShape& shape, + int precision) { + CHECK_EQ(shape.dims(), 4); + + constexpr int kAlphabetSize = 200; + const int64 num_elements = shape.num_elements(); + + Tensor histogram{DT_INT32, {kAlphabetSize + 1}}; + TTypes::Vec h = histogram.vec(); + h.setConstant(1); + h(0) = 0; + + random::PhiloxRandom philox(random::New64(), random::New64()); + random::SimplePhilox gen(&philox); + for (int i = 0; i < (1 << precision) - kAlphabetSize; ++i) { + const int value = LogUniform(&gen, kAlphabetSize - 1); + h(value + 1) += 1; + } + + Tensor cdf_tensor{DT_INT32, {1, 1, 1, 1, kAlphabetSize + 1}}; + TTypes::Flat cdf = cdf_tensor.flat(); + cdf = h.cumsum(0); + + Tensor string_tensor{DT_STRING, TensorShape{}}; + string& sink = string_tensor.scalar()(); + + RangeEncoder encoder{precision}; + for (int64 i = 0; i < num_elements; ++i) { + const int value = LogUniform(&gen, kAlphabetSize - 1); + encoder.Encode(cdf(value), cdf(value + 1), &sink); + } + encoder.Finalize(&sink); + + Tensor shape_tensor{DT_INT32, {shape.dims()}}; + for (int i = 0; i < shape.dims(); ++i) { + shape_tensor.flat()(i) = shape.dim_size(i); + } + + Graph* g = new Graph(OpRegistry::Global()); + TF_CHECK_OK(NodeBuilder("range_decode", "RangeDecode", g->op_registry()) + .Input(test::graph::Constant(g, string_tensor)) + .Input(test::graph::Constant(g, shape_tensor)) + .Input(test::graph::Constant(g, cdf_tensor)) + .Attr("precision", precision) + .Finalize(g, nullptr)); + return g; +} + +void RunTensorFlowBenchmark(int iters, Graph* g, int64 num_elements) { + SessionOptions opts; + opts.config.set_intra_op_parallelism_threads(1); + opts.config.set_inter_op_parallelism_threads(1); + + testing::UseRealTime(); + test::Benchmark("cpu", g, &opts).Run(iters); + + const int64 num_items = static_cast(iters) * num_elements; + testing::ItemsProcessed(num_items); +} + +void BM_RangeEncodeFullBroadcast(int iters, int code_size) { + constexpr int kPrecision = 14; + const TensorShape shape = {1, code_size, code_size, 256}; + Graph* g = CreateRangeEncodeFullBroadcastGraph(shape, kPrecision); + RunTensorFlowBenchmark(iters, g, shape.num_elements()); +} + +BENCHMARK(BM_RangeEncodeFullBroadcast)->Arg(32)->Arg(64); + +void BM_RangeDecodeFullBroadcast(int iters, int code_size) { + constexpr int kPrecision = 14; + const TensorShape shape = {1, code_size, code_size, 256}; + Graph* g = CreateRangeDecodeFullBroadcastGraph(shape, kPrecision); + RunTensorFlowBenchmark(iters, g, shape.num_elements()); +} + +BENCHMARK(BM_RangeDecodeFullBroadcast)->Arg(32)->Arg(64); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/contrib/coder/kernels/range_coder_ops_util.cc b/tensorflow/contrib/coder/kernels/range_coder_ops_util.cc new file mode 100644 index 0000000000000000000000000000000000000000..d66730cb4881ea92b5477047c500291fa9c0c290 --- /dev/null +++ b/tensorflow/contrib/coder/kernels/range_coder_ops_util.cc @@ -0,0 +1,85 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/coder/kernels/range_coder_ops_util.h" + +#include + +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" + +using tensorflow::errors::InvalidArgument; + +namespace tensorflow { +Status MergeAxes(const TensorShape& broadcast_shape, + const TensorShape& storage_shape, + std::vector* merged_broadcast_shape_pointer, + std::vector* merged_storage_shape_pointer) { + CHECK_EQ(storage_shape.dims(), broadcast_shape.dims() + 1); + + std::vector& merged_broadcast_shape = *merged_broadcast_shape_pointer; + std::vector& merged_storage_shape = *merged_storage_shape_pointer; + + // The shapes are simplified so that the conversions between linear index + // and coordinates takes less CPU cycles. Two adjacent dimensions are + // merged if they both are broadcasting dimensions or if they both are + // non-broadcasting dimensions. + merged_broadcast_shape.resize(1); + merged_broadcast_shape[0] = 1; + merged_storage_shape.resize(1); + merged_storage_shape[0] = 1; + + for (int i = 0, j = 0; j < broadcast_shape.dims(); ++j) { + if (TF_PREDICT_FALSE( + (broadcast_shape.dim_size(j) != storage_shape.dim_size(j)) && + (storage_shape.dim_size(j) != 1))) { + return InvalidArgument("Cannot broadcast shape ", + storage_shape.DebugString(), " to ", + broadcast_shape.DebugString()); + } + + const bool was_broadcasting = (merged_storage_shape[i] == 1); + const bool is_broadcasting = (storage_shape.dim_size(j) == 1); + + // Merge two adjacent axes if they both are broadcasting or both are + // non-broadcasting axes. The second and the third conditions in the if + // clause below are when the previously merged axis or the next j-th axis + // may be interpreted as either a broadcasting or a non-broadcasting axis. + const bool merge = (was_broadcasting == is_broadcasting) || + (broadcast_shape.dim_size(j) <= 1) || + (merged_broadcast_shape[i] <= 1); + + if (merge) { + merged_broadcast_shape[i] *= broadcast_shape.dim_size(j); + merged_storage_shape[i] *= storage_shape.dim_size(j); + } else { + // Move to the next axis. + merged_broadcast_shape.push_back(broadcast_shape.dim_size(j)); + merged_storage_shape.push_back(storage_shape.dim_size(j)); + ++i; + } + } + + int64 storage_stride = 1; + for (int i = broadcast_shape.dims(); i < storage_shape.dims(); ++i) { + storage_stride *= storage_shape.dim_size(i); + } + merged_storage_shape.push_back(storage_stride); + + return Status::OK(); +} +} // namespace tensorflow diff --git a/tensorflow/contrib/coder/kernels/range_coder_ops_util.h b/tensorflow/contrib/coder/kernels/range_coder_ops_util.h new file mode 100644 index 0000000000000000000000000000000000000000..95241a8682891dc94780a9194d20aa9dc22e17c8 --- /dev/null +++ b/tensorflow/contrib/coder/kernels/range_coder_ops_util.h @@ -0,0 +1,33 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_CODER_KERNELS_RANGE_CODER_OPS_UTIL_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_CODER_KERNELS_RANGE_CODER_OPS_UTIL_H_ + +#include + +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +// The shapes are simplified to reduce indexing cost. +Status MergeAxes(const TensorShape& broadcast_shape, + const TensorShape& storage_shape, + std::vector* merged_broadcast_shape_pointer, + std::vector* merged_storage_shape_pointer); +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_CODER_KERNELS_RANGE_CODER_OPS_UTIL_H_ diff --git a/tensorflow/contrib/coder/kernels/range_coder_test.cc b/tensorflow/contrib/coder/kernels/range_coder_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..442994bf7c7566c1cbe1c439050a69e5b9a4208e --- /dev/null +++ b/tensorflow/contrib/coder/kernels/range_coder_test.cc @@ -0,0 +1,116 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/coder/kernels/range_coder.h" + +#include + +#include "tensorflow/core/lib/random/distribution_sampler.h" +#include "tensorflow/core/lib/random/philox_random.h" +#include "tensorflow/core/lib/random/random.h" +#include "tensorflow/core/lib/random/simple_philox.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { +void RangeEncodeDecodeTest(int precision, random::SimplePhilox* gen) { + constexpr int kAlphabetSize = 256; + + std::vector distribution_weight; + distribution_weight.reserve(kAlphabetSize); + for (int i = 1; i <= kAlphabetSize; ++i) { + distribution_weight.push_back(std::pow(static_cast(i), -2.0f)); + } + + random::DistributionSampler sampler(distribution_weight); + + const int multiplier = (precision > 7) ? 32 : 1; + std::vector histogram(kAlphabetSize, multiplier - 1); + + const int data_size = + (multiplier << precision) - histogram.size() * (multiplier - 1); + CHECK_GE(data_size, 0); + std::vector data(data_size); + for (uint8& x : data) { + x = sampler.Sample(gen); + ++histogram[x]; + } + + std::vector cdf(histogram.size() + 1, 0); + int partial_sum = 0; + for (int i = 0; i < histogram.size(); ++i) { + partial_sum += histogram[i]; + cdf[i + 1] = partial_sum / multiplier; + } + + ASSERT_EQ(cdf.front(), 0); + ASSERT_EQ(cdf.back(), 1 << precision); + + std::vector ideal_code_length(histogram.size()); + const double normalizer = static_cast(1 << precision); + for (int i = 0; i < ideal_code_length.size(); ++i) { + ideal_code_length[i] = -std::log2((cdf[i + 1] - cdf[i]) / normalizer); + } + + RangeEncoder encoder(precision); + string encoded; + double ideal_length = 0.0; + for (uint8 x : data) { + encoder.Encode(cdf[x], cdf[x + 1], &encoded); + ideal_length += ideal_code_length[x]; + } + encoder.Finalize(&encoded); + + LOG(INFO) << "Encoded string length (bits): " << 8 * encoded.size() + << ", whereas ideal " << ideal_length << " (" + << (8 * encoded.size()) / ideal_length << " of ideal) " + << " (ideal compression rate " << ideal_length / (8 * data.size()) + << ")"; + + RangeDecoder decoder(encoded, precision); + for (int i = 0; i < data.size(); ++i) { + const int32 decoded = decoder.Decode(cdf); + ASSERT_EQ(decoded, static_cast(data[i])) << i; + } +} + +TEST(RangeCoderTest, Precision1To11) { + random::PhiloxRandom gen(random::New64(), random::New64()); + random::SimplePhilox rand(&gen); + const int precision = 1 + rand.Uniform(11); + RangeEncodeDecodeTest(precision, &rand); +} + +TEST(RangeCoderTest, Precision12To16) { + random::PhiloxRandom gen(random::New64(), random::New64()); + random::SimplePhilox rand(&gen); + for (int precision = 12; precision < 17; ++precision) { + RangeEncodeDecodeTest(precision, &rand); + } +} + +TEST(RangeCoderTest, FinalizeState0) { + constexpr int kPrecision = 2; + + string output; + RangeEncoder encoder(kPrecision); + encoder.Encode(0, 2, &output); + encoder.Finalize(&output); + + RangeDecoder decoder(output, kPrecision); + EXPECT_EQ(decoder.Decode({0, 2, 4}), 0); +} +} // namespace +} // namespace tensorflow diff --git a/tensorflow/contrib/coder/ops/coder_ops.cc b/tensorflow/contrib/coder/ops/coder_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..9056d1a6963d7be92f499db31385fb6afe2dc515 --- /dev/null +++ b/tensorflow/contrib/coder/ops/coder_ops.cc @@ -0,0 +1,119 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { +using shape_inference::InferenceContext; +using shape_inference::ShapeHandle; + +// clang-format off +REGISTER_OP("RangeEncode") + .Input("data: int16") + .Input("cdf: int32") + .Output("encoded: string") + .Attr("precision: int >= 1") + .SetShapeFn(shape_inference::ScalarShape) + .Doc(R"doc( +Using the provided cumulative distribution functions (CDF) inside `cdf`, returns +a range-code of `data`. + +The shape of `cdf` should have one more axis than the shape of `data`, and the +prefix `cdf.shape[:-1]` should be broadcastable to `data.shape`. That is, for +every `i = 0,...,rank(data) - 1`, the op requires that either +`cdf.shape[i] == 1` or `cdf.shape[i] == data.shape[i]`. Note that this +broadcasting is limited in the sense that the number of axes must match, and +broadcasts only `cdf` but not `data`. + +`data` should have an upper bound `m > 0` such that each element is an integer +in range `[0, m)`. Then the last dimension size of `cdf` must be `m + 1`. For +each element of `data`, the innermost strip of `cdf` is a vector representing a +CDF. For each k = 0,...,m, `cdf[..., k] / 2^precision` is the probability that +an outcome is less than `k` (not less than or equal to). + +``` + cdf[..., 0] / 2^precision = Pr(data[...] < 0) + cdf[..., 1] / 2^precision = Pr(data[...] < 1) = Pr(data[...] <= 0) + cdf[..., 2] / 2^precision = Pr(data[...] < 2) = Pr(data[...] <= 1) + ... + cdf[..., m] / 2^precision = Pr(data[...] < m) = 1 +``` + +Therefore each element of `cdf` must be in `[0, 2^precision]`. + +Ideally `cdf[..., m]` should equal to `2^precision` but this is not a hard +requirement as long as `cdf[..., m] <= 2^precision`. + +The encoded string neither contains the shape information of the encoded data +nor a termination symbol. Therefore the shape of the encoded data must be +explicitly provided to the decoder. + +Implementation notes: + +- Because of potential performance issues, the op does not check whether +elements of `data` is in the correct range `[0, m)`, or if `cdf` satisfies +monotonic increase property. + +- For the range coder to decode the encoded string correctly, the decoder should +be able to reproduce the internal states of the encoder precisely. Otherwise, +the decoding would fail and once an error occur, all subsequent decoded values +are incorrect. For this reason, the range coder uses integer arithmetics and +avoids using any floating point operations internally, and `cdf` should contain +integers representing quantized probability mass rather than floating points. + +data: An int32 tensor. +cdf: An int32 tensor representing the CDF's of `data`. Each integer is divided + by `2^precision` to represent a fraction. +encoded: A range-coded scalar string. +precision: The number of bits for probability quantization. Must be <= 16. +)doc"); + + +REGISTER_OP("RangeDecode") + .Input("encoded: string") + .Input("shape: int32") + .Input("cdf: int32") + .Output("decoded: int16") + .Attr("precision: int >= 1") + .SetShapeFn([] (InferenceContext* c) { + ShapeHandle out; + TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &out)); + c->set_output(0, out); + return Status::OK(); + }) + .Doc(R"doc( +Decodes a range-coded `code` into an int32 tensor of shape `shape`. + +This is the reverse op of RangeEncode. The shape of the tensor that was encoded +should be known by the caller. + +Implementation notes: + +- If wrong input was given (e.g., corrupt `encoded` string, or `cdf` or +`precision` do not match encoder), the decode is unsuccessful. Because of +potential performance issues, the decoder does not return error status. + +encoded: A scalar string tensor from RangeEncode. +shape: An int32 1-D tensor representing the shape of the data encoded by + RangeEncode. +decoded: An int32 tensor with shape equal to `shape`. +precision: The number of bits for probability quantization. Must be <= 16, and + must match the precision used by RangeEncode that produced `encoded`. +)doc"); +// clang-format on +} // namespace tensorflow diff --git a/tensorflow/contrib/coder/python/ops/coder_ops.py b/tensorflow/contrib/coder/python/ops/coder_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..bb262e338baf1d9c3c043f03a02c2d2851e22b49 --- /dev/null +++ b/tensorflow/contrib/coder/python/ops/coder_ops.py @@ -0,0 +1,30 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Range coder operations.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=wildcard-import,unused-import +from tensorflow.contrib.coder.python.ops import gen_coder_ops +from tensorflow.contrib.coder.python.ops.gen_coder_ops import * +# pylint: enable=wildcard-import,unused-import +from tensorflow.contrib.util import loader +from tensorflow.python.platform import resource_loader + + +_coder_ops = loader.load_op_library( + resource_loader.get_path_to_datafile("_coder_ops.so")) diff --git a/tensorflow/contrib/coder/python/ops/coder_ops_test.py b/tensorflow/contrib/coder/python/ops/coder_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..d5e14e7a641b5673e97882daf2b5a1796ee1bbef --- /dev/null +++ b/tensorflow/contrib/coder/python/ops/coder_ops_test.py @@ -0,0 +1,53 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Coder operations tests.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.coder.python.ops import coder_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.platform import test + + +class CoderOpsTest(test.TestCase): + """Coder ops test. + + Coder ops have C++ tests. Python test just ensures that Python binding is not + broken. + """ + + def testReadmeExample(self): + data = random_ops.random_uniform((128, 128), 0, 10, dtype=dtypes.int32) + histogram = math_ops.bincount(data, minlength=10, maxlength=10) + cdf = math_ops.cumsum(histogram, exclusive=False) + cdf = array_ops.pad(cdf, [[1, 0]]) + cdf = array_ops.reshape(cdf, [1, 1, -1]) + + data = math_ops.cast(data, dtypes.int16) + encoded = coder_ops.range_encode(data, cdf, precision=14) + decoded = coder_ops.range_decode( + encoded, array_ops.shape(data), cdf, precision=14) + + with self.test_session() as sess: + self.assertAllEqual(*sess.run((data, decoded))) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/copy_graph/__init__.py b/tensorflow/contrib/copy_graph/__init__.py index 30a0aac140b576c501595fd6c8767b7dddde8e58..61ee39e4be1f0471309bb2672476dd9100cbfd49 100644 --- a/tensorflow/contrib/copy_graph/__init__.py +++ b/tensorflow/contrib/copy_graph/__init__.py @@ -13,8 +13,6 @@ # limitations under the License. # ============================================================================== """Functions to copy elements between graphs. - -See the @{$python/contrib.copy_graph} guide. """ from __future__ import absolute_import diff --git a/tensorflow/contrib/copy_graph/python/util/copy_elements.py b/tensorflow/contrib/copy_graph/python/util/copy_elements.py index d060eda0a74010db10d9506b2a1c2345b2731709..bae66ffd4289308f2cbfc730ec50d057b13923fb 100644 --- a/tensorflow/contrib/copy_graph/python/util/copy_elements.py +++ b/tensorflow/contrib/copy_graph/python/util/copy_elements.py @@ -225,6 +225,7 @@ def copy_op_to_graph(org_instance, to_graph, variables, new_original_op, op_def) #Use Graph's hidden methods to add the op + to_graph._add_op(new_op) # pylint: disable=protected-access to_graph._record_op_seen_by_control_dependencies(new_op) for device_function in reversed(to_graph._device_function_stack): new_op._set_device(device_function(new_op)) diff --git a/tensorflow/contrib/crf/__init__.py b/tensorflow/contrib/crf/__init__.py index bc749339bd4d49c8372bc731da98732f8c19cbe1..046c509626bc2eb20a65c0b38495ff37c294e0e1 100644 --- a/tensorflow/contrib/crf/__init__.py +++ b/tensorflow/contrib/crf/__init__.py @@ -16,15 +16,15 @@ See the @{$python/contrib.crf} guide. -@@crf_sequence_score -@@crf_log_norm -@@crf_log_likelihood -@@crf_unary_score @@crf_binary_score @@crf_decode -@@CrfForwardRnnCell -@@CrfDecodeForwardRnnCell +@@crf_log_likelihood +@@crf_log_norm +@@crf_sequence_score +@@crf_unary_score @@CrfDecodeBackwardRnnCell +@@CrfDecodeForwardRnnCell +@@CrfForwardRnnCell @@viterbi_decode """ @@ -32,16 +32,15 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.crf.python.ops.crf import _lengths_to_masks from tensorflow.contrib.crf.python.ops.crf import crf_binary_score from tensorflow.contrib.crf.python.ops.crf import crf_decode from tensorflow.contrib.crf.python.ops.crf import crf_log_likelihood from tensorflow.contrib.crf.python.ops.crf import crf_log_norm from tensorflow.contrib.crf.python.ops.crf import crf_sequence_score from tensorflow.contrib.crf.python.ops.crf import crf_unary_score -from tensorflow.contrib.crf.python.ops.crf import CrfForwardRnnCell -from tensorflow.contrib.crf.python.ops.crf import CrfDecodeForwardRnnCell from tensorflow.contrib.crf.python.ops.crf import CrfDecodeBackwardRnnCell +from tensorflow.contrib.crf.python.ops.crf import CrfDecodeForwardRnnCell +from tensorflow.contrib.crf.python.ops.crf import CrfForwardRnnCell from tensorflow.contrib.crf.python.ops.crf import viterbi_decode from tensorflow.python.util.all_util import remove_undocumented diff --git a/tensorflow/contrib/crf/python/kernel_tests/crf_test.py b/tensorflow/contrib/crf/python/kernel_tests/crf_test.py index b47fb426a193e0fcc075deafae3eaab698f18ec9..721dc4d0801d1f0e116921888e3851a95e0b72b0 100644 --- a/tensorflow/contrib/crf/python/kernel_tests/crf_test.py +++ b/tensorflow/contrib/crf/python/kernel_tests/crf_test.py @@ -179,17 +179,6 @@ class CrfTest(test.TestCase): tf_total_log_likelihood = sess.run(total_log_likelihood) self.assertAllClose(tf_total_log_likelihood, 0.0) - def testLengthsToMasks(self): - with self.test_session() as sess: - sequence_lengths = [4, 1, 8, 2] - max_sequence_length = max(sequence_lengths) - mask = crf._lengths_to_masks(sequence_lengths, max_sequence_length) - tf_mask = sess.run(mask) - self.assertEqual(len(tf_mask), len(sequence_lengths)) - for m, l in zip(tf_mask, sequence_lengths): - self.assertAllEqual(m[:l], [1] * l) - self.assertAllEqual(m[l:], [0] * (len(m) - l)) - def testViterbiDecode(self): inputs = np.array( [[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], dtype=np.float32) diff --git a/tensorflow/contrib/crf/python/ops/crf.py b/tensorflow/contrib/crf/python/ops/crf.py index 7f5ae937b26f465076c6976429697c35924432e5..62708636c6181ca63cddf2b2e7c84d3da740282a 100644 --- a/tensorflow/contrib/crf/python/ops/crf.py +++ b/tensorflow/contrib/crf/python/ops/crf.py @@ -70,25 +70,6 @@ __all__ = [ ] -def _lengths_to_masks(lengths, max_length): - """Creates a binary matrix that can be used to mask away padding. - - Args: - lengths: A vector of integers representing lengths. - max_length: An integer indicating the maximum length. All values in - lengths should be less than max_length. - Returns: - masks: Masks that can be used to get rid of padding. - """ - tiled_ranges = array_ops.tile( - array_ops.expand_dims(math_ops.range(max_length), 0), - [array_ops.shape(lengths)[0], 1]) - lengths = array_ops.expand_dims(lengths, 1) - masks = math_ops.to_float( - math_ops.to_int64(tiled_ranges) < math_ops.to_int64(lengths)) - return masks - - def crf_sequence_score(inputs, tag_indices, sequence_lengths, transition_params): """Computes the unnormalized score for a tag sequence. @@ -234,7 +215,9 @@ def crf_unary_score(tag_indices, sequence_lengths, inputs): array_ops.gather(flattened_inputs, flattened_tag_indices), [batch_size, max_seq_len]) - masks = _lengths_to_masks(sequence_lengths, array_ops.shape(tag_indices)[1]) + masks = array_ops.sequence_mask(sequence_lengths, + maxlen=array_ops.shape(tag_indices)[1], + dtype=dtypes.float32) unary_scores = math_ops.reduce_sum(unary_scores * masks, 1) return unary_scores @@ -268,7 +251,9 @@ def crf_binary_score(tag_indices, sequence_lengths, transition_params): binary_scores = array_ops.gather(flattened_transition_params, flattened_transition_indices) - masks = _lengths_to_masks(sequence_lengths, array_ops.shape(tag_indices)[1]) + masks = array_ops.sequence_mask(sequence_lengths, + maxlen=array_ops.shape(tag_indices)[1], + dtype=dtypes.float32) truncated_masks = array_ops.slice(masks, [0, 1], [-1, -1]) binary_scores = math_ops.reduce_sum(binary_scores * truncated_masks, 1) return binary_scores diff --git a/tensorflow/contrib/cudnn_rnn/BUILD b/tensorflow/contrib/cudnn_rnn/BUILD index fce2c03e69bc4b8b0ac46b8e081a33c43c9d41ab..fec358c4e1067dc8dc8173d1b9d05dc90b90ca05 100644 --- a/tensorflow/contrib/cudnn_rnn/BUILD +++ b/tensorflow/contrib/cudnn_rnn/BUILD @@ -25,6 +25,7 @@ tf_custom_op_library( ], deps = [ "//tensorflow/core/kernels:bounds_check_lib", + "@farmhash_archive//:farmhash", ], ) @@ -39,6 +40,7 @@ tf_kernel_library( "//tensorflow/core:stream_executor", "//tensorflow/core/kernels:bounds_check_lib", "//third_party/eigen3", + "@farmhash_archive//:farmhash", ], ) @@ -146,10 +148,10 @@ cuda_py_test( cuda_py_test( name = "cudnn_rnn_ops_benchmark", - size = "large", + size = "small", srcs = ["python/kernel_tests/cudnn_rnn_ops_benchmark.py"], additional_deps = [ - ":cudnn_rnn_ops_py", + ":cudnn_rnn_py", "//tensorflow/contrib/rnn:rnn_py", "//tensorflow/python:array_ops", "//tensorflow/python:client", @@ -164,7 +166,6 @@ cuda_py_test( "//tensorflow/python:variables", ], tags = [ - "manual", "noasan", # http://b/62067814 "nomsan", "notsan", diff --git a/tensorflow/contrib/cudnn_rnn/kernels/cudnn_rnn_ops.cc b/tensorflow/contrib/cudnn_rnn/kernels/cudnn_rnn_ops.cc index 5d5f593d016a3bb9f7b5ea8f5cd40c29268dc4f5..ba9686e94ee7072cc485c955decb2287bd4a56f3 100644 --- a/tensorflow/contrib/cudnn_rnn/kernels/cudnn_rnn_ops.cc +++ b/tensorflow/contrib/cudnn_rnn/kernels/cudnn_rnn_ops.cc @@ -39,6 +39,7 @@ limitations under the License. #include "tensorflow/core/lib/gtl/inlined_vector.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/platform/fingerprint.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/env_var.h" @@ -369,6 +370,27 @@ struct CudnnModelShapes { } }; +// Utility class for using CudnnModelShapes as a hash table key. +struct CudnnModelShapesHasher { + uint64 operator()(const CudnnModelShapes& to_hash) const { + uint64 hash = static_cast(to_hash.num_layers); + hash = tensorflow::FingerprintCat64( + hash, static_cast(to_hash.input_size)); + hash = tensorflow::FingerprintCat64(hash, + static_cast(to_hash.num_units)); + return tensorflow::FingerprintCat64(hash, + static_cast(to_hash.dir_count)); + } +}; + +// Utility class for using CudnnModelShapes as a hash table key. +struct CudnnModelShapesComparator { + bool operator()(const CudnnModelShapes& first, + const CudnnModelShapes& second) const { + return first.IsCompatibleWith(second); + } +}; + // Extract and checks the forward input tensors, parameters, and shapes from the // OpKernelContext. Status ExtractForwardInput(OpKernelContext* context, @@ -627,7 +649,7 @@ class CudnnRNNParamsToCanonical : public CudnnRNNKernelCommon { } const int num_params_per_layer = num_params_ / num_layers / num_dirs; // Number of params applied on inputs. The rest are applied on recurrent - // hiddden states. + // hidden states. const int num_params_input_state = num_params_per_layer / 2; CHECK(num_params_ % (num_layers * num_dirs) == 0) << "Number of params is not a multiple of num_layers * num_dirs."; @@ -764,6 +786,13 @@ TF_CALL_float(REGISTER_GPU); TF_CALL_double(REGISTER_GPU); #undef REGISTER_GPU +// Pointers to RNN scratch space for a specific set of shape parameters (used as +// a hash table value in CudnnRNNForwardOp and CudnnRNNBackwardOp). +struct RnnScratchSpace { + std::unique_ptr rnn_desc; + std::unique_ptr dropout_state_allocator; +}; + // Run the forward operation of the RNN model. template class CudnnRNNForwardOp : public CudnnRNNKernelCommon { @@ -808,32 +837,7 @@ class CudnnRNNForwardOp : public CudnnRNNKernelCommon { OP_REQUIRES_OK(context, ToRNNInputMode(rnn_input_mode(), model_shapes.num_units, model_shapes.input_size, &input_mode)); - // TODO(zhengxq): cache the descriptor so we don't have to create them all - // the time. auto data_type = ToDataType::value; - { - mutex_lock l(mu_); - if (model_shapes_ == nullptr) { - model_shapes_.reset(new CudnnModelShapes(model_shapes)); - } else { - OP_REQUIRES(context, model_shapes_->IsCompatibleWith(model_shapes), - errors::InvalidArgument( - "Incompatible rnn model shapes inferred: expecting ", - model_shapes_->RnnDescDebugString(), ", getting ", - model_shapes.RnnDescDebugString(), ".")); - } - if (rnn_desc_ == nullptr || ResetRndGenState()) { - dropout_state_allocator_.reset( - new CudnnRNNPersistentSpaceAllocator(context)); - auto rnn_desc_s = executor->createRnnDescriptor( - model_shapes_->num_layers, model_shapes_->num_units, - model_shapes_->input_size, input_mode, rnn_direction_mode(), - rnn_mode(), data_type, dropout(), seed(), - dropout_state_allocator_.get()); - OP_REQUIRES_OK(context, FromExecutorStatus(rnn_desc_s)); - rnn_desc_ = std::move(rnn_desc_s.ConsumeValueOrDie()); - } - } auto input_desc_s = executor->createRnnSequenceTensorDescriptor( input_shape.dim_size(0), input_shape.dim_size(1), @@ -882,14 +886,27 @@ class CudnnRNNForwardOp : public CudnnRNNKernelCommon { bool launch_status = false; { mutex_lock l(mu_); + RnnScratchSpace& rnn_state = rnn_state_cache_[model_shapes]; + if (rnn_state.rnn_desc == nullptr || ResetRndGenState()) { + CudnnRNNPersistentSpaceAllocator* dropout_state_allocator = + new CudnnRNNPersistentSpaceAllocator(context); + rnn_state.dropout_state_allocator.reset(dropout_state_allocator); + auto rnn_desc_s = executor->createRnnDescriptor( + model_shapes.num_layers, model_shapes.num_units, + model_shapes.input_size, input_mode, rnn_direction_mode(), + rnn_mode(), data_type, dropout(), seed(), dropout_state_allocator); + OP_REQUIRES_OK(context, FromExecutorStatus(rnn_desc_s)); + rnn_state.rnn_desc = std::move(rnn_desc_s.ConsumeValueOrDie()); + } launch_status = stream - ->ThenRnnForward( - *rnn_desc_, *input_desc, input_data, *hidden_state_desc, - input_h_data, *hidden_state_desc, input_c_data, params_data, - *output_desc, &output_data, *hidden_state_desc, - &output_h_data, *hidden_state_desc, &output_c_data, - is_training_, &reserve_space_allocator, &workspace_allocator) + ->ThenRnnForward(*rnn_state.rnn_desc, *input_desc, input_data, + *hidden_state_desc, input_h_data, + *hidden_state_desc, input_c_data, params_data, + *output_desc, &output_data, *hidden_state_desc, + &output_h_data, *hidden_state_desc, + &output_c_data, is_training_, + &reserve_space_allocator, &workspace_allocator) .ok(); } OP_REQUIRES(context, launch_status, @@ -899,10 +916,9 @@ class CudnnRNNForwardOp : public CudnnRNNKernelCommon { private: mutex mu_; bool is_training_; - std::unique_ptr model_shapes_ GUARDED_BY(mu_); - std::unique_ptr rnn_desc_ GUARDED_BY(mu_); - std::unique_ptr dropout_state_allocator_ - GUARDED_BY(mu_); + std::unordered_map + rnn_state_cache_ GUARDED_BY(mu_); }; #define REGISTER_GPU(T) \ @@ -1022,32 +1038,6 @@ class CudnnRNNBackwardOp : public CudnnRNNKernelCommon { OP_REQUIRES_OK(context, ToRNNInputMode(rnn_input_mode(), model_shapes.num_units, model_shapes.input_size, &input_mode)); - // TODO(zhengxq): cache the descriptor so we don't have to create them all - // the time. - { - mutex_lock l(mu_); - if (model_shapes_ == nullptr) { - model_shapes_.reset(new CudnnModelShapes(model_shapes)); - } else { - OP_REQUIRES(context, model_shapes_->IsCompatibleWith(model_shapes), - errors::InvalidArgument( - "Incompatible rnn model shapes inferred: expecting ", - model_shapes_->RnnDescDebugString(), ", getting ", - model_shapes.RnnDescDebugString(), ".")); - } - - if (rnn_desc_ == nullptr || ResetRndGenState()) { - dropout_state_allocator_.reset( - new CudnnRNNPersistentSpaceAllocator(context)); - auto rnn_desc_s = executor->createRnnDescriptor( - model_shapes.num_layers, model_shapes.num_units, - model_shapes.input_size, input_mode, rnn_direction_mode(), - rnn_mode(), data_type, dropout(), seed(), - dropout_state_allocator_.get()); - OP_REQUIRES_OK(context, FromExecutorStatus(rnn_desc_s)); - rnn_desc_ = std::move(rnn_desc_s.ConsumeValueOrDie()); - } - } auto input_desc_s = executor->createRnnSequenceTensorDescriptor( input_shape.dim_size(0), input_shape.dim_size(1), @@ -1100,17 +1090,30 @@ class CudnnRNNBackwardOp : public CudnnRNNKernelCommon { bool launch_status = false; { mutex_lock l(mu_); + RnnScratchSpace& rnn_state = rnn_state_cache_[model_shapes]; + if (rnn_state.rnn_desc == nullptr || ResetRndGenState()) { + CudnnRNNPersistentSpaceAllocator* dropout_state_allocator = + new CudnnRNNPersistentSpaceAllocator(context); + rnn_state.dropout_state_allocator.reset(dropout_state_allocator); + auto rnn_desc_s = executor->createRnnDescriptor( + model_shapes.num_layers, model_shapes.num_units, + model_shapes.input_size, input_mode, rnn_direction_mode(), + rnn_mode(), data_type, dropout(), seed(), dropout_state_allocator); + OP_REQUIRES_OK(context, FromExecutorStatus(rnn_desc_s)); + rnn_state.rnn_desc = std::move(rnn_desc_s.ConsumeValueOrDie()); + } launch_status = stream - ->ThenRnnBackward( - *rnn_desc_, *input_desc, input_data, *hidden_state_desc, - input_h_data, *hidden_state_desc, input_c_data, params_data, - *output_desc, output_data, *hidden_state_desc, output_h_data, - *hidden_state_desc, output_c_data, output_backprop_data, - output_h_backprop_data, output_c_backprop_data, - &input_backprop_data, &input_h_backprop_data, - &input_c_backprop_data, ¶ms_backprop_data, - &reserve_space_uint8, &workspace_allocator) + ->ThenRnnBackward(*rnn_state.rnn_desc, *input_desc, input_data, + *hidden_state_desc, input_h_data, + *hidden_state_desc, input_c_data, params_data, + *output_desc, output_data, *hidden_state_desc, + output_h_data, *hidden_state_desc, + output_c_data, output_backprop_data, + output_h_backprop_data, output_c_backprop_data, + &input_backprop_data, &input_h_backprop_data, + &input_c_backprop_data, ¶ms_backprop_data, + &reserve_space_uint8, &workspace_allocator) .ok(); } OP_REQUIRES(context, launch_status, @@ -1119,10 +1122,9 @@ class CudnnRNNBackwardOp : public CudnnRNNKernelCommon { private: mutex mu_; - std::unique_ptr model_shapes_ GUARDED_BY(mu_); - std::unique_ptr rnn_desc_ GUARDED_BY(mu_); - std::unique_ptr dropout_state_allocator_ - GUARDED_BY(mu_); + std::unordered_map + rnn_state_cache_ GUARDED_BY(mu_); }; #define REGISTER_GPU(T) \ diff --git a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_benchmark.py b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_benchmark.py index ff409ac71826f1f0f57e9133d768003f849abc09..4fc5ff1bd1887c4532e95fcf0e791d72b20471b0 100644 --- a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_benchmark.py +++ b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_benchmark.py @@ -20,8 +20,8 @@ from __future__ import print_function import time +from tensorflow.contrib import rnn as contrib_rnn from tensorflow.contrib.cudnn_rnn.python.ops import cudnn_rnn_ops -from tensorflow.contrib.rnn.python.ops import core_rnn from tensorflow.contrib.rnn.python.ops import lstm_ops from tensorflow.python.client import session from tensorflow.python.framework import dtypes @@ -29,8 +29,7 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gradients_impl -from tensorflow.python.ops import init_ops -from tensorflow.python.ops import rnn_cell +from tensorflow.python.ops import rnn from tensorflow.python.ops import variables from tensorflow.python.platform import test @@ -44,19 +43,19 @@ class CudnnRNNBenchmark(test.Benchmark): "large": { "num_layers": 4, "num_units": 1024, - "seq_length": 40, + "seq_length": 50, "batch_size": 64, }, "medium": { "num_layers": 4, "num_units": 512, - "seq_length": 30, + "seq_length": 50, "batch_size": 64, }, "small": { "num_layers": 4, "num_units": 128, - "seq_length": 20, + "seq_length": 50, "batch_size": 64, }, } @@ -71,7 +70,7 @@ class CudnnRNNBenchmark(test.Benchmark): def _BenchmarkOp(self, op, desc): burn_in_steps = 10 - benchmark_steps = 40 + benchmark_steps = 20 with session.Session() as sess: sess.run(variables.global_variables_initializer()) for i in xrange(burn_in_steps + benchmark_steps): @@ -126,16 +125,12 @@ class CudnnRNNBenchmark(test.Benchmark): seq_length = config["seq_length"] with ops.Graph().as_default(), ops.device("/device:GPU:0"): - inputs = seq_length * [ - array_ops.zeros([batch_size, num_units], dtypes.float32) - ] - initializer = init_ops.random_uniform_initializer(-0.01, 0.01, seed=127) - - cell = rnn_cell.LSTMCell( - num_units=num_units, initializer=initializer, state_is_tuple=True) - multi_cell = rnn_cell.MultiRNNCell( - [cell() for _ in range(num_layers)]) - outputs, final_state = core_rnn.static_rnn( + inputs = array_ops.zeros([batch_size, seq_length, num_units], + dtypes.float32) + + multi_cell = contrib_rnn.MultiRNNCell( + [contrib_rnn.BasicLSTMCell(num_units) for _ in range(num_layers)]) + outputs, final_state = rnn.dynamic_rnn( multi_cell, inputs, dtype=dtypes.float32) trainable_variables = ops.get_collection( ops.GraphKeys.TRAINABLE_VARIABLES) @@ -154,14 +149,12 @@ class CudnnRNNBenchmark(test.Benchmark): seq_length = config["seq_length"] with ops.Graph().as_default(), ops.device("/device:GPU:0"): - inputs = seq_length * [ - array_ops.zeros([batch_size, num_units], dtypes.float32) - ] - cell = lambda: lstm_ops.LSTMBlockCell(num_units=num_units) # pylint: disable=cell-var-from-loop - - multi_cell = rnn_cell.MultiRNNCell( - [cell() for _ in range(num_layers)]) - outputs, final_state = core_rnn.static_rnn( + inputs = array_ops.zeros([batch_size, seq_length, num_units], + dtypes.float32) + + multi_cell = contrib_rnn.MultiRNNCell( + [lstm_ops.LSTMBlockCell(num_units) for _ in range(num_layers)]) + outputs, final_state = rnn.dynamic_rnn( multi_cell, inputs, dtype=dtypes.float32) trainable_variables = ops.get_collection( ops.GraphKeys.TRAINABLE_VARIABLES) diff --git a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py index e65394cba07574ed49398981f1cbd8bcb402e24f..49d305cb0dd0387c34b7feb79ef631eac9e935cd 100644 --- a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py +++ b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py @@ -29,6 +29,8 @@ import numpy as np from tensorflow.contrib.cudnn_rnn.python.layers import cudnn_rnn from tensorflow.contrib.cudnn_rnn.python.ops import cudnn_rnn_ops from tensorflow.contrib.rnn.python.ops import rnn as contrib_rnn_lib +from tensorflow.python.eager import backprop +from tensorflow.python.eager import context from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed @@ -314,6 +316,101 @@ class CudnnRNNTestBasic(TensorFlowTestCase): self.assertEqual(0, total_sum2_v) self.assertEqual(0, total_sum3_v) + def testSaveableGraphDeviceAssignment(self): + num_layers = 4 + num_units = 2 + batch_size = 8 + direction = CUDNN_RNN_UNIDIRECTION + dir_count = 1 + + def DeviceFn(op): + if op.type in ("Variable", "VariableV2"): + return "/cpu:0" + else: + return "/gpu:0" + + with ops.Graph().as_default() as g: + with ops.device(DeviceFn): + with vs.variable_scope("main"): + kernel_initializer = init_ops.constant_initializer(3.14) + bias_initializer = init_ops.constant_initializer(1.59) + inputs = random_ops.random_uniform( + [num_layers * dir_count, batch_size, num_units], + dtype=dtypes.float32) + + lstm = cudnn_rnn.CudnnLSTM(num_layers, num_units, + direction=direction, + kernel_initializer=kernel_initializer, + bias_initializer=bias_initializer, + name="awesome_lstm") + outputs = lstm(inputs) + + # saver is created in the scope of DeviceFn. + saver = saver_lib.Saver() + + with self.test_session(use_gpu=True, graph=g) as sess: + save_path = os.path.join(self.get_temp_dir(), + "test-saveable-device-assignment") + sess.run(variables.global_variables_initializer()) + + saver.save(sess, save_path) + saver.restore(sess, save_path) + sess.run(outputs) + + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def testDifferentShapesEager(self): + # Checks that kernel caching does not cause sharing of temporary storage + # across different input shapes when executing eagerly. + with context.eager_mode(): + with ops.device("gpu:0"): + first_output, _ = cudnn_rnn.CudnnGRU(1, 100)( + array_ops.zeros([28, 100, 28])) + second_output, _ = cudnn_rnn.CudnnGRU(1, 100)( + array_ops.zeros([28, 100, 100])) + self.assertAllEqual([28, 100, 100], first_output.shape) + self.assertAllEqual([28, 100, 100], second_output.shape) + + def _LossFunc(): + first_output, _ = cudnn_rnn.CudnnGRU(1, 100)( + array_ops.zeros([28, 100, 28])) + second_output, _ = cudnn_rnn.CudnnGRU(1, 100)( + array_ops.zeros([28, 100, 100])) + return (math_ops.reduce_sum(first_output) + + math_ops.reduce_sum(second_output)) + + backprop.implicit_grad(_LossFunc)() + + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def testDifferentShapesGraph(self): + # Tests that a single kernel instance presented with multiple input shapes + # does not crash with graph execution. + with ops.device("gpu:0"): + layer = cudnn_rnn.CudnnGRU(1, 100) + layer(array_ops.zeros([28, 100, 100])) + + def _Cond(index, accumulation): + del accumulation # unused + return math_ops.less(index, 4) + + def _Body(index, accumulation): + layer_input = accumulation[:, :, 10 * (1 + index % 2):] + output, _ = layer(layer_input) + return index + 1, accumulation + output + + original_input = array_ops.zeros([28, 100, 100]) + _, accumulation = control_flow_ops.while_loop(_Cond, _Body, + [0, original_input]) + grad, = gradients.gradients( + math_ops.reduce_sum(accumulation), (original_input,)) + init_op = variables.global_variables_initializer() + with self.test_session() as sess: + sess.run(init_op) + accumulation_eval, grad_eval = sess.run((accumulation, grad)) + self.assertAllEqual([28, 100, 100], accumulation_eval.shape) + self.assertAllEqual([28, 100, 100], grad_eval.shape) + # TODO(jamesqin): Transform to parameterized test after it is included in the # TF open source codebase. diff --git a/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py b/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py index 37c61a71a3bdac4fadef58ba8c24b853fb3638ef..36fba917a8f56c26fd5b4c3468d1d980a8ba2ba5 100644 --- a/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py +++ b/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py @@ -176,8 +176,9 @@ class _CudnnRNN(base_layer.Layer): otherwise, it implies 'linear_input'. direction: the direction model that the model operates. Can be either 'unidirectional' or 'bidirectional' - dropout: dropout rate, a number between [0, 1]. Dropout is applied on - inputs of each layer. When set to 0, dropout is disabled. + dropout: dropout rate, a number between [0, 1]. Dropout is applied between + each layer (no dropout is applied for a model with a single layer). + When set to 0, dropout is disabled. seed: the op seed used for initializing dropout. See @{tf.set_random_seed} for behavior. dtype: tf.float16, tf.float32 or tf.float64 @@ -358,7 +359,7 @@ class _CudnnRNN(base_layer.Layer): # Create saveable in the outer scope of the cudnn subgraph, such that # alternative subgraph with platform-independent rnn cells can load the # checkpoints directly. - if not (self.built or vs.get_variable_scope().reuse): + if not (self.built or vs.get_variable_scope().reuse is True): self._create_saveable() self.built = True @@ -450,17 +451,18 @@ class _CudnnRNN(base_layer.Layer): raise RuntimeError( "%s._canonical_to_opaque invoked before input shape is known" % type(self).__name__) - return cudnn_rnn_ops.cudnn_rnn_canonical_to_opaque_params( - rnn_mode=self._rnn_mode, - num_layers=self._num_layers, - num_units=self._num_units, - input_size=self._input_size, - weights=cu_weights, - biases=cu_biases, - input_mode=self._input_mode, - seed=self._seed, - dropout=self._dropout, - direction=self._direction) + with ops.device("/gpu:0"): + return cudnn_rnn_ops.cudnn_rnn_canonical_to_opaque_params( + rnn_mode=self._rnn_mode, + num_layers=self._num_layers, + num_units=self._num_units, + input_size=self._input_size, + weights=cu_weights, + biases=cu_biases, + input_mode=self._input_mode, + seed=self._seed, + dropout=self._dropout, + direction=self._direction) def _forward(self, inputs, h, c, opaque_params, training): output, output_h, output_c = cudnn_rnn_ops._cudnn_rnn( # pylint:disable=protected-access @@ -489,14 +491,14 @@ class _CudnnRNN(base_layer.Layer): if self._saveable is not None: raise RuntimeError("Cudnn saveable already created.") self._saveable = self._saveable_cls( # pylint:disable=not-callable - self.trainable_variables[0], - self.num_layers, - self.num_units, - self.input_size, - self.input_mode, - self.direction, + opaque_params=self.trainable_variables[0], + num_layers=self.num_layers, + num_units=self.num_units, + input_size=self.input_size, + input_mode=self.input_mode, + direction=self.direction, scope=vs.get_variable_scope(), - name="%s_saveable" % self.trainable_variables[0].op.name) + name="%s_saveable" % self.trainable_variables[0].name.split(":")[0]) ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, self._saveable) diff --git a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py index dcd3d4732a27ae4bec579ac12ac568dc4a53baaa..e87162f0ee9cc4eed795555171f55a93639e83cf 100644 --- a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py +++ b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py @@ -72,7 +72,7 @@ class CudnnCompatibleLSTMCell(lstm_ops.LSTMBlockCell): def __init__(self, num_units, reuse=None): super(CudnnCompatibleLSTMCell, self).__init__( num_units, forget_bias=0, cell_clip=None, use_peephole=False, - reuse=reuse) + reuse=reuse, name="cudnn_compatible_lstm_cell") self._names.update({"scope": "cudnn_compatible_lstm_cell"}) @@ -303,16 +303,17 @@ class CudnnOpaqueParamsSaveable(saver.BaseSaverBuilder.SaveableObject): Returns: 2 list for weights and biases respectively. """ - weights, biases = gen_cudnn_rnn_ops.cudnn_rnn_params_to_canonical( - num_layers=self._num_layers, - num_units=self._num_units, - input_size=self._input_size, - params=self._variables, - num_params=self._num_params, - rnn_mode=self._rnn_mode, - input_mode=self._input_mode, - direction=self._direction) - return (weights, biases) + with ops.device("/gpu:0"): + weights, biases = gen_cudnn_rnn_ops.cudnn_rnn_params_to_canonical( + num_layers=self._num_layers, + num_units=self._num_units, + input_size=self._input_size, + params=self._variables, + num_params=self._num_params, + rnn_mode=self._rnn_mode, + input_mode=self._input_mode, + direction=self._direction) + return (weights, biases) def _CanonicalToOpaqueParams(self, cu_weights, cu_biases): """Converts from Cudnn canonical format to opaque params. @@ -323,15 +324,16 @@ class CudnnOpaqueParamsSaveable(saver.BaseSaverBuilder.SaveableObject): Returns: a single opaque tensor. """ - return gen_cudnn_rnn_ops.cudnn_rnn_canonical_to_params( - num_layers=self._num_layers, - num_units=self._num_units, - input_size=self._input_size, - weights=cu_weights, - biases=cu_biases, - rnn_mode=self._rnn_mode, - input_mode=self._input_mode, - direction=self._direction) + with ops.device("/gpu:0"): + return gen_cudnn_rnn_ops.cudnn_rnn_canonical_to_params( + num_layers=self._num_layers, + num_units=self._num_units, + input_size=self._input_size, + weights=cu_weights, + biases=cu_biases, + rnn_mode=self._rnn_mode, + input_mode=self._input_mode, + direction=self._direction) def _TransformCanonical(self, cu_weights, cu_biases): r"""Transform from Cudnn canonical to tf canonical. @@ -1352,7 +1354,7 @@ class _CudnnRNN(object): params: the parameter buffer created for this model. is_training: whether this operation will be used in training or inference. Returns: - output: the output sequuence. + output: the output sequence. output_h: the final state for h. output_c: the final state for c. This is only relevant for LSTM. """ @@ -1470,7 +1472,7 @@ class CudnnLSTM(_CudnnRNN): params: the parameter buffer created for this model. is_training: whether this operation will be used in training or inference. Returns: - output: the output sequuence. + output: the output sequence. output_h: the final state for h. output_c: the final state for c. """ @@ -1540,7 +1542,7 @@ class _CudnnRNNNoInputC(_CudnnRNN): params: the parameter buffer created for this model. is_training: whether this operation will be used in training or inference. Returns: - output: the output sequuence. + output: the output sequence. output_h: the final state for h. """ return _cudnn_rnn_no_input_c( diff --git a/tensorflow/contrib/data/BUILD b/tensorflow/contrib/data/BUILD index f7d8a084d9c12c05c411ae0751854d1823a818ec..8ecc003348d70379ee48d050e63e93d0dd38efaa 100644 --- a/tensorflow/contrib/data/BUILD +++ b/tensorflow/contrib/data/BUILD @@ -18,7 +18,9 @@ py_library( "//tensorflow/contrib/data/python/ops:dataset_ops", "//tensorflow/contrib/data/python/ops:iterator_ops", "//tensorflow/contrib/data/python/ops:readers", + "//tensorflow/contrib/data/python/ops:shuffle_ops", "//tensorflow/contrib/data/python/ops:transformation_ops", + "//tensorflow/python:parsing_ops", "//tensorflow/python:util", "//tensorflow/python/data/ops:iterator_ops", ], diff --git a/tensorflow/contrib/data/__init__.py b/tensorflow/contrib/data/__init__.py index 7c6244f22b0f41656369595d3e3e6c23b7088bcb..daeb6a610533404044d42033709d644deb481024 100644 --- a/tensorflow/contrib/data/__init__.py +++ b/tensorflow/contrib/data/__init__.py @@ -12,7 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""`tf.contrib.data.Dataset` API for input pipelines. +"""`tf.contrib.data` API for input pipelines. + +This module contains the experimental (less stable) counterpart to the +`tf.data` API. See @{tf.data.Dataset} and @{tf.data.Iterator} for the +stable classes. See the @{$datasets$Importing Data} Programmer's Guide for an overview. @@ -24,18 +28,20 @@ See the @{$datasets$Importing Data} Programmer's Guide for an overview. @@TextLineDataset @@batch_and_drop_remainder -@@padded_batch_and_drop_remainder @@dense_to_sparse_batch @@enumerate_dataset @@group_by_window @@ignore_errors @@make_saveable_from_iterator -@@read_batch_features -@@unbatch +@@map_and_batch +@@padded_batch_and_drop_remainder @@parallel_interleave +@@read_batch_features @@rejection_resample @@scan +@@shuffle_and_repeat @@sloppy_interleave +@@unbatch @@get_single_element """ @@ -48,6 +54,7 @@ from __future__ import print_function from tensorflow.contrib.data.python.ops.batching import batch_and_drop_remainder from tensorflow.contrib.data.python.ops.batching import dense_to_sparse_batch +from tensorflow.contrib.data.python.ops.batching import map_and_batch from tensorflow.contrib.data.python.ops.batching import padded_batch_and_drop_remainder from tensorflow.contrib.data.python.ops.batching import unbatch from tensorflow.contrib.data.python.ops.counter import Counter @@ -66,6 +73,7 @@ from tensorflow.contrib.data.python.ops.readers import TextLineDataset from tensorflow.contrib.data.python.ops.readers import TFRecordDataset from tensorflow.contrib.data.python.ops.resampling import rejection_resample from tensorflow.contrib.data.python.ops.scan_ops import scan +from tensorflow.contrib.data.python.ops.shuffle_ops import shuffle_and_repeat from tensorflow.python.data.ops.iterator_ops import Iterator # pylint: enable=unused-import diff --git a/tensorflow/contrib/data/kernels/prefetching_kernels.cc b/tensorflow/contrib/data/kernels/prefetching_kernels.cc index c9a3537c70c711290fb1111a1594e6dea3bc07a9..d3df14bdd03476e9ee4015b374512e5bb9893a63 100644 --- a/tensorflow/contrib/data/kernels/prefetching_kernels.cc +++ b/tensorflow/contrib/data/kernels/prefetching_kernels.cc @@ -83,11 +83,10 @@ class FunctionBufferingResource : public ResourceBase { return Status::OK(); } AttrValueMap attr_values = func_.attr(); - AttrValue v; - v.set_s(target_device_); - AddAttr("_target", v, &attr_values); - - return lib_->Instantiate(func_.name(), AttrSlice(&attr_values), &handle_); + FunctionLibraryRuntime::InstantiateOptions opts; + opts.target = target_device_; + return lib_->Instantiate(func_.name(), AttrSlice(&attr_values), opts, + &handle_); } // Returns true if we've got to the end of the sequence and exhausted the diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD index 1d4817fa2670317f4f4e9e63c724a79e18aa35bc..1fbf18f30a293de697826885d15bb95b40568daa 100644 --- a/tensorflow/contrib/data/python/kernel_tests/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/BUILD @@ -4,7 +4,7 @@ licenses(["notice"]) # Apache 2.0 exports_files(["LICENSE"]) -load("//tensorflow:tensorflow.bzl", "py_test") +load("//tensorflow:tensorflow.bzl", "py_test", "tf_py_test") py_test( name = "batch_dataset_op_test", @@ -36,6 +36,7 @@ py_test( srcs = ["bucketing_test.py"], srcs_version = "PY2AND3", deps = [ + ":dataset_serialization_test", "//tensorflow/contrib/data/python/ops:dataset_ops", "//tensorflow/contrib/data/python/ops:transformation_ops", "//tensorflow/python:array_ops", @@ -75,13 +76,11 @@ py_test( srcs = ["concatenate_dataset_op_test.py"], srcs_version = "PY2AND3", deps = [ + ":dataset_serialization_test", "//tensorflow/contrib/data/python/ops:dataset_ops", - "//tensorflow/contrib/data/python/ops:iterator_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:errors", - "//tensorflow/python:framework_ops", "//tensorflow/python:tensor_shape", - "//tensorflow/python:training", "//tensorflow/python/data/util:nest", "//third_party/py/numpy", ], @@ -89,7 +88,7 @@ py_test( py_test( name = "dataset_constructor_op_test", - size = "small", + size = "medium", srcs = ["dataset_constructor_op_test.py"], srcs_version = "PY2AND3", tags = [ @@ -118,7 +117,6 @@ py_test( py_library( name = "dataset_serialization_test", - testonly = 1, srcs = [ "dataset_serialization_test_base.py", ], @@ -157,14 +155,13 @@ py_test( ], ) -py_test( +tf_py_test( name = "flat_map_dataset_op_test", - size = "small", + size = "medium", srcs = ["flat_map_dataset_op_test.py"], - srcs_version = "PY2AND3", - tags = ["no_pip"], - deps = [ + additional_deps = [ ":dataset_serialization_test", + "//third_party/py/numpy", "//tensorflow/contrib/data/python/ops:dataset_ops", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -177,17 +174,19 @@ py_test( "//tensorflow/python:session", "//tensorflow/python:training", "//tensorflow/python:variable_scope", - "//third_party/py/numpy", ], + grpc_enabled = True, + tags = ["no_pip"], ) py_test( name = "interleave_dataset_op_test", - size = "small", + size = "medium", srcs = ["interleave_dataset_op_test.py"], srcs_version = "PY2AND3", tags = [ - "manual", # b/67958761 + "no_oss", + "no_pip", ], deps = [ ":dataset_serialization_test", @@ -207,13 +206,11 @@ py_test( ], ) -py_test( +tf_py_test( name = "iterator_ops_cluster_test", size = "small", srcs = ["iterator_ops_cluster_test.py"], - srcs_version = "PY2AND3", - tags = ["no_windows"], - deps = [ + additional_deps = [ "//tensorflow/contrib/data/python/ops:dataset_ops", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", @@ -227,14 +224,19 @@ py_test( "//tensorflow/python:session", "//tensorflow/python/data/ops:iterator_ops", ], + grpc_enabled = True, + tags = [ + "no_windows", + "oss_serial", + ], ) -py_test( +tf_py_test( name = "iterator_ops_test", size = "small", srcs = ["iterator_ops_test.py"], - srcs_version = "PY2AND3", - deps = [ + additional_deps = [ + "//third_party/py/numpy", "//tensorflow/contrib/data/python/ops:dataset_ops", "//tensorflow/contrib/data/python/ops:readers", "//tensorflow/core:protos_all_py", @@ -256,8 +258,8 @@ py_test( "//tensorflow/python:session", "//tensorflow/python:training", "//tensorflow/python/data/ops:iterator_ops", - "//third_party/py/numpy", ], + grpc_enabled = True, ) py_test( @@ -277,7 +279,7 @@ py_test( py_test( name = "map_dataset_op_test", - size = "small", + size = "medium", srcs = ["map_dataset_op_test.py"], srcs_version = "PY2AND3", tags = ["no_pip"], @@ -304,6 +306,7 @@ py_test( "//tensorflow/python:string_ops", "//tensorflow/python:util", "//tensorflow/python:variable_scope", + "//tensorflow/python/data/ops:dataset_ops", "//third_party/py/numpy", ], ) @@ -327,8 +330,8 @@ py_test( srcs = ["range_dataset_op_test.py"], srcs_version = "PY2AND3", deps = [ + ":dataset_serialization_test", "//tensorflow/contrib/data/python/ops:dataset_ops", - "//tensorflow/contrib/data/python/ops:iterator_ops", "//tensorflow/contrib/data/python/ops:transformation_ops", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -339,11 +342,8 @@ py_test( "//tensorflow/python:framework_ops", "//tensorflow/python:io_ops", "//tensorflow/python:parsing_ops", - "//tensorflow/python:platform", "//tensorflow/python:tensor_shape", - "//tensorflow/python:training", "//tensorflow/python:variables", - "//tensorflow/python/data/ops:iterator_ops", ], ) @@ -389,8 +389,27 @@ py_test( ) py_test( - name = "sequence_dataset_op_test", + name = "scan_dataset_op_test", size = "small", + srcs = ["scan_dataset_op_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test", + "//tensorflow/contrib/data/python/ops:transformation_ops", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) + +py_test( + name = "sequence_dataset_op_test", + size = "medium", srcs = ["sequence_dataset_op_test.py"], srcs_version = "PY2AND3", tags = ["no_pip"], @@ -419,20 +438,20 @@ py_test( py_test( name = "shuffle_dataset_op_test", - size = "small", + size = "medium", srcs = ["shuffle_dataset_op_test.py"], srcs_version = "PY2AND3", + tags = ["no_pip"], deps = [ + ":dataset_serialization_test", "//tensorflow/contrib/data/python/ops:dataset_ops", - "//tensorflow/contrib/data/python/ops:iterator_ops", + "//tensorflow/contrib/data/python/ops:shuffle_ops", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", "//tensorflow/python:dtypes", "//tensorflow/python:errors", "//tensorflow/python:framework_ops", - "//tensorflow/python:platform", - "//tensorflow/python:training", "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/data/ops:iterator_ops", "//third_party/py/numpy", @@ -450,6 +469,7 @@ py_test( "//tensorflow/python:client_testlib", "//tensorflow/python:dtypes", "//tensorflow/python:errors", + "@org_sqlite//:python", ], ) @@ -458,11 +478,32 @@ py_test( size = "small", srcs = ["stats_dataset_ops_test.py"], srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test", + "//tensorflow/contrib/data/python/ops:dataset_ops", + "//tensorflow/contrib/data/python/ops:transformation_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:errors", + ], +) + +py_test( + name = "unique_dataset_op_test", + size = "small", + srcs = ["unique_dataset_op_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], deps = [ + ":dataset_serialization_test", "//tensorflow/contrib/data/python/ops:dataset_ops", "//tensorflow/contrib/data/python/ops:transformation_ops", + "//tensorflow/contrib/stateless", + "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", "//tensorflow/python:errors", + "//third_party/py/numpy", ], ) diff --git a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py index a939b3c841286a3b5786268dc3a9c82fd7359bfb..015f69c5673f185c53e61a5df2636333699ae203 100644 --- a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py @@ -305,10 +305,10 @@ class BatchDatasetTest(test.TestCase): iterator = ( dataset_ops.Dataset.from_tensor_slices(components) .map(lambda x: array_ops.fill([x], x)).apply( - batching.dense_to_sparse_batch(4, - [12])).make_initializable_iterator()) + batching.dense_to_sparse_batch(4, [12])) + .make_initializable_iterator()) init_op = iterator.initializer - get_next = sparse_tensor.SparseTensor(*iterator.get_next()) + get_next = iterator.get_next() with self.test_session() as sess: sess.run(init_op) @@ -334,9 +334,9 @@ class BatchDatasetTest(test.TestCase): dataset_ops.Dataset.from_tensor_slices(components) .map(lambda x: array_ops.fill([x, x], x)).apply( batching.dense_to_sparse_batch( - 4, [5, -1])).make_initializable_iterator()) + 4, [5, None])).make_initializable_iterator()) init_op = iterator.initializer - get_next = sparse_tensor.SparseTensor(*iterator.get_next()) + get_next = iterator.get_next() with self.test_session() as sess: sess.run(init_op) @@ -363,25 +363,18 @@ class BatchDatasetTest(test.TestCase): def testDenseToSparseBatchDatasetWithInvalidShape(self): input_tensor = array_ops.constant([[1]]) - iterator = ( - dataset_ops.Dataset.from_tensors(input_tensor).apply( - batching.dense_to_sparse_batch(4, [-2])) - .make_initializable_iterator()) - init_op = iterator.initializer - - with self.test_session() as sess: - with self.assertRaisesRegexp(errors.InvalidArgumentError, - "Dimension -2 must be >= -1"): - sess.run(init_op) + with self.assertRaisesRegexp(ValueError, "Dimension -2 must be >= 0"): + dataset_ops.Dataset.from_tensors(input_tensor).apply( + batching.dense_to_sparse_batch(4, [-2])).make_initializable_iterator() def testDenseToSparseBatchDatasetShapeErrors(self): input_tensor = array_ops.placeholder(dtypes.int32) iterator = ( dataset_ops.Dataset.from_tensors(input_tensor).apply( - batching.dense_to_sparse_batch(4, - [12])).make_initializable_iterator()) + batching.dense_to_sparse_batch(4, [12])) + .make_initializable_iterator()) init_op = iterator.initializer - get_next = sparse_tensor.SparseTensor(*iterator.get_next()) + get_next = iterator.get_next() with self.test_session() as sess: # Initialize with an input tensor of incompatible rank. @@ -577,7 +570,7 @@ class BatchDatasetTest(test.TestCase): self.assertEqual([None], dataset.output_shapes[1][0].as_list()) self.assertEqual([None, 30], dataset.output_shapes[1][1].as_list()) - def testBatchAndMapDataset(self): + def _testBatchAndMapDatasetHelper(self, num_parallel_batches=1): """Test a dataset that maps a TF function across its input elements.""" # The pipeline is TensorSliceDataset -> # RepeatDataset(count) -> BatchAndMapDataset(square_3, batch_size). @@ -593,7 +586,10 @@ class BatchDatasetTest(test.TestCase): iterator = ( dataset_ops.Dataset.from_tensor_slices(components).repeat(count).apply( - batching.map_and_batch(_map_fn, batch_size)) + batching.map_and_batch( + map_func=_map_fn, + batch_size=batch_size, + num_parallel_batches=num_parallel_batches)) .make_initializable_iterator()) init_op = iterator.initializer get_next = iterator.get_next() @@ -627,7 +623,11 @@ class BatchDatasetTest(test.TestCase): for j in range(8): self.assertAllEqual(component[(i * 8 + j) % 7]**2, result_component[j]) - # The last batch should fail with `OutOfRange`. + result = sess.run(get_next) + for component, result_component in zip(components, result): + for j in range((14 * 7) % 8): + self.assertAllEqual(component[((num_batches - 1) * 8 + j) % 7]**2, + result_component[j]) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) @@ -640,6 +640,12 @@ class BatchDatasetTest(test.TestCase): with self.assertRaises(errors.InvalidArgumentError): sess.run(init_op, feed_dict={count: 14, batch_size: 0}) + def testBatchAndMapDataset(self): + return self._testBatchAndMapDatasetHelper() + + def testBatchAndMapDatasetWithParallelBatching(self): + return self._testBatchAndMapDatasetHelper(num_parallel_batches=10) + def testMapAndBatchSparse(self): def _sparse(i): @@ -722,6 +728,22 @@ class BatchDatasetSerializationTest( lambda: self.build_dataset(20.0, tensor_slice_len, batch_size), num_outputs) + def _build_dataset_dense_to_sparse(self, components): + return dataset_ops.Dataset.from_tensor_slices(components).map( + lambda x: array_ops.fill([x], x)).apply( + batching.dense_to_sparse_batch(4, [12])) + + # TODO(b/70988345): Re-enable when sparse tensors are properly supported by + # the DatasetSerializationTestBase. + def _testDenseToSparseBatchDatasetCore(self): + components = np.random.randint(5, size=(40,)).astype(np.int32) + diff_comp = np.random.randint(2, size=(100,)).astype(np.int32) + + num_outputs = len(components) // 4 + self.run_core_tests(lambda: self._build_dataset_dense_to_sparse(components), + lambda: self._build_dataset_dense_to_sparse(diff_comp), + num_outputs) + class PaddedBatchDatasetSerializationTest( dataset_serialization_test_base.DatasetSerializationTestBase): diff --git a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py index 765ed53618958a8c49b26e416c57be28ea3bba73..4d984bb4d76e52c4200ae471550dcf48668c5f89 100644 --- a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py @@ -19,6 +19,7 @@ from __future__ import print_function import numpy as np +from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base from tensorflow.contrib.data.python.ops import dataset_ops from tensorflow.contrib.data.python.ops import grouping from tensorflow.python.framework import constant_op @@ -160,6 +161,34 @@ class GroupByWindowTest(test.TestCase): self.assertEqual(len(components), sum(counts)) +class GroupByWindowSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_dataset(self, components): + return dataset_ops.Dataset.from_tensor_slices(components).repeat(-1).apply( + grouping.group_by_window(lambda x: x % 3, lambda _, xs: xs.batch(4), 4)) + + def testCoreGroupByWindow(self): + components = np.array( + [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 0, 0, 2, 2, 0, 0], dtype=np.int64) + self.verify_unused_iterator( + lambda: self._build_dataset(components), 12, verify_exhausted=False) + self.verify_init_before_restore( + lambda: self._build_dataset(components), 12, verify_exhausted=False) + self.verify_multiple_breaks( + lambda: self._build_dataset(components), 12, verify_exhausted=False) + self.verify_reset_restored_iterator( + lambda: self._build_dataset(components), 12, verify_exhausted=False) + self.verify_restore_in_empty_graph( + lambda: self._build_dataset(components), 12, verify_exhausted=False) + diff_components = np.array([0, 0, 0, 1, 1, 1], dtype=np.int64) + self.verify_restore_in_modified_graph( + lambda: self._build_dataset(components), + lambda: self._build_dataset(diff_components), + 12, + verify_exhausted=False) + + # NOTE(mrry): These tests are based on the tests in bucket_ops_test.py. # Currently, they use a constant batch size, though should be made to use a # different batch size per key. diff --git a/tensorflow/contrib/data/python/kernel_tests/concatenate_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/concatenate_dataset_op_test.py index 870352209a08e6bc08bcca227ba455ad1851e8bf..063c71063601002af8168c4facf4057433061ab7 100644 --- a/tensorflow/contrib/data/python/kernel_tests/concatenate_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/concatenate_dataset_op_test.py @@ -17,17 +17,14 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import os import numpy as np +from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base from tensorflow.contrib.data.python.ops import dataset_ops -from tensorflow.contrib.data.python.ops import iterator_ops from tensorflow.python.data.util import nest from tensorflow.python.framework import errors -from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.platform import test -from tensorflow.python.training import saver as saver_lib class ConcatenateDatasetTest(test.TestCase): @@ -133,139 +130,26 @@ class ConcatenateDatasetTest(test.TestCase): with self.assertRaisesRegexp(TypeError, "have different types"): input_dataset.concatenate(dataset_to_concatenate) - def _iterator_checkpoint_prefix(self): - return os.path.join(self.get_temp_dir(), "iterator") - def _build_graph(self, input_components, to_concatenate_components): - input_dataset = dataset_ops.Dataset.from_tensor_slices(input_components) - dataset_to_concatenate = dataset_ops.Dataset.from_tensor_slices( - to_concatenate_components) - iterator = input_dataset.concatenate( - dataset_to_concatenate).make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - saveable = iterator_ops.make_saveable_from_iterator(iterator) - ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) - # TODO(shivaniagrawal) : non-intuitive way, add support in mata_graph - for t in nest.flatten(get_next): - ops.add_to_collection("get_next", t) - return init_op, get_next - - def _testSaveRestoreUtility(self, start, break_range, stop): - path = self._iterator_checkpoint_prefix() - step = 0 - meta_filename = path + "-%d.meta" % step - - input_components = (np.tile(np.array([[1], [2], [3], [4]]), 20), np.tile( - np.array([[12], [13], [14], [15]]), 4)) - to_concatenate_components = (np.tile( - np.array([[5], [6], [7], [8], [9]]), 20), np.tile( - np.array([[16], [17], [18], [19], [20]]), 15)) - - with ops.Graph().as_default() as g: - init_op, get_next = self._build_graph(input_components, - to_concatenate_components) - saver = saver_lib.Saver() - with self.test_session(graph=g) as sess: - sess.run(init_op) - for i in range(start, break_range): - result = sess.run(get_next) - if i < 4: - for component, result_component in zip(input_components, result): - self.assertAllEqual(component[i], result_component) - else: - for component, result_component in zip(to_concatenate_components, - result): - self.assertAllEqual(component[i - 4], result_component) - saver.save(sess, path, step) - - with ops.Graph().as_default() as g: - saver = saver_lib.import_meta_graph(meta_filename) - with self.test_session(graph=g) as sess: - get_next = nest.pack_sequence_as(("a", "b"), - ops.get_collection("get_next")) - saver.restore(sess, saver_lib.latest_checkpoint(self.get_temp_dir())) - for i in range(break_range, stop): - result = sess.run(get_next) - if i < 4: - for component, result_component in zip(input_components, result): - self.assertAllEqual(component[i], result_component) - else: - for component, result_component in zip(to_concatenate_components, - result): - self.assertAllEqual(component[i - 4], result_component) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) +class ConcatenateDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): - def testRestoreAtFirstDataset(self): - start = 0 - stop = 9 - break_range = 3 - self._testSaveRestoreUtility(start, break_range, stop) - - def testRestoreAtSecondDataset(self): - start = 0 - stop = 9 - break_range = 6 - self._testSaveRestoreUtility(start, break_range, stop) - - def testRestoreAtBetweenDatasets(self): - start = 0 - stop = 9 - break_range = 4 - self._testSaveRestoreUtility(start, break_range, stop) - - def testRestoreExhaustedIterator(self): - start = 0 - stop = 9 - break_range = 9 - self._testSaveRestoreUtility(start, break_range, stop) - - def testRestoreInModifiedGraph(self): - start = 0 - stop = 9 - break_range = 6 - path = self._iterator_checkpoint_prefix() - step = 0 - - input_components = (np.tile(np.array([[1], [2], [3], [4]]), 20), np.tile( - np.array([[12], [13], [14], [15]]), 4)) + def _build_concatenate_dataset(self, var_array): + input_components = (np.tile(np.array([[1], [2], [3], [4]]), 20), + np.tile(np.array([[12], [13], [14], [15]]), 4)) to_concatenate_components = (np.tile( - np.array([[5], [6], [7], [8], [9]]), 20), np.tile( - np.array([[16], [17], [18], [19], [20]]), 15)) - - with ops.Graph().as_default() as g: - init_op, get_next = self._build_graph(input_components, - to_concatenate_components) - saver = saver_lib.Saver(allow_empty=True) - with self.test_session(graph=g) as sess: - sess.run(init_op) - for i in range(start, break_range): - result = sess.run(get_next) - if i < 4: - for component, result_component in zip(input_components, result): - self.assertAllEqual(component[i], result_component) - else: - for component, result_component in zip(to_concatenate_components, - result): - self.assertAllEqual(component[i - 4], result_component) - saver.save(sess, path, step) - - new_to_concatenate_components = (np.array([[5], [6], [7], [8], [9]]), - np.array([[16], [17], [18], [19], [20]])) - with ops.Graph().as_default() as g: - init_op, get_next = self._build_graph(input_components, - new_to_concatenate_components) - saver = saver_lib.Saver() - with self.test_session(graph=g) as sess: - saver.restore(sess, saver_lib.latest_checkpoint(self.get_temp_dir())) - for i in range(break_range, stop): - result = sess.run(get_next) - for component, result_component in zip(to_concatenate_components, - result): - self.assertAllEqual(component[i - 4], result_component) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) + np.array([[5], [6], [7], [8], [9]]), 20), var_array) + + return dataset_ops.Dataset.from_tensor_slices(input_components).concatenate( + dataset_ops.Dataset.from_tensor_slices(to_concatenate_components)) + + def testConcatenateCore(self): + num_outputs = 9 + array = np.tile(np.array([[16], [17], [18], [19], [20]]), 15) + diff_array = np.array([[1], [2], [3], [4], [5]]) + self.run_core_tests(lambda: self._build_concatenate_dataset(array), + lambda: self._build_concatenate_dataset(diff_array), + num_outputs) if __name__ == "__main__": diff --git a/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py b/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py index 55a1d3b95b212466b262ad3c26f1efd7ed0e067e..a90ba30e60cef13156719bba24fb553c0acec391 100644 --- a/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py @@ -717,11 +717,12 @@ class DatasetConstructorTest(test.TestCase): sess.run(var_1.initializer) iterator = dataset.make_initializable_iterator() + sess.run(iterator.initializer) with self.assertRaisesRegexp( - errors.InvalidArgumentError, - "Trying to access resource located in device"): - sess.run(iterator.initializer) + errors.FailedPreconditionError, + "Error while reading resource variable Variable"): + sess.run(iterator.get_next()) def testRestructureDataset(self): components = (array_ops.placeholder(dtypes.int32), diff --git a/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py b/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py index bf25cc60a1c0efc09bed6501fd2d6f4ccb07764b..7cde6e05b244773966fd7c1bd4ca1e95abf7fd5e 100644 --- a/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py +++ b/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py @@ -40,6 +40,8 @@ class DatasetSerializationTestBase(test.TestCase): def tearDown(self): self._delete_ckpt() + # TODO(b/70988345): Support native `tf.SparseTensor` objects and get rid of + # `sparse_tensors` argument. def run_core_tests(self, ds_fn1, ds_fn2, num_outputs, sparse_tensors=False): """Runs the core tests. diff --git a/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py index e66ed3f7aa2a512813ef353d2d0744ae67005884..b1937c08f347734d0d6871bd30ed209ff520623a 100644 --- a/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py @@ -41,6 +41,7 @@ from tensorflow.python.platform import test class InterleaveDatasetTest(test.TestCase): def _interleave(self, lists, cycle_length, block_length): + # TODO(b/69678297): Consolidate python interleave implementations. num_open = 0 # `all_iterators` acts as a queue of iterators over each element of `lists`. @@ -255,11 +256,15 @@ class InterleaveDatasetSeriazationTest( class ParallelInterleaveDatasetTest(test.TestCase): def setUp(self): + self.input_values = array_ops.placeholder(dtypes.int64, shape=[None]) self.cycle_length = array_ops.placeholder(dtypes.int64, shape=[]) self.block_length = array_ops.placeholder(dtypes.int64, shape=[]) self.sloppy = array_ops.placeholder(dtypes.bool, shape=[]) + self.buffer_output_elements = array_ops.placeholder(dtypes.int64, shape=[]) + self.prefetch_input_elements = array_ops.placeholder(dtypes.int64, shape=[]) + self.error = None self.repeat_count = 2 # Set up threading events used to sequence when items are produced that @@ -276,6 +281,10 @@ class ParallelInterleaveDatasetTest(test.TestCase): self.write_coordination_events[x].wait() self.write_coordination_events[x].clear() self.read_coordination_events[x].release() + if self.error: + err = self.error + self.error = None + raise err # pylint: disable=raising-bad-type return x * x def map_fn(x): @@ -286,11 +295,13 @@ class ParallelInterleaveDatasetTest(test.TestCase): dataset = dataset.repeat(x) return dataset.map(map_fn) - self.dataset = (dataset_ops.Dataset.from_tensor_slices(self.input_values) - .repeat(self.repeat_count).apply( - interleave_ops.parallel_interleave( - interleave_fn, self.cycle_length, - self.block_length, self.sloppy))) + self.dataset = ( + dataset_ops.Dataset.from_tensor_slices(self.input_values) + .repeat(self.repeat_count).apply( + interleave_ops.parallel_interleave(interleave_fn, self.cycle_length, + self.block_length, self.sloppy, + self.buffer_output_elements, + self.prefetch_input_elements))) self.iterator = self.dataset.make_initializable_iterator() self.init_op = self.iterator.initializer self.next_element = self.iterator.get_next() @@ -380,7 +391,7 @@ class ParallelInterleaveDatasetTest(test.TestCase): for i in range(4, 7): self.write_coordination_events[i].set() - def _testSingleThreaded(self, sloppy=False): + def _testSingleThreaded(self, sloppy=False, prefetch_input_elements=0): # cycle_length=1,block_length=1 acts like `Dataset.interleave()` and # `Dataset.flat_map()` and is single-threaded. No synchronization required. with self.test_session() as sess: @@ -391,7 +402,9 @@ class ParallelInterleaveDatasetTest(test.TestCase): self.input_values: [4, 5, 6], self.cycle_length: 1, self.block_length: 1, - self.sloppy: sloppy + self.sloppy: sloppy, + self.buffer_output_elements: 1, + self.prefetch_input_elements: prefetch_input_elements, }) for expected_element in self._interleave( @@ -408,6 +421,41 @@ class ParallelInterleaveDatasetTest(test.TestCase): def testSingleThreadedSloppy(self): self._testSingleThreaded(sloppy=True) + def testSingleThreadedPrefetch1Itr(self): + self._testSingleThreaded(prefetch_input_elements=1) + + def testSingleThreadedPrefetch1ItrSloppy(self): + self._testSingleThreaded(prefetch_input_elements=1, sloppy=True) + + def testSingleThreadedRagged(self): + # Tests a sequence with wildly different elements per iterator. + with self.test_session() as sess: + self._clear_coordination_events() + sess.run( + self.init_op, + feed_dict={ + self.input_values: [3, 7, 4], + self.cycle_length: 2, + self.block_length: 1, + self.sloppy: False, + self.buffer_output_elements: 1, + self.prefetch_input_elements: 1, + }) + + # Add coordination values for 3 and 7 + self.read_coordination_events[3] = threading.Semaphore(0) + self.write_coordination_events[3] = threading.Event() + self.read_coordination_events[7] = threading.Semaphore(0) + self.write_coordination_events[7] = threading.Event() + + for expected_element in self._interleave( + [[3] * 3, [7] * 7, [4] * 4] * self.repeat_count, 2, 1): + self.write_coordination_events[expected_element].set() + output = sess.run(self.next_element) + self.assertEqual(expected_element * expected_element, output) + with self.assertRaises(errors.OutOfRangeError): + sess.run(self.next_element) + def _testTwoThreadsNoContention(self, sloppy=False): # num_threads > 1. # Explicit coordination should result in `Dataset.interleave()` behavior @@ -420,7 +468,9 @@ class ParallelInterleaveDatasetTest(test.TestCase): self.input_values: [4, 5, 6], self.cycle_length: 2, self.block_length: 1, - self.sloppy: sloppy + self.sloppy: sloppy, + self.buffer_output_elements: 1, + self.prefetch_input_elements: 1, }) for i, expected_element in enumerate( self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2, @@ -463,6 +513,8 @@ class ParallelInterleaveDatasetTest(test.TestCase): self.cycle_length: 2, self.block_length: 1, self.sloppy: sloppy, + self.buffer_output_elements: 1, + self.prefetch_input_elements: 1, }) for i, expected_element in enumerate( self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2, @@ -472,7 +524,7 @@ class ParallelInterleaveDatasetTest(test.TestCase): self.read_coordination_events[expected_element].acquire() else: self.write_coordination_events[expected_element].set() - time.sleep(0.1) # Sleep to consistently "avoid" the race condition. + time.sleep(0.5) # Sleep to consistently "avoid" the race condition. actual_element = sess.run(self.next_element) if not done_first_event: done_first_event = True @@ -502,7 +554,9 @@ class ParallelInterleaveDatasetTest(test.TestCase): self.input_values: [4, 5, 6], self.cycle_length: 2, self.block_length: 2, - self.sloppy: sloppy + self.sloppy: sloppy, + self.buffer_output_elements: 1, + self.prefetch_input_elements: 1, }) for i, expected_element in enumerate( self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2, @@ -545,7 +599,9 @@ class ParallelInterleaveDatasetTest(test.TestCase): self.input_values: [4, 5, 6], self.cycle_length: 2, self.block_length: 2, - self.sloppy: sloppy + self.sloppy: sloppy, + self.buffer_output_elements: 1, + self.prefetch_input_elements: 1, }) for i, expected_element in enumerate( self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2, @@ -555,7 +611,7 @@ class ParallelInterleaveDatasetTest(test.TestCase): self.read_coordination_events[expected_element].acquire() else: self.write_coordination_events[expected_element].set() - time.sleep(0.1) # Sleep to consistently "avoid" the race condition. + time.sleep(0.5) # Sleep to consistently "avoid" the race condition. actual_element = sess.run(self.next_element) if not done_first_event: done_first_event = True @@ -583,7 +639,9 @@ class ParallelInterleaveDatasetTest(test.TestCase): self.input_values: [], self.cycle_length: 2, self.block_length: 3, - self.sloppy: sloppy + self.sloppy: sloppy, + self.buffer_output_elements: 1, + self.prefetch_input_elements: 0, }) with self.assertRaises(errors.OutOfRangeError): sess.run(self.next_element) @@ -604,7 +662,9 @@ class ParallelInterleaveDatasetTest(test.TestCase): self.input_values: [0, 0, 0], self.cycle_length: 2, self.block_length: 3, - self.sloppy: sloppy + self.sloppy: sloppy, + self.buffer_output_elements: 1, + self.prefetch_input_elements: 0, }) with self.assertRaises(errors.OutOfRangeError): sess.run(self.next_element) @@ -615,7 +675,8 @@ class ParallelInterleaveDatasetTest(test.TestCase): def testNonEmptyInputIntoEmptyOutputsSloppy(self): self._testNonEmptyInputIntoEmptyOutputs(sloppy=True) - def _testPartiallyEmptyOutputs(self, sloppy=False): + def _testPartiallyEmptyOutputs(self, sloppy=False, prefetch_input_elements=1): + race_indices = {2, 8, 14} # Sequence points when sloppy mode has race conds # Mixture of non-empty and empty interleaved datasets. with self.test_session() as sess: self._clear_coordination_events() @@ -627,27 +688,31 @@ class ParallelInterleaveDatasetTest(test.TestCase): self.cycle_length: 2, self.block_length: 1, self.sloppy: sloppy, + self.buffer_output_elements: 1, + self.prefetch_input_elements: prefetch_input_elements, }) for i, expected_element in enumerate( self._interleave([[4] * 4, [], [6] * 6] * self.repeat_count, 2, 1)): self.write_coordination_events[expected_element].set() - if done_first_event: # First event starts the worker threads + # First event starts the worker threads. Additionally, when running the + # sloppy case with prefetch_input_elements=0, we get stuck if we wait + # for the read coordination event for certain event orderings in the + # presence of finishing iterators. + if done_first_event and not (sloppy and (i in race_indices)): self.read_coordination_events[expected_element].acquire() actual_element = sess.run(self.next_element) - if not done_first_event: + if not done_first_event or (sloppy and (i in race_indices)): done_first_event = True self.read_coordination_events[expected_element].acquire() self.assertEqual(expected_element * expected_element, actual_element, "At index %s: %s expected, got: %s" % (i, expected_element, actual_element)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(self.next_element) def testPartiallyEmptyOutputs(self): self._testPartiallyEmptyOutputs() def testPartiallyEmptyOutputsSloppy(self): - self._testPartiallyEmptyOutputs(sloppy=True) + self._testPartiallyEmptyOutputs(sloppy=True, prefetch_input_elements=0) def testDelayedOutputSloppy(self): # Explicitly control the sequence of events to ensure we correctly avoid @@ -661,6 +726,8 @@ class ParallelInterleaveDatasetTest(test.TestCase): self.cycle_length: 2, self.block_length: 1, self.sloppy: True, + self.buffer_output_elements: 1, + self.prefetch_input_elements: 0, }) mis_ordering = [ @@ -683,8 +750,10 @@ class ParallelInterleaveDatasetTest(test.TestCase): feed_dict={ self.input_values: [4, 5, 6], self.cycle_length: 2, - self.block_length: 3, - self.sloppy: True + self.block_length: 1, + self.sloppy: True, + self.buffer_output_elements: 1, + self.prefetch_input_elements: 1, }) # Test against a generating sequence that differs from the uncontended # case, in order to prove sloppy correctness. @@ -692,7 +761,7 @@ class ParallelInterleaveDatasetTest(test.TestCase): self._interleave( [[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, cycle_length=2, - block_length=2)): + block_length=3)): self.write_coordination_events[expected_element].set() if done_first_event: # First event starts the worker threads. self.read_coordination_events[expected_element].acquire() @@ -716,7 +785,9 @@ class ParallelInterleaveDatasetTest(test.TestCase): self.input_values: [4, 5, 6], self.cycle_length: 3, self.block_length: 2, - self.sloppy: sloppy + self.sloppy: sloppy, + self.buffer_output_elements: 1, + self.prefetch_input_elements: 0, }) for i in range(4, 7): self.write_coordination_events[i].set() @@ -790,6 +861,139 @@ class ParallelInterleaveDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) + def testErrorsInOutputFn(self): + with self.test_session() as sess: + self._clear_coordination_events() + sess.run( + self.init_op, + feed_dict={ + self.input_values: [4, 5, 6], + self.cycle_length: 2, + self.block_length: 1, + self.sloppy: False, + self.buffer_output_elements: 1, + self.prefetch_input_elements: 0, + }) + + except_on_element_indices = set([3]) + + for i, expected_element in enumerate( + self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2, + 1)): + if i in except_on_element_indices: + self.error = ValueError() + self.write_coordination_events[expected_element].set() + with self.assertRaises(errors.InvalidArgumentError): + sess.run(self.next_element) + else: + self.write_coordination_events[expected_element].set() + actual_element = sess.run(self.next_element) + self.assertEqual(expected_element * expected_element, actual_element, + "At index %s: %s expected, got: %s" % + (i, expected_element, actual_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(self.next_element) + + def testErrorsInInputFn(self): + + def map_py_fn(x): + if x == 5: + raise ValueError() + return x + + def map_fn(x): + return script_ops.py_func(map_py_fn, [x], x.dtype) + + def interleave_fn(x): + dataset = dataset_ops.Dataset.from_tensors(x) + dataset = dataset.repeat(x) + return dataset + + self.dataset = ( + dataset_ops.Dataset.from_tensor_slices(self.input_values).map(map_fn) + .repeat(self.repeat_count).apply( + interleave_ops.parallel_interleave(interleave_fn, self.cycle_length, + self.block_length, self.sloppy, + self.buffer_output_elements, + self.prefetch_input_elements))) + + self.iterator = self.dataset.make_initializable_iterator() + self.init_op = self.iterator.initializer + self.next_element = self.iterator.get_next() + + with self.test_session() as sess: + sess.run( + self.init_op, + feed_dict={ + self.input_values: [4, 5, 6], + self.cycle_length: 2, + self.block_length: 1, + self.sloppy: False, + self.buffer_output_elements: 1, + self.prefetch_input_elements: 0, + }) + for i, expected_element in enumerate( + self._interleave([[4] * 4, [5], [6] * 6] * self.repeat_count, 2, 1)): + if expected_element == 5: + with self.assertRaises(errors.InvalidArgumentError): + sess.run(self.next_element) + else: + actual_element = sess.run(self.next_element) + self.assertEqual(expected_element, actual_element, + "At index %s: %s expected, got: %s" % + (i, expected_element, actual_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(self.next_element) + + def testErrorsInInterleaveFn(self): + + def map_py_fn(x): + if x == 5: + raise ValueError() + return x + + def interleave_fn(x): + dataset = dataset_ops.Dataset.from_tensors(x) + y = script_ops.py_func(map_py_fn, [x], x.dtype) + dataset = dataset.repeat(y) + return dataset + + self.dataset = ( + dataset_ops.Dataset.from_tensor_slices(self.input_values) + .repeat(self.repeat_count).apply( + interleave_ops.parallel_interleave(interleave_fn, self.cycle_length, + self.block_length, self.sloppy, + self.buffer_output_elements, + self.prefetch_input_elements))) + + self.iterator = self.dataset.make_initializable_iterator() + self.init_op = self.iterator.initializer + self.next_element = self.iterator.get_next() + + with self.test_session() as sess: + sess.run( + self.init_op, + feed_dict={ + self.input_values: [4, 5, 6], + self.cycle_length: 2, + self.block_length: 1, + self.sloppy: False, + self.buffer_output_elements: 1, + self.prefetch_input_elements: 0, + }) + for i, expected_element in enumerate( + self._interleave([[4] * 4, [5], [6] * 6] * self.repeat_count, 2, 1)): + if expected_element == 5: + with self.assertRaises(errors.InvalidArgumentError): + sess.run(self.next_element) + else: + actual_element = sess.run(self.next_element) + self.assertEqual(expected_element, actual_element, + "At index %s: %s expected, got: %s" % + (i, expected_element, actual_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(self.next_element) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py index e9a07da84a8c80c09ebd4dab0b1d69febe1c9790..69252612a8e6cb29c513003188946be21f3432c2 100644 --- a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py @@ -24,8 +24,9 @@ import threading import numpy as np from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base -from tensorflow.contrib.data.python.ops import dataset_ops +from tensorflow.contrib.data.python.ops import dataset_ops as contrib_dataset_ops from tensorflow.contrib.data.python.ops import error_ops +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -52,8 +53,10 @@ class MapDatasetTest(test.TestCase): def _buildMapDataset(self, components, count): def _map_fn(x, y, z): return math_ops.square(x), math_ops.square(y), math_ops.square(z) - return (dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn) - .repeat(count)) + + return ( + contrib_dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn) + .repeat(count)) def testMapDataset(self): """Test an dataset that maps a TF function across its input elements.""" @@ -113,7 +116,8 @@ class MapDatasetTest(test.TestCase): output_buffer_size): def _map_fn(x, y, z): return math_ops.square(x), math_ops.square(y), math_ops.square(z) - return (dataset_ops.Dataset.from_tensor_slices(components).map( + + return (contrib_dataset_ops.Dataset.from_tensor_slices(components).map( _map_fn, num_threads=num_threads, output_buffer_size=output_buffer_size) .repeat(count)) @@ -210,9 +214,9 @@ class MapDatasetTest(test.TestCase): def testParallelMapUnspecifiedOutputSize(self): components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32) - dataset = (dataset_ops.Dataset.from_tensor_slices(components) - .map(lambda x: array_ops.check_numerics(x, "message"), - num_threads=2)) + dataset = ( + contrib_dataset_ops.Dataset.from_tensor_slices(components).map( + lambda x: array_ops.check_numerics(x, "message"), num_threads=2)) iterator = dataset.make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() @@ -225,9 +229,11 @@ class MapDatasetTest(test.TestCase): def testParallelMapError(self): components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32) - dataset = (dataset_ops.Dataset.from_tensor_slices(components) - .map(lambda x: array_ops.check_numerics(x, "message"), - num_threads=2, output_buffer_size=2)) + dataset = ( + contrib_dataset_ops.Dataset.from_tensor_slices(components).map( + lambda x: array_ops.check_numerics(x, "message"), + num_threads=2, + output_buffer_size=2)) iterator = dataset.make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() @@ -246,9 +252,9 @@ class MapDatasetTest(test.TestCase): def testPrefetchError(self): components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32) - dataset = (dataset_ops.Dataset.from_tensor_slices(components) - .map(lambda x: array_ops.check_numerics(x, "message")) - .prefetch(2)) + dataset = ( + contrib_dataset_ops.Dataset.from_tensor_slices(components) + .map(lambda x: array_ops.check_numerics(x, "message")).prefetch(2)) iterator = dataset.make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() @@ -267,9 +273,10 @@ class MapDatasetTest(test.TestCase): def testMapIgnoreError(self): components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32) - dataset = (dataset_ops.Dataset.from_tensor_slices(components) - .map(lambda x: array_ops.check_numerics(x, "message")).apply( - error_ops.ignore_errors())) + dataset = ( + contrib_dataset_ops.Dataset.from_tensor_slices(components) + .map(lambda x: array_ops.check_numerics(x, "message")).apply( + error_ops.ignore_errors())) iterator = dataset.make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() @@ -284,10 +291,11 @@ class MapDatasetTest(test.TestCase): def testParallelMapIgnoreError(self): components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32) - dataset = (dataset_ops.Dataset.from_tensor_slices(components).map( - lambda x: array_ops.check_numerics(x, "message"), - num_threads=2, - output_buffer_size=2).apply(error_ops.ignore_errors())) + dataset = ( + contrib_dataset_ops.Dataset.from_tensor_slices(components).map( + lambda x: array_ops.check_numerics(x, "message"), + num_threads=2, + output_buffer_size=2).apply(error_ops.ignore_errors())) iterator = dataset.make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() @@ -308,9 +316,10 @@ class MapDatasetTest(test.TestCase): for filename in filenames: write_string_to_file(filename, filename) - dataset = (dataset_ops.Dataset.from_tensor_slices(filenames).map( - io_ops.read_file, num_threads=2, output_buffer_size=2).apply( - error_ops.ignore_errors())) + dataset = ( + contrib_dataset_ops.Dataset.from_tensor_slices(filenames).map( + io_ops.read_file, num_threads=2, output_buffer_size=2).apply( + error_ops.ignore_errors())) iterator = dataset.make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() @@ -344,7 +353,7 @@ class MapDatasetTest(test.TestCase): table = lookup_ops.HashTable( lookup_ops.KeyValueTensorInitializer(keys, values), default_val) - input_sentences = dataset_ops.Dataset.from_tensor_slices( + input_sentences = contrib_dataset_ops.Dataset.from_tensor_slices( ["brain brain tank salad surgery", "surgery brain"]) iterator = (input_sentences @@ -368,8 +377,9 @@ class MapDatasetTest(test.TestCase): queue = data_flow_ops.FIFOQueue(200, dtypes.int64, shapes=[]) enqueue_op = queue.enqueue_many(elements) close_op = queue.close() - iterator = (dataset_ops.Dataset.from_tensors(0).repeat(-1) - .map(lambda _: queue.dequeue()).make_initializable_iterator()) + iterator = ( + contrib_dataset_ops.Dataset.from_tensors(0).repeat(-1) + .map(lambda _: queue.dequeue()).make_initializable_iterator()) init_op = iterator.initializer get_next = iterator.get_next() @@ -392,9 +402,10 @@ class MapDatasetTest(test.TestCase): enqueue_op = queue.enqueue_many(elements) close_op = queue.close() - iterator = (dataset_ops.Dataset.from_tensors(0).repeat(-1) - .map(lambda _: (queue.dequeue(), queue_2.dequeue())) - .make_initializable_iterator()) + iterator = ( + contrib_dataset_ops.Dataset.from_tensors(0).repeat(-1) + .map(lambda _: (queue.dequeue(), queue_2.dequeue())) + .make_initializable_iterator()) init_op = iterator.initializer get_next = iterator.get_next() @@ -411,9 +422,9 @@ class MapDatasetTest(test.TestCase): def testCaptureVariable(self): counter_var = variable_scope.get_variable( "counter", (), dtypes.int32, use_resource=True) - iterator = (dataset_ops.Dataset.from_tensors(0).repeat(10) - .map(lambda _: counter_var.assign_add(1)) - .make_initializable_iterator()) + iterator = ( + contrib_dataset_ops.Dataset.from_tensors(0).repeat(10) + .map(lambda _: counter_var.assign_add(1)).make_initializable_iterator()) init_op = iterator.initializer get_next = iterator.get_next() @@ -431,20 +442,22 @@ class MapDatasetTest(test.TestCase): def testCaptureUninitializedVariableError(self): counter_var = variable_scope.get_variable( "counter", (), dtypes.int32, use_resource=True) - iterator = (dataset_ops.Dataset.from_tensors(0).repeat(10) - .map(lambda _: counter_var.assign_add(1)) - .make_initializable_iterator()) + iterator = ( + contrib_dataset_ops.Dataset.from_tensors(0).repeat(10) + .map(lambda _: counter_var.assign_add(1)).make_initializable_iterator()) init_op = iterator.initializer + get_next = iterator.get_next() with self.test_session() as sess: - with self.assertRaisesRegexp(errors.FailedPreconditionError, - "Failed to capture resource"): - sess.run(init_op) + sess.run(init_op) + with self.assertRaises(errors.NotFoundError): + sess.run(get_next) def testSeededStatefulOperatorIsProperlyStateful(self): - iterator = (dataset_ops.Dataset.from_tensors(0).repeat(10) - .map(lambda _: random_ops.random_uniform((), seed=11)).batch(2) - .make_initializable_iterator()) + iterator = ( + contrib_dataset_ops.Dataset.from_tensors(0).repeat(10) + .map(lambda _: random_ops.random_uniform((), seed=11)).batch(2) + .make_initializable_iterator()) init_op = iterator.initializer get_next = iterator.get_next() @@ -466,7 +479,7 @@ class MapDatasetTest(test.TestCase): self.assertAllClose(random_values, random_values_2) def testMapDict(self): - iterator = (dataset_ops.Dataset.range(10) + iterator = (contrib_dataset_ops.Dataset.range(10) .map(lambda x: {"foo": x * 2, "bar": x ** 2}) .map(lambda d: d["foo"] + d["bar"]) .make_initializable_iterator()) @@ -482,9 +495,9 @@ class MapDatasetTest(test.TestCase): def testMapNamedtuple(self, count=10): # construct dataset of tuples - labels = dataset_ops.Dataset.range(count) + labels = contrib_dataset_ops.Dataset.range(count) images = labels.map(lambda l: -l) - dataset_tuple = dataset_ops.Dataset.zip((labels, images)) + dataset_tuple = contrib_dataset_ops.Dataset.zip((labels, images)) # convert dataset of tuples to dataset of namedtuples example = namedtuple("Example", ["label", "image"]) @@ -517,7 +530,7 @@ class MapDatasetTest(test.TestCase): def testUseStepContainerInMap(self): row = np.arange(6) iterator = ( - dataset_ops.Dataset.from_tensors(row) + contrib_dataset_ops.Dataset.from_tensors(row) .map(lambda elems: functional_ops.map_fn(lambda x: x * x, elems)) .make_initializable_iterator()) init_op = iterator.initializer @@ -547,10 +560,8 @@ class MapDatasetTest(test.TestCase): buffer_size_placeholder = array_ops.placeholder(dtypes.int64, shape=[]) iterator = ( - dataset_ops.Dataset.range(100) - .map(_map_fn) - .prefetch(buffer_size_placeholder) - .make_initializable_iterator()) + contrib_dataset_ops.Dataset.range(100).map(_map_fn) + .prefetch(buffer_size_placeholder).make_initializable_iterator()) init_op = iterator.initializer get_next = iterator.get_next() @@ -586,9 +597,10 @@ class MapDatasetTest(test.TestCase): sess.run(get_next) def testReturnList(self): - iterator = (dataset_ops.Dataset.range(10) - .map(lambda x: [x, constant_op.constant(37.0)]) - .make_initializable_iterator()) + iterator = ( + contrib_dataset_ops.Dataset.range(10) + .map(lambda x: [x, constant_op.constant(37.0)]) + .make_initializable_iterator()) init_op = iterator.initializer get_next = iterator.get_next() @@ -607,9 +619,9 @@ class MapDatasetTest(test.TestCase): return script_ops.py_func( _map_py_func, [x_tensor], [dtypes.int64, dtypes.float64]) - iterator = (dataset_ops.Dataset.range(10) - .map(_map_fn) - .make_initializable_iterator()) + iterator = ( + contrib_dataset_ops.Dataset.range(10).map(_map_fn) + .make_initializable_iterator()) init_op = iterator.initializer get_next = iterator.get_next() @@ -633,9 +645,9 @@ class MapDatasetTest(test.TestCase): values=(i * np.array([1])), dense_shape=np.array([1, 1])) - iterator = (dataset_ops.Dataset.range(10) - .map(_sparse) - .make_initializable_iterator()) + iterator = ( + contrib_dataset_ops.Dataset.range(10).map(_sparse) + .make_initializable_iterator()) init_op = iterator.initializer get_next = iterator.get_next() @@ -661,7 +673,7 @@ class MapDatasetTest(test.TestCase): return sparse_ops.sparse_concat(0, [i, i]) iterator = ( - dataset_ops.Dataset.range(10).map(_sparse).map(_check) + contrib_dataset_ops.Dataset.range(10).map(_sparse).map(_check) .make_initializable_iterator()) init_op = iterator.initializer get_next = iterator.get_next() @@ -683,23 +695,26 @@ class MapDatasetTest(test.TestCase): get_next = iterator.get_next() return x * get_next - return dataset_ops.Dataset.range(10).map(_map_fn) + return contrib_dataset_ops.Dataset.range(10).map(_map_fn) def _build_graph(): - captured_iterator = dataset_ops.Dataset.range( + captured_iterator = contrib_dataset_ops.Dataset.range( 10).make_initializable_iterator() ds = _build_ds(captured_iterator) iterator = ds.make_initializable_iterator() init_op = iterator.initializer - return captured_iterator.initializer, init_op + get_next = iterator.get_next() + return captured_iterator.initializer, init_op, get_next with ops.Graph().as_default() as g: - captured_init_op, init_op = _build_graph() + captured_init_op, init_op, get_next = _build_graph() with self.test_session(graph=g) as sess: sess.run(captured_init_op) - with self.assertRaises(errors.UnimplementedError): - # CapturedFunction does not support capturing IteratorResource. - sess.run(init_op) + sess.run(init_op) + for i in range(10): + self.assertEquals(i * i, sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) class MapDatasetSerializationTest( @@ -718,8 +733,9 @@ class MapDatasetSerializationTest( def _map_fn(x, y, z): return math_ops.square(x), math_ops.square(y), math_ops.square(z) - return (dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn) - .repeat(self._num_epochs)) + return ( + contrib_dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn) + .repeat(self._num_epochs)) def testSaveRestoreCore(self): self.run_core_tests( @@ -735,7 +751,98 @@ class MapDatasetSerializationTest( return random_ops.random_uniform( (), 0, 10, dtype=dtypes.int32) * math_ops.to_int32(x) - return dataset_ops.Dataset.range(100).map(_map_fn) + return contrib_dataset_ops.Dataset.range(100).map(_map_fn) + + self.verify_error_on_save(_build_ds, 15, errors.InvalidArgumentError) + + def testCaptureVariableInMapFn(self): + + def _build_ds(): + counter_var = variable_scope.get_variable( + "counter", (), dtypes.int32, use_resource=True) + return (contrib_dataset_ops.Dataset.from_tensors(0).repeat(10).map( + lambda _: counter_var.assign_add(1))) + + self.verify_error_on_save(_build_ds, 15, errors.InvalidArgumentError) + + def testCaptureDefunInMapFn(self): + num_outputs = 100 + + def _build_ds(): + + @function.Defun(dtypes.int64) + def defun_fn(x): + return constant_op.constant(1000) + math_ops.to_int32(x) + + return contrib_dataset_ops.Dataset.range(num_outputs).map(defun_fn) + + self.run_core_tests(_build_ds, None, num_outputs) + + def testBuildDefunInMapFn(self): + num_outputs = 100 + + def _build_ds(): + + @function.Defun(dtypes.int64) + def defun_fn(x): + + @function.Defun(dtypes.int32) + def defun_fn_deep(x): + return constant_op.constant(1000) + math_ops.to_int32(x) + + return constant_op.constant(11000) + defun_fn_deep(math_ops.to_int32(x)) + + return contrib_dataset_ops.Dataset.range(num_outputs).map(defun_fn) + + self.run_core_tests(_build_ds, None, num_outputs) + + +class ParallelMapDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def setUp(self): + self._tensor_slice_len = 7 + self._num_epochs = 1 + self._num_outputs = self._tensor_slice_len * self._num_epochs + + def _build_ds(self, multiplier=37.0): + components = (np.arange(self._tensor_slice_len), np.array([[1, 2, 3]]) * + np.arange(self._tensor_slice_len)[:, np.newaxis], + np.array(multiplier) * np.arange(self._tensor_slice_len)) + + def _map_fn(x, y, z): + return math_ops.square(x), math_ops.square(y), math_ops.square(z) + + return (dataset_ops.Dataset.from_tensor_slices(components).map( + _map_fn, num_parallel_calls=3).repeat(self._num_epochs)) + + def _build_ds_with_prefetch(self, multiplier=37.0): + components = (np.arange(self._tensor_slice_len), np.array([[1, 2, 3]]) * + np.arange(self._tensor_slice_len)[:, np.newaxis], + np.array(multiplier) * np.arange(self._tensor_slice_len)) + + def _map_fn(x, y, z): + return math_ops.square(x), math_ops.square(y), math_ops.square(z) + + return (dataset_ops.Dataset.from_tensor_slices(components).map( + _map_fn, num_parallel_calls=3).repeat(self._num_epochs).prefetch(5)) + + def testSaveRestoreCore(self): + for ds_fn in [self._build_ds, self._build_ds_with_prefetch]: + self.run_core_tests( + ds_fn, + lambda: ds_fn(multiplier=15.0), + self._num_outputs) + + def testSaveStatefulFunction(self): + + def _build_ds(): + + def _map_fn(x): + return random_ops.random_uniform( + (), 0, 10, dtype=dtypes.int32) * math_ops.to_int32(x) + + return contrib_dataset_ops.Dataset.range(100).map(_map_fn) self.verify_error_on_save(_build_ds, 15, errors.InvalidArgumentError) @@ -744,7 +851,7 @@ class MapDatasetSerializationTest( def _build_ds(): counter_var = variable_scope.get_variable( "counter", (), dtypes.int32, use_resource=True) - return (dataset_ops.Dataset.from_tensors(0).repeat(10).map( + return (contrib_dataset_ops.Dataset.from_tensors(0).repeat(10).map( lambda _: counter_var.assign_add(1))) self.verify_error_on_save(_build_ds, 15, errors.InvalidArgumentError) @@ -758,7 +865,7 @@ class MapDatasetSerializationTest( def defun_fn(x): return constant_op.constant(1000) + math_ops.to_int32(x) - return dataset_ops.Dataset.range(num_outputs).map(defun_fn) + return contrib_dataset_ops.Dataset.range(num_outputs).map(defun_fn) self.run_core_tests(_build_ds, None, num_outputs) @@ -776,7 +883,7 @@ class MapDatasetSerializationTest( return constant_op.constant(11000) + defun_fn_deep(math_ops.to_int32(x)) - return dataset_ops.Dataset.range(num_outputs).map(defun_fn) + return contrib_dataset_ops.Dataset.range(num_outputs).map(defun_fn) self.run_core_tests(_build_ds, None, num_outputs) @@ -785,7 +892,7 @@ class IgnoreErrorsSerializationTest( dataset_serialization_test_base.DatasetSerializationTestBase): def _build_ds(self, components): - return dataset_ops.Dataset.from_tensor_slices(components).map( + return contrib_dataset_ops.Dataset.from_tensor_slices(components).map( lambda x: array_ops.check_numerics(x, "message")).apply( error_ops.ignore_errors()) diff --git a/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py index 8e6ad061a11752ab7b1ffc13c90b4fa52f67d6aa..a431670829ed1d66f1719985af73eafa1fe45982 100644 --- a/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py @@ -19,11 +19,10 @@ from __future__ import print_function import os +from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base from tensorflow.contrib.data.python.ops import counter from tensorflow.contrib.data.python.ops import dataset_ops from tensorflow.contrib.data.python.ops import enumerate_ops -from tensorflow.contrib.data.python.ops import iterator_ops as contrib_iterator_ops -from tensorflow.python.data.ops import iterator_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -34,20 +33,11 @@ from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.ops import io_ops from tensorflow.python.ops import parsing_ops from tensorflow.python.ops import variables -from tensorflow.python.platform import gfile from tensorflow.python.platform import test -from tensorflow.python.training import saver as saver_lib class RangeDatasetTest(test.TestCase): - def tearDown(self): - # Remove all checkpoint files. - prefix = self._iterator_checkpoint_prefix() - pattern = prefix + "*" - files = gfile.Glob(pattern) - map(gfile.Remove, files) - def testStop(self): stop = array_ops.placeholder(dtypes.int64, shape=[]) iterator = dataset_ops.Dataset.range(stop).make_initializable_iterator() @@ -216,20 +206,25 @@ class RangeDatasetTest(test.TestCase): self.assertEqual(-1, sess.run(negative_get_next)) self.assertEqual(-2, sess.run(negative_get_next)) - def _iterator_checkpoint_prefix(self): + +class RangeDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _iterator_checkpoint_prefix_local(self): return os.path.join(self.get_temp_dir(), "iterator") def _save_op(self, iterator_resource): iterator_state_variant = gen_dataset_ops.serialize_iterator( iterator_resource) save_op = io_ops.write_file( - self._iterator_checkpoint_prefix(), + self._iterator_checkpoint_prefix_local(), parsing_ops.serialize_tensor(iterator_state_variant)) return save_op def _restore_op(self, iterator_resource): iterator_state_variant = parsing_ops.parse_tensor( - io_ops.read_file(self._iterator_checkpoint_prefix()), dtypes.variant) + io_ops.read_file(self._iterator_checkpoint_prefix_local()), + dtypes.variant) restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource, iterator_state_variant) return restore_op @@ -283,382 +278,16 @@ class RangeDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) - def testSaveRestoreUsingSaverFromMetaGraph(self): - - def _build_graph(start, stop): - iterator = dataset_ops.Dataset.range(start, - stop).make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - ops.add_to_collection("iterator_ops", init_op) - ops.add_to_collection("iterator_ops", get_next) - saveable_obj = contrib_iterator_ops.make_saveable_from_iterator(iterator) - # Add the SaveableObject to the `SAVEABLE_OBJECTS` collection - # so that it can be automatically picked up by the Saver. - ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable_obj) - saver = saver_lib.Saver() - return init_op, get_next, saver - - start = 2 - stop = 10 - break_point = 5 - path = self._iterator_checkpoint_prefix() - meta_filename = path + ".meta" - - # Execute input pipeline for a few steps and save iterator state. - with ops.Graph().as_default() as g: - init_op, get_next, saver = _build_graph(start, stop) - with self.test_session(graph=g) as sess: - sess.run(variables.global_variables_initializer()) - sess.run(init_op) - for i in range(start, break_point): - self.assertEqual(i, sess.run(get_next)) - saver.save(sess, path) - - # Build the saver from the MetaGraph using import_meta_graph and - # check that the iterator state is restored. - with ops.Graph().as_default() as g: - saver = saver_lib.import_meta_graph(meta_filename) - init_op, get_next = ops.get_collection("iterator_ops") - with self.test_session(graph=g) as sess: - saver.restore(sess, saver_lib.latest_checkpoint(self.get_temp_dir())) - for i in range(break_point, stop): - self.assertEqual(i, sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testSaveRestoreUsingBuiltSaver(self): - - def _build_graph(start, stop): - iterator = dataset_ops.Dataset.range(start, - stop).make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - ops.add_to_collection("iterator_ops", init_op) - ops.add_to_collection("iterator_ops", get_next) - # Add the SaveableObject to the `SAVEABLE_OBJECTS` collection - # so that it can be automatically picked up by the Saver. - saveable_obj = contrib_iterator_ops.make_saveable_from_iterator(iterator) - ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable_obj) - saver = saver_lib.Saver() - return init_op, get_next, saver - - start = 2 - stop = 10 - stop_new = 15 - break_point = 5 - path = self._iterator_checkpoint_prefix() - - # Execute input pipeline for a few steps and save iterator state. - with ops.Graph().as_default() as g: - init_op, get_next, saver = _build_graph(start, stop) - with self.test_session(graph=g) as sess: - sess.run(variables.global_variables_initializer()) - sess.run(init_op) - for i in range(start, break_point): - self.assertEqual(i, sess.run(get_next)) - saver.save(sess, path) - - # Manually build a modified Graph and Saver instead of importing - # MetaGraph and verify that original iterator state gets restored. - with ops.Graph().as_default() as g: - init_op, get_next, saver = _build_graph(start, stop_new) - with self.test_session(graph=g) as sess: - saver.restore(sess, saver_lib.latest_checkpoint(self.get_temp_dir())) - for i in range(break_point, stop): - self.assertEqual(i, sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testSaveRestoreUsingSaverThenInit(self): - - def _build_graph(start, stop): - iterator = dataset_ops.Dataset.range(start, - stop).make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - ops.add_to_collection("iterator_ops", init_op) - ops.add_to_collection("iterator_ops", get_next) - # Add the SaveableObject to the `SAVEABLE_OBJECTS` collection - # so that it can be automatically picked up by the Saver. - saveable_obj = contrib_iterator_ops.make_saveable_from_iterator(iterator) - ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable_obj) - saver = saver_lib.Saver() - return init_op, get_next, saver + def _build_range_dataset(self, start, stop): + return dataset_ops.Dataset.range(start, stop) - start = 2 - stop = 10 - stop_new = 15 - break_point = 5 - path = self._iterator_checkpoint_prefix() - - # Execute input pipeline for a few steps and save iterator state. - with ops.Graph().as_default() as g: - init_op, get_next, saver = _build_graph(start, stop) - with self.test_session(graph=g) as sess: - sess.run(variables.global_variables_initializer()) - sess.run(init_op) - for i in range(start, break_point): - self.assertEqual(i, sess.run(get_next)) - saver.save(sess, path) - - # Restore iterator state call and then call init_op for the iterator and - # verify that the new iterator hides the restored iterator. - with ops.Graph().as_default() as g: - init_op, get_next, saver = _build_graph(start, stop_new) - with self.test_session(graph=g) as sess: - saver.restore(sess, saver_lib.latest_checkpoint(self.get_temp_dir())) - sess.run(init_op) - for i in range(start, stop_new): - self.assertEqual(i, sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testRestoreWithoutBuildingDatasetGraph(self): - - def _build_graph(start, stop, num_epochs): - dataset = dataset_ops.Dataset.range(start, stop).repeat(num_epochs) - iterator = dataset.make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - save_op = self._save_op(iterator._iterator_resource) - restore_op = self._restore_op(iterator._iterator_resource) - return init_op, get_next, save_op, restore_op - - # Saving and restoring in different sessions. - start = 2 - stop = 10 - num_epochs = 5 - break_point = 5 - break_epoch = 3 - with ops.Graph().as_default() as g: - init_op, get_next, save_op, _ = _build_graph(start, stop, num_epochs) - with self.test_session(graph=g) as sess: - sess.run(variables.global_variables_initializer()) - sess.run(init_op) - for _ in range(break_epoch): - for i in range(start, stop): - self.assertEqual(i, sess.run(get_next)) - for i in range(start, break_point): - self.assertEqual(i, sess.run(get_next)) - sess.run(save_op) - - with ops.Graph().as_default() as g: - # Create an empty IteratorResource and restore the Iterator into it. - output_types = dtypes.int64 - output_shapes = tensor_shape.scalar() - iterator = iterator_ops.Iterator.from_structure(output_types, - output_shapes) - restore_op = self._restore_op(iterator._iterator_resource) - get_next = iterator.get_next() - with self.test_session(graph=g) as sess: - sess.run(restore_op) - for i in range(break_point, stop): - self.assertEqual(i, sess.run(get_next)) - for _ in range(break_epoch + 1, num_epochs): - for i in range(start, stop): - self.assertEqual(i, sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testRestoreInModifiedGraph(self): - - def _build_graph(start, stop): - dataset = dataset_ops.Dataset.range(start, stop) - iterator = dataset.make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - save_op = self._save_op(iterator._iterator_resource) - restore_op = self._restore_op(iterator._iterator_resource) - return init_op, get_next, save_op, restore_op - - # Saving and restoring in different sessions. + def testRangeCore(self): start = 2 stop = 10 stop_1 = 8 - break_point = 5 - with ops.Graph().as_default() as g: - init_op, get_next, save_op, _ = _build_graph(start, stop) - with self.test_session(graph=g) as sess: - sess.run(variables.global_variables_initializer()) - sess.run(init_op) - for i in range(start, break_point): - self.assertEqual(i, sess.run(get_next)) - sess.run(save_op) - - with ops.Graph().as_default() as g: - # Intentionally build a graph with a different value for stop to make sure - # the original dataset graph is actually getting loaded. - init_op, get_next, _, restore_op = _build_graph(start, stop_1) - with self.test_session(graph=g) as sess: - sess.run(restore_op) - for i in range(break_point, stop): - self.assertEqual(i, sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testInitThenRestore(self): - # Note: Calling init_op before restore_op is redundant. This test just makes - # sure we do not fail if restore is called on an already initialized - # iterator resource. - - def _build_graph(start, stop): - dataset = dataset_ops.Dataset.range(start, stop) - iterator = dataset.make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - save_op = self._save_op(iterator._iterator_resource) - restore_op = self._restore_op(iterator._iterator_resource) - return init_op, get_next, save_op, restore_op - - # Saving and restoring in different sessions. - start = 2 - stop = 10 - break_point = 5 - with ops.Graph().as_default() as g: - init_op, get_next, save_op, _ = _build_graph(start, stop) - with self.test_session(graph=g) as sess: - sess.run(variables.global_variables_initializer()) - sess.run(init_op) - for i in range(start, break_point): - self.assertEqual(i, sess.run(get_next)) - sess.run(save_op) - - with ops.Graph().as_default() as g: - init_op, get_next, _, restore_op = _build_graph(start, stop) - with self.test_session(graph=g) as sess: - sess.run(init_op) - sess.run(restore_op) - for i in range(break_point, stop): - self.assertEqual(i, sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testMultipleSaves(self): - - def _build_graph(start, stop): - iterator = dataset_ops.Dataset.range(start, - stop).make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - save_op = self._save_op(iterator._iterator_resource) - restore_op = self._restore_op(iterator._iterator_resource) - return init_op, get_next, save_op, restore_op - - start = 2 - stop = 10 - break_point1 = 5 - break_point2 = 7 - - with ops.Graph().as_default() as g: - init_op, get_next, save_op, _ = _build_graph(start, stop) - with self.test_session(graph=g) as sess: - sess.run(variables.global_variables_initializer()) - sess.run(init_op) - for i in range(start, break_point1): - self.assertEqual(i, sess.run(get_next)) - sess.run(save_op) - - with ops.Graph().as_default() as g: - init_op, get_next, save_op, restore_op = _build_graph(start, stop) - with self.test_session(graph=g) as sess: - sess.run(restore_op) - for i in range(break_point1, break_point2): - self.assertEqual(i, sess.run(get_next)) - sess.run(save_op) - - break_point2 = 7 - with ops.Graph().as_default() as g: - init_op, get_next, save_op, restore_op = _build_graph(start, stop) - with self.test_session(graph=g) as sess: - sess.run(restore_op) - for i in range(break_point2, stop): - self.assertEqual(i, sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testSaveRestoreWithRepeat(self): - - def _build_graph(start, stop, num_epochs): - iterator = dataset_ops.Dataset.range( - start, stop).repeat(num_epochs).make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - save_op = self._save_op(iterator._iterator_resource) - restore_op = self._restore_op(iterator._iterator_resource) - return init_op, get_next, save_op, restore_op - - start = 2 - stop = 10 - num_epochs = 5 - break_range = 5 - break_epoch = 3 - with ops.Graph().as_default() as g: - init_op, get_next, save_op, restore_op = _build_graph( - start, stop, num_epochs) - with self.test_session(graph=g) as sess: - sess.run(variables.global_variables_initializer()) - sess.run(init_op) - # Note: There is no checkpoint saved currently so a NotFoundError is - # raised. - with self.assertRaises(errors.NotFoundError): - sess.run(restore_op) - for _ in range(break_epoch - 1): - for i in range(start, stop): - self.assertEqual(i, sess.run(get_next)) - for i in range(start, break_range): - self.assertEqual(i, sess.run(get_next)) - sess.run(save_op) - - with ops.Graph().as_default() as g: - init_op, get_next, _, restore_op = _build_graph(start, stop, num_epochs) - with self.test_session(graph=g) as sess: - sess.run(restore_op) - for i in range(break_range, stop): - self.assertEqual(i, sess.run(get_next)) - for _ in range(break_epoch, num_epochs): - for i in range(start, stop): - self.assertEqual(i, sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testSaveRestoreExhaustedIterator(self): - - def _build_graph(start, stop, num_epochs): - iterator = dataset_ops.Dataset.range( - start, stop).repeat(num_epochs).make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - save_op = self._save_op(iterator._iterator_resource) - restore_op = self._restore_op(iterator._iterator_resource) - return init_op, get_next, save_op, restore_op - - start = 2 - stop = 10 - num_epochs = 5 - with ops.Graph().as_default() as g: - init_op, get_next, save_op, restore_op = _build_graph( - start, stop, num_epochs) - with self.test_session(graph=g) as sess: - sess.run(variables.global_variables_initializer()) - sess.run(init_op) - # Note: There is no checkpoint saved currently so a NotFoundError is - # raised. - with self.assertRaises(errors.NotFoundError): - sess.run(restore_op) - for _ in range(num_epochs): - for i in range(start, stop): - self.assertEqual(i, sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - sess.run(save_op) - - with ops.Graph().as_default() as g: - init_op, get_next, _, restore_op = _build_graph(start, stop, num_epochs) - with self.test_session(graph=g) as sess: - sess.run(restore_op) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) + self.run_core_tests(lambda: self._build_range_dataset(start, stop), + lambda: self._build_range_dataset(start, stop_1), + stop - start) if __name__ == "__main__": diff --git a/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py index 5338ec56bf275e481a984964e39aa0c1ade3a752..e0494736b72ae52f586cb80d42a5c1e50ac17a61 100644 --- a/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py @@ -21,6 +21,7 @@ import itertools import numpy as np +from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base from tensorflow.contrib.data.python.ops import scan_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import constant_op @@ -124,5 +125,18 @@ class ScanDatasetTest(test.TestCase): scan_ops.scan(constant_op.constant(1, dtype=dtypes.int32), _scan_fn)) +class ScanDatasetSerialzationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_dataset(self, num_elements): + return dataset_ops.Dataset.from_tensors(1).repeat(num_elements).apply( + scan_ops.scan([0, 1], lambda a, _: ([a[1], a[0] + a[1]], a[1]))) + + def testScanCore(self): + num_output = 5 + self.run_core_tests(lambda: self._build_dataset(num_output), + lambda: self._build_dataset(2), num_output) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py index 6b5b53cc0f8f2d1df5622a5bc5e2f8ef04c6342a..45943d56ecb4bc18a6221157d0eeeae4efdf23cc 100644 --- a/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py @@ -18,12 +18,12 @@ from __future__ import division from __future__ import print_function import collections -import os import numpy as np +from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base from tensorflow.contrib.data.python.ops import dataset_ops as contrib_dataset_ops -from tensorflow.contrib.data.python.ops import iterator_ops as contrib_iterator_ops +from tensorflow.contrib.data.python.ops import shuffle_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import iterator_ops from tensorflow.python.framework import constant_op @@ -31,9 +31,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops -from tensorflow.python.platform import gfile from tensorflow.python.platform import test -from tensorflow.python.training import saver as saver_lib class ShuffleDatasetTest(test.TestCase): @@ -157,321 +155,135 @@ class ShuffleDatasetTest(test.TestCase): self.assertEqual(10, counts[i]) -class ShuffleDatasetSerializationTest(test.TestCase): +class ShuffleDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): - def tearDown(self): - # Remove all checkpoint files. - prefix = self._ckpt_path() - pattern = prefix + "*" - files = gfile.Glob(pattern) - map(gfile.Remove, files) - - def _build_graph(self, - range_limit=10, - num_repeats=5, - buffer_size=5, - seed=None, - reshuffle_each_iteration=None, - build_saveable=True): - iterator = dataset_ops.Dataset.range(range_limit).shuffle( + def _build_shuffle_dataset( + self, + range_limit=10, + num_repeats=5, + buffer_size=5, + seed=None, + reshuffle_each_iteration=None, + ): + return dataset_ops.Dataset.range(range_limit).shuffle( buffer_size, seed=seed, - reshuffle_each_iteration=reshuffle_each_iteration).repeat( - num_repeats).make_initializable_iterator() - if build_saveable: - saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator) - ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) - init_op = iterator.initializer - get_next = iterator.get_next() - ops.add_to_collection("iterator_ops", init_op) - ops.add_to_collection("iterator_ops", get_next) - saver = saver_lib.Saver(allow_empty=True) - return init_op, get_next, saver - - def _ckpt_path(self): - return os.path.join(self.get_temp_dir(), "iterator") - - def _latest_ckpt(self): - return saver_lib.latest_checkpoint(self.get_temp_dir()) - - def _save(self, sess, saver): - saver.save(sess, self._ckpt_path()) + reshuffle_each_iteration=reshuffle_each_iteration).repeat(num_repeats) - def _restore(self, saver, sess): - saver.restore(sess, self._latest_ckpt()) + def testShuffleCore(self): - def _import_meta_graph(self): - meta_file_path = self._ckpt_path() + ".meta" - return saver_lib.import_meta_graph(meta_file_path) - - def _testReadWithBreaks(self, break_points, init_before_restore=False): seed = 55 range_limit = 10 num_repeats = 5 num_outputs = range_limit * num_repeats buffer_sizes = [1, 3, 8, 10, 25, 50] reshuffle_each_iteration = False + # pylint: disable=cell-var-from-loop + # pylint: disable=g-long-lambda for buffer_size in buffer_sizes: - expected = [] - actual = [] - # Generate the ground truth. - with ops.Graph().as_default() as g: - g.seed = 10 - init_op, get_next_op, _ = self._build_graph( - range_limit=range_limit, - num_repeats=num_repeats, - buffer_size=buffer_size, - seed=seed, - reshuffle_each_iteration=reshuffle_each_iteration) - with self.test_session(graph=g) as sess: - sess.run(init_op) - for _ in range(num_outputs): - expected.append(sess.run(get_next_op)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next_op) - - # Run and checkpoint after first break_point. - with ops.Graph().as_default() as g: - g.seed = 10 - init_op, get_next_op, saver = self._build_graph( - range_limit=range_limit, - num_repeats=num_repeats, - buffer_size=buffer_size, - seed=seed, - reshuffle_each_iteration=reshuffle_each_iteration) - with self.test_session(graph=g) as sess: - sess.run(init_op) - for _ in range(break_points[0]): - actual.append(sess.run(get_next_op)) - self._save(sess, saver) - - # Load from checkpoint and continue running while stopping at each - # subsequent checkpoint. - for i in range(len(break_points)): - with ops.Graph().as_default() as g: - saver = self._import_meta_graph() - init_op, get_next_op = ops.get_collection("iterator_ops") - with self.test_session(graph=g) as sess: - if init_before_restore: - sess.run(init_op) - self._restore(saver, sess) - start = break_points[i] - end = break_points[ - i + 1] if i < len(break_points) - 1 else num_outputs - for _ in range(end - start): - actual.append(sess.run(get_next_op)) - self._save(sess, saver) - if end == num_outputs: - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next_op) - self.assertEqual(expected, actual) - - def testSaveRestore(self): - self._testReadWithBreaks([8]) # rng buffer_size: 0 - self._testReadWithBreaks([13]) # rng buffer_size: 1 - self._testReadWithBreaks([18]) # rng buffer_size: 2 - self._testReadWithBreaks([23]) # rng buffer_size: 3 - - def testSaveUnusedIterator(self): - self._testReadWithBreaks([0]) - - def testSaveFullyUsedIterator(self): - self._testReadWithBreaks([50]) - - def testMultipleBreaks(self): - self._testReadWithBreaks([0, 5, 9, 15, 25, 32]) - - def testIdempotence(self): - # Attempt to save iterator immediately after restoring. - self._testReadWithBreaks([1, 1, 5, 5, 5, 25, 32]) - - def testInitThenRestore(self): - self._testReadWithBreaks([0, 5, 9, 15, 25, 32], init_before_restore=True) - - def testRestoreExhaustedIterator(self): - seed = 55 - range_limit = 10 - num_repeats = 5 - num_outputs = range_limit * num_repeats - buffer_sizes = [1, 3, 8, 10, 25, 50] - reshuffle_each_iteration = False - for buffer_size in buffer_sizes: - with ops.Graph().as_default() as g: - g.seed = 10 - init_op, get_next_op, saver = self._build_graph( - range_limit=range_limit, - num_repeats=num_repeats, - buffer_size=buffer_size, - seed=seed, - reshuffle_each_iteration=reshuffle_each_iteration) - with self.test_session(graph=g) as sess: - sess.run(init_op) - for _ in range(num_outputs): - sess.run(get_next_op) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next_op) - self._save(sess, saver) - - with ops.Graph().as_default() as g: - saver = self._import_meta_graph() - init_op, get_next_op = ops.get_collection("iterator_ops") - with self.test_session(graph=g) as sess: - self._restore(saver, sess) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next_op) - - def testResetRestoredIterator(self): - seed = 55 - range_limit = 10 - num_repeats = 5 - num_outputs = range_limit * num_repeats - buffer_sizes = [1, 3, 8, 10, 25, 50] - reshuffle_each_iteration = False - for buffer_size in buffer_sizes: - with ops.Graph().as_default() as g: - g.seed = 10 - init_op, get_next_op, saver = self._build_graph( - range_limit=range_limit, - num_repeats=num_repeats, - buffer_size=buffer_size, - seed=seed, - reshuffle_each_iteration=reshuffle_each_iteration) - with self.test_session(graph=g) as sess: - sess.run(init_op) - for _ in range(num_outputs // 2): - sess.run(get_next_op) - self._save(sess, saver) - - outputs = [] - with ops.Graph().as_default() as g: - saver = self._import_meta_graph() - init_op, get_next_op = ops.get_collection("iterator_ops") - with self.test_session(graph=g) as sess: - self._restore(saver, sess) - sess.run(init_op) - for _ in range(num_outputs): - outputs.append(sess.run(get_next_op)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next_op) - expected_outputs_sorted = sorted( - np.array([range(range_limit) - for _ in range(num_repeats)]).flatten()) - self.assertEqual(expected_outputs_sorted, sorted(outputs)) - - def testRestoreInModifiedGraph(self): - seed = 55 - break_point = 25 - range_limit = 10 - num_repeats = 5 - num_outputs = range_limit * num_repeats - buffer_sizes = [3, 8, 10, 25, 50] - reshuffle_each_iteration = False - for buffer_size in buffer_sizes: - expected = [] - actual_without_restore = [] - actual = [] - with ops.Graph().as_default() as g: - g.seed = 10 - init_op, get_next_op, saver = self._build_graph( - range_limit=range_limit, - num_repeats=num_repeats, - buffer_size=buffer_size, - seed=seed, - reshuffle_each_iteration=reshuffle_each_iteration) - with self.test_session(graph=g) as sess: - sess.run(init_op) - for _ in range(break_point): - expected.append(sess.run(get_next_op)) - actual.extend(expected) - self._save(sess, saver) - for _ in range(num_outputs - break_point): - expected.append(sess.run(get_next_op)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next_op) - - with ops.Graph().as_default() as g: - g.seed = 20 # Different seed than previous graph for shuffle rngs. - init_op, get_next_op, saver = self._build_graph( - range_limit=range_limit, - num_repeats=num_repeats, - buffer_size=buffer_size, - seed=seed, - reshuffle_each_iteration=reshuffle_each_iteration) - with self.test_session(graph=g) as sess: - sess.run(init_op) - for _ in range(num_outputs): - actual_without_restore.append(sess.run(get_next_op)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next_op) - - with ops.Graph().as_default() as g: - g.seed = 20 # Different seed than previous graph for shuffle rngs. - init_op, get_next_op, saver = self._build_graph( - range_limit=range_limit, - num_repeats=num_repeats, - buffer_size=buffer_size, - seed=seed, - reshuffle_each_iteration=reshuffle_each_iteration) - with self.test_session(graph=g) as sess: - self._restore(saver, sess) - for _ in range(num_outputs - break_point): - actual.append(sess.run(get_next_op)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next_op) - - # Since the modified graph has a different random seed it produces a - # different order of examples. - self.assertNotEqual(expected, actual_without_restore) - self.assertEqual(sorted(expected), sorted(actual_without_restore)) - self.assertEqual(expected, actual) - - def testDoNotBuildSaveable(self): - seed = 55 - break_point = 25 - range_limit = 10 - num_repeats = 5 - num_outputs = range_limit * num_repeats - buffer_sizes = [3, 8, 10, 25, 50] - reshuffle_each_iteration = False - for buffer_size in buffer_sizes: - actual = [] - with ops.Graph().as_default() as g: - g.seed = 10 - init_op, get_next_op, saver = self._build_graph( - range_limit=range_limit, - num_repeats=num_repeats, - buffer_size=buffer_size, - seed=seed, - reshuffle_each_iteration=reshuffle_each_iteration) - with self.test_session(graph=g) as sess: - sess.run(init_op) - for _ in range(break_point): - sess.run(get_next_op) - self._save(sess, saver) - - with ops.Graph().as_default() as g: - g.seed = 20 # Different seed than previous graph for shuffle rngs. - init_op, get_next_op, saver = self._build_graph( - range_limit=range_limit, - num_repeats=num_repeats, - buffer_size=buffer_size, - seed=seed, - reshuffle_each_iteration=reshuffle_each_iteration, - build_saveable=False) - with self.test_session(graph=g) as sess: - # Since the SaveableObject was not added to Saver's list - # of saveables, iterator state is not restored by saver.restore(). - self._restore(saver, sess) - with self.assertRaises(errors.FailedPreconditionError): - sess.run(get_next_op) - sess.run(init_op) - for _ in range(num_outputs): - actual.append(sess.run(get_next_op)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next_op) - expected_outputs_sorted = sorted( - np.array([range(range_limit) for _ in range(num_repeats)]).flatten()) - self.assertEqual(expected_outputs_sorted, sorted(actual)) + self.run_core_tests( + lambda: self._build_shuffle_dataset( + range_limit=range_limit, + num_repeats=num_repeats, + buffer_size=buffer_size, + seed=seed, + reshuffle_each_iteration=reshuffle_each_iteration), + lambda: self._build_shuffle_dataset( + range_limit=range_limit, + num_repeats=num_repeats, + buffer_size=buffer_size, + seed=10, + reshuffle_each_iteration=reshuffle_each_iteration), + num_outputs) + # pylint: enable=cell-var-from-loop + # pylint: enable=g-long-lambda + + +class ShuffleAndRepeatTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_ds(self, seed, count=5, num_elements=20): + return dataset_ops.Dataset.range(num_elements).apply( + shuffle_ops.shuffle_and_repeat(buffer_size=5, count=count, seed=seed)) + + def testCorrectOutput(self): + output = self.gen_outputs(lambda: self._build_ds(10), [], 100) + self.assertSequenceEqual( + sorted(output), sorted( + np.array([range(20) for _ in range(5)]).flatten())) + for i in range(5): + self.assertSequenceEqual(sorted(output[i * 20:(i + 1) * 20]), range(20)) + + def testReshuffling(self): + # Check that the output orders of different epochs are indeed different. + output = self.gen_outputs(lambda: self._build_ds(10), [], 100) + for i in range(4): + epoch1 = output[i * 20:(i + 1) * 20] + epoch2 = output[(i + 1) * 20:(i + 2) * 20] + self.assertNotEqual(epoch1, epoch2) + + def testSameOrderForSameSeeds(self): + output1 = self.gen_outputs(lambda: self._build_ds(10), [], 100) + output2 = self.gen_outputs(lambda: self._build_ds(10), [], 100) + self.assertEqual(output1, output2) + + def testDifferentOrderForDifferentSeeds(self): + output1 = self.gen_outputs(lambda: self._build_ds(10), [], 100) + output2 = self.gen_outputs(lambda: self._build_ds(20), [], 100) + self.assertNotEqual(output1, output2) + self.assertEqual(sorted(output1), sorted(output2)) + + def testCountNone(self): + output1 = self.gen_outputs( + lambda: self._build_ds(10, count=None), [], 100, verify_exhausted=False) + output2 = self.gen_outputs( + lambda: self._build_ds(20, count=None), [], 100, verify_exhausted=False) + self.assertNotEqual(output1, output2) + self.assertEqual(sorted(output1), sorted(output2)) + + def testCountMinusOne(self): + output1 = self.gen_outputs( + lambda: self._build_ds(10, count=-1), [], 100, verify_exhausted=False) + output2 = self.gen_outputs( + lambda: self._build_ds(20, count=-1), [], 100, verify_exhausted=False) + self.assertNotEqual(output1, output2) + self.assertEqual(sorted(output1), sorted(output2)) + + def testInfiniteOutputs(self): + # Asserting the iterator is exhausted after producing 100 items should fail. + with self.assertRaises(AssertionError): + self.gen_outputs(lambda: self._build_ds(10, count=None), [], 100) + with self.assertRaises(AssertionError): + self.gen_outputs(lambda: self._build_ds(10, count=-1), [], 100) + + def testInfiniteEmpty(self): + with self.assertRaises(errors.OutOfRangeError): + self.gen_outputs(lambda: self._build_ds(10, count=None, num_elements=0), + [], 100) + with self.assertRaises(errors.OutOfRangeError): + self.gen_outputs(lambda: self._build_ds(10, count=-1, num_elements=0), [], + 100) + + def testLargeBufferSize(self): + with ops.Graph().as_default() as g: + ds = dataset_ops.Dataset.range(20).apply( + shuffle_ops.shuffle_and_repeat(buffer_size=21)) + get_next_op = ds.make_one_shot_iterator().get_next() + with self.test_session(graph=g) as sess: + sess.run(get_next_op) + + +class ShuffleAndRepeatSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_ds(self, seed): + return dataset_ops.Dataset.range(20).apply( + shuffle_ops.shuffle_and_repeat(buffer_size=5, count=5, seed=seed)) + + def testCore(self): + self.run_core_tests(lambda: self._build_ds(10), lambda: self._build_ds(20), + 100) if __name__ == "__main__": diff --git a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py index 8f24d6b2f612cff662aa8a36085bc69a9ea1a290..07bdf920446e953c2a1abaf495d2e9e1256106fd 100644 --- a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py @@ -19,6 +19,7 @@ from __future__ import print_function import numpy as np +from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base from tensorflow.contrib.data.python.ops import stats_ops from tensorflow.core.framework import summary_pb2 from tensorflow.python.data.ops import dataset_ops @@ -209,5 +210,48 @@ class StatsDatasetTest(test.TestCase): sess.run(stats_aggregator_1.subscribe(iterator)) +class StatsDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_dataset_bytes_stats(self, num_elements): + return dataset_ops.Dataset.range(num_elements).map( + lambda x: array_ops.tile([x], ops.convert_to_tensor([x]))).apply( + stats_ops.bytes_produced_stats("bytes_produced")) + + def testBytesStatsDatasetSaveableCore(self): + num_outputs = 100 + self.run_core_tests( + lambda: self._build_dataset_bytes_stats(num_outputs), + lambda: self._build_dataset_bytes_stats(num_outputs // 10), num_outputs) + + def _build_dataset_latency_stats(self, num_elements, tag="record_latency"): + return dataset_ops.Dataset.range(num_elements).apply( + stats_ops.latency_stats(tag)) + + def _build_dataset_multiple_tags(self, + num_elements, + tag1="record_latency", + tag2="record_latency_2"): + return dataset_ops.Dataset.range(num_elements).apply( + stats_ops.latency_stats(tag1)).apply(stats_ops.latency_stats(tag2)) + + def testLatencyStatsDatasetSaveableCore(self): + num_outputs = 100 + + self.run_core_tests( + lambda: self._build_dataset_latency_stats(num_outputs), + lambda: self._build_dataset_latency_stats(num_outputs // 10), + num_outputs) + + self.run_core_tests(lambda: self._build_dataset_multiple_tags(num_outputs), + None, num_outputs) + + tag1 = "record_latency" + tag2 = "record_latency" + self.run_core_tests( + lambda: self._build_dataset_multiple_tags(num_outputs, tag1, tag2), + None, num_outputs) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py new file mode 100644 index 0000000000000000000000000000000000000000..55296d5710e7f66408bb7464cf790149d6df9fa1 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py @@ -0,0 +1,96 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the experimental input pipeline ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base +from tensorflow.contrib.data.python.ops import dataset_ops +from tensorflow.contrib.data.python.ops import unique +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.platform import test +from tensorflow.python.util import compat + + +class UniqueDatasetTest(test.TestCase): + + def _testSimpleHelper(self, dtype, test_cases): + """Test the `unique()` transformation on a list of test cases. + + Args: + dtype: The `dtype` of the elements in each test case. + test_cases: A list of pairs of lists. The first component is the test + input that will be passed to the transformation; the second component + is the expected sequence of outputs from the transformation. + """ + + # The `current_test_case` will be updated when we loop over `test_cases` + # below; declare it here so that the generator can capture it once. + current_test_case = [] + dataset = dataset_ops.Dataset.from_generator(lambda: current_test_case, + dtype).apply(unique.unique()) + iterator = dataset.make_initializable_iterator() + next_element = iterator.get_next() + + with self.test_session() as sess: + for test_case, expected in test_cases: + current_test_case = test_case + sess.run(iterator.initializer) + for element in expected: + if dtype == dtypes.string: + element = compat.as_bytes(element) + self.assertAllEqual(element, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testSimpleInt(self): + for dtype in [dtypes.int32, dtypes.int64]: + self._testSimpleHelper(dtype, [ + ([], []), + ([1], [1]), + ([1, 1, 1, 1, 1, 1, 1], [1]), + ([1, 2, 3, 4], [1, 2, 3, 4]), + ([1, 2, 4, 3, 2, 1, 2, 3, 4], [1, 2, 4, 3]), + ([[1], [1, 1], [1, 1, 1]], [[1], [1, 1], [1, 1, 1]]), + ([[1, 1], [1, 1], [2, 2], [3, 3], [1, 1]], [[1, 1], [2, 2], [3, 3]]), + ]) + + def testSimpleString(self): + self._testSimpleHelper(dtypes.string, [ + ([], []), + (["hello"], ["hello"]), + (["hello", "hello", "hello"], ["hello"]), + (["hello", "world"], ["hello", "world"]), + (["foo", "bar", "baz", "baz", "bar", "foo"], ["foo", "bar", "baz"]), + ]) + + +class UniqueSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def testUnique(self): + + def build_dataset(num_elements, unique_elem_range): + return dataset_ops.Dataset.range(num_elements).map( + lambda x: x % unique_elem_range).apply(unique.unique()) + + self.run_core_tests(lambda: build_dataset(200, 100), + lambda: build_dataset(40, 100), 100) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/ops/BUILD b/tensorflow/contrib/data/python/ops/BUILD index 25ed58cdf5833cd041582046bc1a358625e321e0..4349085a10135b4dee842a29916aeb5febe9ddd4 100644 --- a/tensorflow/contrib/data/python/ops/BUILD +++ b/tensorflow/contrib/data/python/ops/BUILD @@ -40,6 +40,25 @@ py_library( ], ) +py_library( + name = "random_ops", + srcs = [ + "random_ops.py", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:constant_op", + "//tensorflow/python:dataset_ops_gen", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:random_seed", + "//tensorflow/python:tensor_shape", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/util:nest", + "//tensorflow/python/data/util:sparse", + ], +) + py_library( name = "readers", srcs = [ @@ -62,6 +81,19 @@ py_library( ], ) +py_library( + name = "shuffle_ops", + srcs = [ + "shuffle_ops.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":random_ops", + ":transformation_ops", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + py_library( name = "transformation_ops", srcs = [ @@ -73,6 +105,7 @@ py_library( "resampling.py", "scan_ops.py", "stats_ops.py", + "unique.py", ], srcs_version = "PY2AND3", deps = [ @@ -89,6 +122,7 @@ py_library( "//tensorflow/python:tensor_util", "//tensorflow/python:util", "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/util:convert", "//tensorflow/python/data/util:nest", "//tensorflow/python/data/util:sparse", "//third_party/py/numpy", diff --git a/tensorflow/contrib/data/python/ops/batching.py b/tensorflow/contrib/data/python/ops/batching.py index 63782d229e1535892686f202ca1f0833dee6ed80..76c07b2c999e1424e8efe4af515fddee73922c9c 100644 --- a/tensorflow/contrib/data/python/ops/batching.py +++ b/tensorflow/contrib/data/python/ops/batching.py @@ -22,6 +22,7 @@ from tensorflow.python.data.util import nest from tensorflow.python.data.util import sparse from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops @@ -231,32 +232,29 @@ class DenseToSparseBatchDataset(dataset_ops.Dataset): input_dataset.output_types) self._input_dataset = input_dataset self._batch_size = batch_size - # pylint: disable=protected-access - self._row_shape = dataset_ops._partial_shape_to_tensor(row_shape) - # pylint: enable=protected-access + self._row_shape = row_shape def _as_variant_tensor(self): return gen_dataset_ops.dense_to_sparse_batch_dataset( self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access self._batch_size, - self._row_shape, - output_shapes=self.output_shapes, - output_types=self.output_types) + row_shape=dataset_ops._partial_shape_to_tensor(self._row_shape), # pylint: disable=protected-access + output_shapes=nest.flatten( + sparse.as_dense_shapes(self.output_shapes, self.output_classes)), + output_types=nest.flatten( + sparse.as_dense_types(self.output_types, self.output_classes))) @property def output_classes(self): - return (ops.Tensor, ops.Tensor, ops.Tensor) + return sparse_tensor.SparseTensor @property def output_shapes(self): - num_elements = tensor_shape.Dimension(None) - return (tensor_shape.matrix(num_elements, self._row_shape.shape[0] + 1), - tensor_shape.vector(num_elements), - tensor_shape.vector(self._row_shape.shape[0] + 1)) + return tensor_shape.vector(None).concatenate(self._row_shape) @property def output_types(self): - return (dtypes.int64, self._input_dataset.output_types, dtypes.int64) + return self._input_dataset.output_types class _RestructuredDataset(dataset_ops.Dataset): @@ -390,17 +388,12 @@ def map_and_batch(map_func, batch_size, num_parallel_batches=1): """Fused implementation of `map` and `batch`. Maps `map_func` across `batch_size` consecutive elements of this dataset - and then combines them into a batch. Similarly to `batch_and_drop_remainder`, - if the batch size does not evenly divide the input dataset size, this - transformation will drop the final smaller element. - - - Functionally, it is equivalent to `map` followed by - `batch_and_drop_remainder`. However, by fusing the two transformations - together, the implementation can be more efficient. This transformation is a - stop gap solution for performance critical workloads. Once automatic input - pipeline optimization are implemented, the fusing of map and batch will not - need to be exposed at the API level and this method will be removed. + and then combines them into a batch. Functionally, it is equivalent to `map` + followed by `batch`. However, by fusing the two transformations together, the + implementation can be more efficient. Surfacing this transformation in the API + is temporary. Once automatic input pipeline optimization is implemented, + the fusing of `map` and `batch` will happen automatically and this API will be + deprecated. Args: map_func: A function mapping a nested structure of tensors to another @@ -414,7 +407,7 @@ def map_and_batch(map_func, batch_size, num_parallel_batches=1): Returns: A `Dataset` transformation function, which can be passed to - @{tf.contrib.data.Dataset.apply}. + @{tf.data.Dataset.apply}. """ def _apply_fn(dataset): diff --git a/tensorflow/contrib/data/python/ops/dataset_ops.py b/tensorflow/contrib/data/python/ops/dataset_ops.py index 626a9e0edcea5928b1636c1a2a86e83657c966a5..fafd231061a9108b2585f4fc9256b6f069b7c37a 100644 --- a/tensorflow/contrib/data/python/ops/dataset_ops.py +++ b/tensorflow/contrib/data/python/ops/dataset_ops.py @@ -364,7 +364,7 @@ class Dataset(dataset_ops.Dataset): When reading a single input file, you can skip elements as follows: ```python - d = tf.contrib.data.TFRecordDataset(FLAGS.input_file) + d = tf.data.TFRecordDataset(FLAGS.input_file) d = d.shard(FLAGS.num_workers, FLAGS.worker_index) d = d.repeat(FLAGS.num_epochs) d = d.shuffle(FLAGS.shuffle_buffer_size) @@ -382,12 +382,11 @@ class Dataset(dataset_ops.Dataset): sharding strategy within a complete pipeline: ```python - d = Dataset.list_files(FLAGS.pattern) + d = tf.data.Dataset.list_files(FLAGS.pattern) d = d.shard(FLAGS.num_workers, FLAGS.worker_index) d = d.repeat(FLAGS.num_epochs) d = d.shuffle(FLAGS.shuffle_buffer_size) - d = d.repeat() - d = d.interleave(tf.contrib.data.TFRecordDataset, + d = d.interleave(tf.data.TFRecordDataset, cycle_length=FLAGS.num_readers, block_length=1) d = d.map(parser_fn, num_parallel_calls=FLAGS.num_map_threads) ``` @@ -549,7 +548,7 @@ class Dataset(dataset_ops.Dataset): elements are produced. `cycle_length` controls the number of input elements that are processed concurrently. If you set `cycle_length` to 1, this transformation will handle one input element at a time, and will produce - identical results = to @{tf.contrib.data.Dataset.flat_map}. In general, + identical results = to @{tf.data.Dataset.flat_map}. In general, this transformation will apply `map_func` to `cycle_length` input elements, open iterators on the returned `Dataset` objects, and cycle through them producing `block_length` consecutive elements from each iterator, and diff --git a/tensorflow/contrib/data/python/ops/interleave_ops.py b/tensorflow/contrib/data/python/ops/interleave_ops.py index 53324e06e7f1dc249388410f0e14e42336630cd1..3124ca1d1540e12d949dded88ce1c66181be3595 100644 --- a/tensorflow/contrib/data/python/ops/interleave_ops.py +++ b/tensorflow/contrib/data/python/ops/interleave_ops.py @@ -18,6 +18,7 @@ from __future__ import division from __future__ import print_function from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.util import convert from tensorflow.python.data.util import nest from tensorflow.python.data.util import sparse from tensorflow.python.framework import dtypes @@ -31,7 +32,7 @@ class ParallelInterleaveDataset(dataset_ops.Dataset): """A `Dataset` that maps a function over its input and flattens the result.""" def __init__(self, input_dataset, map_func, cycle_length, block_length, - sloppy): + sloppy, buffer_output_elements, prefetch_input_elements): """See `tf.contrib.data.parallel_interleave()` for details.""" super(ParallelInterleaveDataset, self).__init__() self._input_dataset = input_dataset @@ -74,6 +75,14 @@ class ParallelInterleaveDataset(dataset_ops.Dataset): block_length, dtype=dtypes.int64, name="block_length") self._sloppy = ops.convert_to_tensor( sloppy, dtype=dtypes.bool, name="sloppy") + self._buffer_output_elements = convert.optional_param_to_tensor( + "buffer_output_elements", + buffer_output_elements, + argument_default=2 * block_length) + self._prefetch_input_elements = convert.optional_param_to_tensor( + "prefetch_input_elements", + prefetch_input_elements, + argument_default=2 * cycle_length) def _as_variant_tensor(self): return gen_dataset_ops.parallel_interleave_dataset( @@ -82,6 +91,8 @@ class ParallelInterleaveDataset(dataset_ops.Dataset): self._cycle_length, self._block_length, self._sloppy, + self._buffer_output_elements, + self._prefetch_input_elements, f=self._map_func, output_types=nest.flatten( sparse.as_dense_types(self.output_types, self.output_classes)), @@ -101,7 +112,12 @@ class ParallelInterleaveDataset(dataset_ops.Dataset): return self._output_types -def parallel_interleave(map_func, cycle_length, block_length=1, sloppy=False): +def parallel_interleave(map_func, + cycle_length, + block_length=1, + sloppy=False, + buffer_output_elements=None, + prefetch_input_elements=None): """A parallel version of the `Dataset.interleave()` transformation. `parallel_interleave()` maps `map_func` across its input to produce nested @@ -129,12 +145,17 @@ def parallel_interleave(map_func, cycle_length, block_length=1, sloppy=False): Args: map_func: A function mapping a nested structure of tensors to a `Dataset`. - cycle_length: The number of threads to interleave from in parallel. - block_length: The number of consecutive elements to pull from a thread - before advancing to the next thread. + cycle_length: The number of input `Dataset`s to interleave from in parallel. + block_length: The number of consecutive elements to pull from an input + `Dataset` before advancing to the next input `Dataset`. sloppy: If false, elements are produced in deterministic order. Otherwise, the implementation is allowed, for the sake of expediency, to produce elements in a non-deterministic order. + buffer_output_elements: The number of elements each iterator being + interleaved should buffer (similar to the `.prefetch()` transformation for + each interleaved iterator). + prefetch_input_elements: The number of input elements to transform to + iterators before they are needed for interleaving. Returns: A `Dataset` transformation function, which can be passed to @@ -142,7 +163,9 @@ def parallel_interleave(map_func, cycle_length, block_length=1, sloppy=False): """ def _apply_fn(dataset): return ParallelInterleaveDataset( - dataset, map_func, cycle_length, block_length, sloppy) + dataset, map_func, cycle_length, block_length, sloppy, + buffer_output_elements, prefetch_input_elements) + return _apply_fn @@ -187,11 +210,11 @@ def sloppy_interleave(map_func, cycle_length, block_length=1): map_func: A function mapping a nested structure of tensors (having shapes and types defined by `self.output_shapes` and `self.output_types`) to a `Dataset`. - cycle_length: The number of threads to interleave from in parallel. - block_length: The number of consecutive elements to pull from a thread - before advancing to the next thread. Note: sloppy_interleave will - skip the remainder of elements in the block_length in order to avoid - blocking. + cycle_length: The number of input `Dataset`s to interleave from in parallel. + block_length: The number of consecutive elements to pull from an input + `Dataset` before advancing to the next input `Dataset`. Note: + `sloppy_interleave` will skip the remainder of elements in the + `block_length` in order to avoid blocking. Returns: A `Dataset` transformation function, which can be passed to @@ -199,5 +222,12 @@ def sloppy_interleave(map_func, cycle_length, block_length=1): """ def _apply_fn(dataset): return ParallelInterleaveDataset( - dataset, map_func, cycle_length, block_length, sloppy=True) + dataset, + map_func, + cycle_length, + block_length, + sloppy=True, + buffer_output_elements=None, + prefetch_input_elements=None) + return _apply_fn diff --git a/tensorflow/contrib/data/python/ops/random_ops.py b/tensorflow/contrib/data/python/ops/random_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..7d727165feabb101549567f28a2dfa07083de244 --- /dev/null +++ b/tensorflow/contrib/data/python/ops/random_ops.py @@ -0,0 +1,67 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Datasets for random number generators.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.util import nest +from tensorflow.python.data.util import sparse +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import random_seed +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import gen_dataset_ops + + +class RandomDataset(dataset_ops.Dataset): + """A `Dataset` of pseudorandom values.""" + + def __init__(self, seed=None): + """A `Dataset` of pseudorandom values.""" + super(RandomDataset, self).__init__() + seed, seed2 = random_seed.get_seed(seed) + if seed is None: + self._seed = constant_op.constant(0, dtype=dtypes.int64, name="seed") + else: + self._seed = ops.convert_to_tensor(seed, dtype=dtypes.int64, name="seed") + if seed2 is None: + self._seed2 = constant_op.constant(0, dtype=dtypes.int64, name="seed2") + else: + self._seed2 = ops.convert_to_tensor( + seed2, dtype=dtypes.int64, name="seed2") + + def _as_variant_tensor(self): + return gen_dataset_ops.random_dataset( + seed=self._seed, + seed2=self._seed2, + output_shapes=nest.flatten( + sparse.as_dense_shapes(self.output_shapes, self.output_classes)), + output_types=nest.flatten( + sparse.as_dense_types(self.output_types, self.output_classes))) + + @property + def output_classes(self): + return ops.Tensor + + @property + def output_shapes(self): + return tensor_shape.scalar() + + @property + def output_types(self): + return dtypes.int64 diff --git a/tensorflow/contrib/data/python/ops/readers.py b/tensorflow/contrib/data/python/ops/readers.py index acb7a43211482f9cdeed66542abab5dbde78d60e..347e5edc7b0d479dfa260e8cec500ffaaba375be 100644 --- a/tensorflow/contrib/data/python/ops/readers.py +++ b/tensorflow/contrib/data/python/ops/readers.py @@ -179,6 +179,7 @@ def read_batch_features(file_pattern, dataset = dataset.shuffle(capacity) dataset = dataset.batch(batch_size) dataset = dataset.map(lambda x: parsing_ops.parse_example(x, features)) + dataset = dataset.prefetch(1) iterator = dataset.make_one_shot_iterator() outputs = iterator.get_next() return outputs diff --git a/tensorflow/contrib/data/python/ops/scan_ops.py b/tensorflow/contrib/data/python/ops/scan_ops.py index 2744786e9eec4c9268ba854df6ea761339bb0b4e..1c88366273f5d186509454188e02350d4ea9f66b 100644 --- a/tensorflow/contrib/data/python/ops/scan_ops.py +++ b/tensorflow/contrib/data/python/ops/scan_ops.py @@ -188,7 +188,7 @@ def scan(initial_state, scan_func): Returns: A `Dataset` transformation function, which can be passed to - @{tf.contrib.data.Dataset.apply}. + @{tf.data.Dataset.apply}. """ def _apply_fn(dataset): return _ScanDataset(dataset, initial_state, scan_func) diff --git a/tensorflow/contrib/data/python/ops/shuffle_ops.py b/tensorflow/contrib/data/python/ops/shuffle_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..99bb79bc06a421f811869ca9169aaa11deaca2f3 --- /dev/null +++ b/tensorflow/contrib/data/python/ops/shuffle_ops.py @@ -0,0 +1,120 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Experimental shuffle ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.util import nest +from tensorflow.python.data.util import sparse +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import random_seed +from tensorflow.python.ops import gen_dataset_ops + + +class _ShuffleAndRepeatDataset(dataset_ops.Dataset): + """A `Dataset` that fuses `shuffle` and `repeat`.""" + + def __init__(self, + input_dataset, + buffer_size, + count=None, + seed=None): + """See `Dataset.map()` for details.""" + super(_ShuffleAndRepeatDataset, self).__init__() + self._input_dataset = input_dataset + self._buffer_size = ops.convert_to_tensor( + buffer_size, dtype=dtypes.int64, name="buffer_size") + if count is None: + self._count = constant_op.constant(-1, dtype=dtypes.int64, name="count") + else: + self._count = ops.convert_to_tensor( + count, dtype=dtypes.int64, name="count") + + seed, seed2 = random_seed.get_seed(seed) + if seed is None: + self._seed = constant_op.constant(0, dtype=dtypes.int64, name="seed") + else: + self._seed = ops.convert_to_tensor(seed, dtype=dtypes.int64, name="seed") + if seed2 is None: + self._seed2 = constant_op.constant(0, dtype=dtypes.int64, name="seed2") + else: + self._seed2 = ops.convert_to_tensor( + seed2, dtype=dtypes.int64, name="seed2") + + def _as_variant_tensor(self): + # pylint: disable=protected-access + input_resource = self._input_dataset._as_variant_tensor() + return gen_dataset_ops.shuffle_and_repeat_dataset( + input_resource, + buffer_size=self._buffer_size, + count=self._count, + seed=self._seed, + seed2=self._seed2, + output_types=nest.flatten( + sparse.as_dense_types(self.output_types, self.output_classes)), + output_shapes=nest.flatten( + sparse.as_dense_shapes(self.output_shapes, self.output_classes))) + # pylint: enable=protected-access + + @property + def output_classes(self): + return self._input_dataset.output_classes + + @property + def output_shapes(self): + return self._input_dataset.output_shapes + + @property + def output_types(self): + return self._input_dataset.output_types + + +def shuffle_and_repeat(buffer_size, count=None, seed=None): + """Shuffles and repeats a Dataset returning a new permutation for each epoch. + + `dataset.apply(tf.contrib.data.shuffle_and_repeat(buffer_size, count))` + + is equivalent to + + `dataset.shuffle(buffer_size, reshuffle_each_iteration=True).repeat(count)` + + The difference is that the latter dataset is not serializable. So, + if you need to checkpoint an input pipeline with reshuffling you must use + this implementation. + + Args: + buffer_size: A `tf.int64` scalar `tf.Tensor`, representing the + maximum number elements that will be buffered when prefetching. + count: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the + number of times the dataset should be repeated. The default behavior + (if `count` is `None` or `-1`) is for the dataset be repeated + indefinitely. + seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the + random seed that will be used to create the distribution. See + @{tf.set_random_seed} for behavior. + + Returns: + A `Dataset` transformation function, which can be passed to + @{tf.data.Dataset.apply}. + """ + + def _apply_fn(dataset): # pylint: disable=missing-docstring + return _ShuffleAndRepeatDataset(dataset, buffer_size, count, seed) + + return _apply_fn diff --git a/tensorflow/contrib/data/python/ops/stats_ops.py b/tensorflow/contrib/data/python/ops/stats_ops.py index b8875bd533ddc9e2c195646619dccf3aab5225e4..1dd0729513c0d46db25226178eb17b41efaae0ae 100644 --- a/tensorflow/contrib/data/python/ops/stats_ops.py +++ b/tensorflow/contrib/data/python/ops/stats_ops.py @@ -117,7 +117,7 @@ def bytes_produced_stats(tag): Returns: A `Dataset` transformation function, which can be passed to - @{tf.contrib.data.Dataset.apply}. + @{tf.data.Dataset.apply}. """ def _apply_fn(dataset): @@ -139,7 +139,7 @@ def latency_stats(tag): Returns: A `Dataset` transformation function, which can be passed to - @{tf.contrib.data.Dataset.apply}. + @{tf.data.Dataset.apply}. """ def _apply_fn(dataset): diff --git a/tensorflow/contrib/data/python/ops/unique.py b/tensorflow/contrib/data/python/ops/unique.py new file mode 100644 index 0000000000000000000000000000000000000000..133e17d20d0fc4c8d52cef3c95c132374e927a0b --- /dev/null +++ b/tensorflow/contrib/data/python/ops/unique.py @@ -0,0 +1,82 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Unique element dataset transformations.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.util import nest +from tensorflow.python.data.util import sparse +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import gen_dataset_ops + + +def unique(): + """Creates a `Dataset` from another `Dataset`, discarding duplicates. + + Use this transformation to produce a dataset that contains one instance of + each unique element in the input. For example: + + ```python + dataset = tf.data.Dataset.from_tensor_slices([1, 37, 2, 37, 2, 1]) + + # Using `unique()` will drop the duplicate elements. + dataset = dataset.apply(tf.contrib.data.unique()) # ==> { 1, 37, 2 } + ``` + + Returns: + A `Dataset` transformation function, which can be passed to + @{tf.data.Dataset.apply}. + """ + + def _apply_fn(dataset): + return UniqueDataset(dataset) + + return _apply_fn + + +class UniqueDataset(dataset_ops.Dataset): + """A `Dataset` contains the unique elements from its input.""" + + def __init__(self, input_dataset): + """See `unique()` for details.""" + super(UniqueDataset, self).__init__() + self._input_dataset = input_dataset + if input_dataset.output_types not in (dtypes.int32, dtypes.int64, + dtypes.string): + raise TypeError( + "`tf.contrib.data.unique()` only supports inputs with a single " + "`tf.int32`, `tf.int64`, or `tf.string` component.") + + def _as_variant_tensor(self): + return gen_dataset_ops.unique_dataset( + self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access + output_shapes=nest.flatten( + sparse.as_dense_shapes(self.output_shapes, self.output_classes)), + output_types=nest.flatten( + sparse.as_dense_types(self.output_types, self.output_classes))) + + @property + def output_classes(self): + return self._input_dataset.output_classes + + @property + def output_shapes(self): + return self._input_dataset.output_shapes + + @property + def output_types(self): + return self._input_dataset.output_types diff --git a/tensorflow/contrib/decision_trees/proto/BUILD b/tensorflow/contrib/decision_trees/proto/BUILD index 87c80740a8f0c0721394b5d832bc96e548e3a313..f6de5998d73a4869d2444cd90c9b64d1a2c889ac 100644 --- a/tensorflow/contrib/decision_trees/proto/BUILD +++ b/tensorflow/contrib/decision_trees/proto/BUILD @@ -7,7 +7,11 @@ exports_files([ "generic_tree_model_proto.swig", ]) -load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library") +load( + "//tensorflow/core:platform/default/build_config.bzl", + "tf_proto_library", + "tf_pyclif_proto_library", +) filegroup( name = "all_files", @@ -34,3 +38,10 @@ tf_proto_library( protodeps = [":generic_tree_model"], visibility = ["//visibility:public"], ) + +tf_pyclif_proto_library( + name = "generic_tree_model_pyclif", + proto_lib = ":generic_tree_model", + proto_srcfile = "generic_tree_model.proto", + visibility = ["//visibility:public"], +) diff --git a/tensorflow/contrib/decision_trees/proto/generic_tree_model_proto.swig b/tensorflow/contrib/decision_trees/proto/generic_tree_model_proto.swig index d3d201afd5761e7c5c136301c779222bedc68492..cafb9314caee1c4907786b8101e7c71bd7095306 100644 --- a/tensorflow/contrib/decision_trees/proto/generic_tree_model_proto.swig +++ b/tensorflow/contrib/decision_trees/proto/generic_tree_model_proto.swig @@ -2,7 +2,7 @@ %include "net/proto/swig/protofunc.swig" -#ifndef MUST_USE_RESULT +#ifndef ABSL_MUST_USE_RESULT #error Use this file only as a %include or %import after google.swig. #endif diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD index b2c641f8ab3ea23c5135042e4b1223d487ae8cbc..95848af69950bdaa680c41daecd8cbd8f3174f8e 100644 --- a/tensorflow/contrib/distributions/BUILD +++ b/tensorflow/contrib/distributions/BUILD @@ -60,6 +60,7 @@ py_library( "//tensorflow/python:nn", "//tensorflow/python:nn_ops", "//tensorflow/python:random_ops", + "//tensorflow/python:spectral_ops", "//tensorflow/python:state_ops", "//tensorflow/python:tensor_util", "//tensorflow/python:util", @@ -437,6 +438,7 @@ cuda_py_test( "//tensorflow/python:framework", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework_test_lib", + "//tensorflow/python:spectral_ops_test_util", "//tensorflow/python:math_ops", "//tensorflow/python:nn_ops", "//tensorflow/python:platform_test", diff --git a/tensorflow/contrib/distributions/__init__.py b/tensorflow/contrib/distributions/__init__.py index 66827179e9fa1bea852f55246c263c4696cf3bdc..7b401e178f35fe56e4eb461936565f5c630ec4cf 100644 --- a/tensorflow/contrib/distributions/__init__.py +++ b/tensorflow/contrib/distributions/__init__.py @@ -159,6 +159,10 @@ _allowed_symbols = [ 'assign_log_moving_mean_exp', 'moving_mean_variance', 'estimator_head_distribution_regression', + 'quadrature_scheme_softmaxnormal_gauss_hermite', + 'quadrature_scheme_softmaxnormal_quantiles', + 'quadrature_scheme_lognormal_gauss_hermite', + 'quadrature_scheme_lognormal_quantiles', ] remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/masked_autoregressive_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/masked_autoregressive_test.py index 25a9b6f5fe2ed6d218d6b44650fce17fa89c0664..dcfb0eb05185d36d96947905c2eb91b2201aece1 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/masked_autoregressive_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/masked_autoregressive_test.py @@ -22,9 +22,9 @@ import numpy as np from tensorflow.contrib.distributions.python.ops import test_util from tensorflow.contrib.distributions.python.ops.bijectors.invert import Invert +from tensorflow.contrib.distributions.python.ops.bijectors.masked_autoregressive import _gen_mask from tensorflow.contrib.distributions.python.ops.bijectors.masked_autoregressive import masked_autoregressive_default_template from tensorflow.contrib.distributions.python.ops.bijectors.masked_autoregressive import MaskedAutoregressiveFlow -from tensorflow.contrib.distributions.python.ops.bijectors.masked_autoregressive_impl import _gen_mask from tensorflow.python.framework import constant_op from tensorflow.python.ops import array_ops from tensorflow.python.ops import variables @@ -149,5 +149,17 @@ class MaskedAutoregressiveFlowShiftOnlyTest(MaskedAutoregressiveFlowTest): } +class MaskedAutoregressiveFlowUnrollLoopTest(MaskedAutoregressiveFlowTest): + + @property + def _autoregressive_flow_kwargs(self): + return { + "shift_and_log_scale_fn": masked_autoregressive_default_template( + hidden_layers=[2], shift_only=False), + "is_constant_jacobian": False, + "unroll_loop": True, + } + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/reshape_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/reshape_test.py index 38b3a23c2d684a6f89b7c4be4a763c649bf4de15..49451446b56d290f130c5db90c13b94974d92dc9 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/reshape_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/reshape_test.py @@ -28,8 +28,19 @@ from tensorflow.python.ops.distributions.bijector_test_util import assert_biject from tensorflow.python.platform import test -class ReshapeBijectorTest(test.TestCase): - """Tests correctness of the reshape transformation.""" +class _ReshapeBijectorTest(object): + """Base class for testing the reshape transformation. + + Methods defined in this class call a method self.build_shapes() that + is implemented by subclasses defined below, returning respectively + ReshapeBijectorTestStatic: static shapes, + ReshapeBijectorTestDynamic: shape placeholders of known ndims, and + ReshapeBijectorTestDynamicNdims: shape placeholders of unspecified ndims, + so that each test in this base class is automatically run over all + three cases. The subclasses also implement assertRaisesError to test + for either Python exceptions (in the case of static shapes) or + TensorFlow op errors (dynamic shapes). + """ def setUp(self): self._rng = np.random.RandomState(42) @@ -40,9 +51,10 @@ class ReshapeBijectorTest(test.TestCase): expected_y = np.reshape(expected_x, [4, 6]) with self.test_session() as sess: + shape_in, shape_out, feed_dict = self.build_shapes([3, 2], [6,]) bijector = Reshape( - event_shape_out=[6,], - event_shape_in=[3, 2], + event_shape_out=shape_out, + event_shape_in=shape_in, validate_args=True) (x_, y_, @@ -52,66 +64,23 @@ class ReshapeBijectorTest(test.TestCase): bijector.forward(expected_x), bijector.forward_log_det_jacobian(expected_x), bijector.inverse_log_det_jacobian(expected_y), - )) + ), feed_dict=feed_dict) self.assertEqual("reshape", bijector.name) self.assertAllClose(expected_y, y_, rtol=1e-6, atol=0) self.assertAllClose(expected_x, x_, rtol=1e-6, atol=0) self.assertAllClose(0., fldj_, rtol=1e-6, atol=0) self.assertAllClose(0., ildj_, rtol=1e-6, atol=0) - def testEventShapeDynamicNdims(self): - """Check forward/inverse shape methods with dynamic ndims.""" - - shape_in = tensor_shape.TensorShape([6,]) - shape_in_ph = array_ops.placeholder(dtype=dtypes.int32) - - shape_out = tensor_shape.TensorShape([2, 3]) - shape_out_ph = array_ops.placeholder(dtype=dtypes.int32) - - bijector = Reshape( - event_shape_out=shape_out_ph, - event_shape_in=shape_in_ph, validate_args=True) - - # using the _tensor methods, we should always get a fully-specified - # result since these are evaluated at graph runtime. - with self.test_session() as sess: - (shape_out_, - shape_in_) = sess.run(( - bijector.forward_event_shape_tensor(shape_in), - bijector.inverse_event_shape_tensor(shape_out), - ), feed_dict={ - shape_in_ph: shape_in, - shape_out_ph: shape_out, - }) - self.assertAllEqual(shape_out, shape_out_) - self.assertAllEqual(shape_in, shape_in_) - - def testEventShapeDynamic(self): - """Check shape methods with static ndims but dynamic shape.""" - - shape_in = tensor_shape.TensorShape([6,]) - shape_in_partial = tensor_shape.TensorShape([None,]) - shape_in_ph = array_ops.placeholder( - shape=[1,], dtype=dtypes.int32) - - shape_out = tensor_shape.TensorShape([2, 3]) - shape_out_partial = tensor_shape.TensorShape([None, None]) - shape_out_ph = array_ops.placeholder( - shape=[2,], dtype=dtypes.int32) + def testEventShapeTensor(self): + """Test event_shape_tensor methods when even ndims may be dynamic.""" + shape_in_static = [2, 3] + shape_out_static = [6,] + shape_in, shape_out, feed_dict = self.build_shapes(shape_in_static, + shape_out_static) bijector = Reshape( - event_shape_out=shape_out_ph, - event_shape_in=shape_in_ph, - validate_args=True) - - # if event shapes are not statically available, should - # return partially-specified TensorShapes. - self.assertAllEqual( - bijector.forward_event_shape(shape_in).as_list(), - shape_out_partial.as_list()) - self.assertAllEqual( - bijector.inverse_event_shape(shape_out).as_list(), - shape_in_partial.as_list()) + event_shape_out=shape_out, + event_shape_in=shape_in, validate_args=True) # using the _tensor methods, we should always get a fully-specified # result since these are evaluated at graph runtime. @@ -120,42 +89,9 @@ class ReshapeBijectorTest(test.TestCase): shape_in_) = sess.run(( bijector.forward_event_shape_tensor(shape_in), bijector.inverse_event_shape_tensor(shape_out), - ), feed_dict={ - shape_in_ph: shape_in, - shape_out_ph: shape_out, - }) - self.assertAllEqual(shape_out, shape_out_) - self.assertAllEqual(shape_in, shape_in_) - - def testEventShapeStatic(self): - """Check shape methods when shape is statically known.""" - - shape_in = tensor_shape.TensorShape([6,]) - shape_out = tensor_shape.TensorShape([2, 3]) - - bijector_static = Reshape( - event_shape_out=shape_out, - event_shape_in=shape_in, - validate_args=True) - - # test that forward_ and inverse_event_shape do sensible things - # when shapes are statically known. - self.assertEqual( - bijector_static.forward_event_shape(shape_in), - shape_out) - self.assertEqual( - bijector_static.inverse_event_shape(shape_out), - shape_in) - - with self.test_session() as sess: - (shape_out_static_, - shape_in_static_, - ) = sess.run(( - bijector_static.forward_event_shape_tensor(shape_in), - bijector_static.inverse_event_shape_tensor(shape_out), - )) - self.assertAllEqual(shape_out, shape_out_static_) - self.assertAllEqual(shape_in, shape_in_static_) + ), feed_dict=feed_dict) + self.assertAllEqual(shape_out_static, shape_out_) + self.assertAllEqual(shape_in_static, shape_in_) def testScalarReshape(self): """Test reshaping to and from a scalar shape ().""" @@ -166,11 +102,11 @@ class ReshapeBijectorTest(test.TestCase): expected_x_scalar = np.random.randn(1,) expected_y_scalar = expected_x_scalar[0] + shape_in, shape_out, feed_dict = self.build_shapes([], [1,]) with self.test_session() as sess: bijector = Reshape( - event_shape_out=[], - event_shape_in=[1,], validate_args=True) - + event_shape_out=shape_in, + event_shape_in=shape_out, validate_args=True) (x_, y_, x_scalar_, @@ -180,53 +116,178 @@ class ReshapeBijectorTest(test.TestCase): bijector.forward(expected_x), bijector.inverse(expected_y_scalar), bijector.forward(expected_x_scalar), - )) + ), feed_dict=feed_dict) self.assertAllClose(expected_y, y_, rtol=1e-6, atol=0) self.assertAllClose(expected_x, x_, rtol=1e-6, atol=0) self.assertAllClose(expected_y_scalar, y_scalar_, rtol=1e-6, atol=0) self.assertAllClose(expected_x_scalar, x_scalar_, rtol=1e-6, atol=0) - def testRaisesOpError(self): - x1 = np.random.randn(4, 2, 3) - x2 = np.random.randn(4, 3, 2) - x3 = np.random.randn(4, 5, 1, 1) + def testMultipleUnspecifiedDimensionsOpError(self): with self.test_session() as sess: - shape_in_ph = array_ops.placeholder(shape=[2,], dtype=dtypes.int32) - shape_out_ph = array_ops.placeholder(shape=[3,], dtype=dtypes.int32) + shape_in, shape_out, feed_dict = self.build_shapes([2, 3], [4, -1, -1,]) bijector = Reshape( - event_shape_out=shape_out_ph, - event_shape_in=shape_in_ph, + event_shape_out=shape_out, + event_shape_in=shape_in, validate_args=True) - with self.assertRaisesOpError( + with self.assertRaisesError( + "elements must have at most one `-1`."): + sess.run(bijector.forward_event_shape_tensor(shape_in), + feed_dict=feed_dict) + + def testInvalidDimensionsOpError(self): + + with self.test_session() as sess: + + shape_in, shape_out, feed_dict = self.build_shapes([2, 3], [1, 2, -2,]) + bijector = Reshape( + event_shape_out=shape_out, + event_shape_in=shape_in, + validate_args=True) + + with self.assertRaisesError( + "elements must be either positive integers or `-1`."): + sess.run(bijector.forward_event_shape_tensor(shape_in), + feed_dict=feed_dict) + + def testValidButNonMatchingInputOpError(self): + x = np.random.randn(4, 3, 2) + + with self.test_session() as sess: + shape_in, shape_out, feed_dict = self.build_shapes([2, 3], [1, 6, 1,]) + bijector = Reshape( + event_shape_out=shape_out, + event_shape_in=shape_in, + validate_args=True) + + # Here we pass in a tensor (x) whose shape is compatible with + # the output shape, so tf.reshape will throw no error, but + # doesn't match the expected input shape. + with self.assertRaisesError( "Input `event_shape` does not match `event_shape_in`."): - sess.run(bijector.forward(x2), - feed_dict={shape_out_ph: [1, 6, 1], - shape_in_ph: [2, 3]}) + sess.run(bijector.forward(x), + feed_dict=feed_dict) - with self.assertRaisesOpError( - "event_shape_out entries must be positive."): - sess.run(bijector.forward(x1), - feed_dict={shape_out_ph: [-1, -1, 6], - shape_in_ph: [2, 3]}) + def testValidButNonMatchingInputPartiallySpecifiedOpError(self): + x = np.random.randn(4, 3, 2) + + with self.test_session() as sess: + shape_in, shape_out, feed_dict = self.build_shapes([2, -1], [1, 6, 1,]) + bijector = Reshape( + event_shape_out=shape_out, + event_shape_in=shape_in, + validate_args=True) + + with self.assertRaisesError( + "Input `event_shape` does not match `event_shape_in`."): + sess.run(bijector.forward(x), + feed_dict=feed_dict) + + def testInputOutputMismatchOpError(self): + x1 = np.random.randn(4, 2, 3) + x2 = np.random.randn(4, 1, 1, 5) + + with self.test_session() as sess: + shape_in, shape_out, fd_mismatched = self.build_shapes([2, 3], + [1, 1, 5]) + bijector = Reshape( + event_shape_out=shape_out, + event_shape_in=shape_in, + validate_args=True) # test that *all* methods check basic assertions - fd_mismatched = {shape_out_ph: [1, 1, 5], shape_in_ph: [2, 3]} - with self.assertRaisesOpError( - "Input/output `event_size`s do not match."): + with self.assertRaisesError( + "Input to reshape is a tensor with"): sess.run(bijector.forward(x1), feed_dict=fd_mismatched) - with self.assertRaisesOpError( - "Input/output `event_size`s do not match."): - sess.run(bijector.inverse(x3), feed_dict=fd_mismatched) - with self.assertRaisesOpError( - "Input/output `event_size`s do not match."): - sess.run(bijector.inverse_log_det_jacobian(x3), - feed_dict=fd_mismatched) - with self.assertRaisesOpError( - "Input/output `event_size`s do not match."): - sess.run(bijector.forward_log_det_jacobian(x1), - feed_dict=fd_mismatched) + with self.assertRaisesError( + "Input to reshape is a tensor with"): + sess.run(bijector.inverse(x2), feed_dict=fd_mismatched) + + def testOneShapePartiallySpecified(self): + expected_x = np.random.randn(4, 6) + expected_y = np.reshape(expected_x, [4, 2, 3]) + + with self.test_session() as sess: + # one of input/output shapes is partially specified + shape_in, shape_out, feed_dict = self.build_shapes([-1,], [2, 3]) + bijector = Reshape( + event_shape_out=shape_out, + event_shape_in=shape_in, + validate_args=True) + (x_, + y_, + ) = sess.run(( + bijector.inverse(expected_y), + bijector.forward(expected_x), + ), feed_dict=feed_dict) + self.assertAllClose(expected_y, y_, rtol=1e-6, atol=0) + self.assertAllClose(expected_x, x_, rtol=1e-6, atol=0) + + def testBothShapesPartiallySpecified(self): + expected_x = np.random.randn(4, 2, 3) + expected_y = np.reshape(expected_x, [4, 3, 2]) + with self.test_session() as sess: + shape_in, shape_out, feed_dict = self.build_shapes([-1, 3], [-1, 2]) + bijector = Reshape( + event_shape_out=shape_out, + event_shape_in=shape_in, + validate_args=True) + (x_, + y_, + ) = sess.run(( + bijector.inverse(expected_y), + bijector.forward(expected_x), + ), feed_dict=feed_dict) + self.assertAllClose(expected_y, y_, rtol=1e-6, atol=0) + self.assertAllClose(expected_x, x_, rtol=1e-6, atol=0) + + def testDefaultVectorShape(self): + expected_x = np.random.randn(4, 4) + expected_y = np.reshape(expected_x, [4, 2, 2]) + with self.test_session() as sess: + _, shape_out, feed_dict = self.build_shapes([-1,], [-1, 2]) + bijector = Reshape(shape_out, + validate_args=True) + (x_, + y_, + ) = sess.run(( + bijector.inverse(expected_y), + bijector.forward(expected_x), + ), feed_dict=feed_dict) + self.assertAllClose(expected_y, y_, rtol=1e-6, atol=0) + self.assertAllClose(expected_x, x_, rtol=1e-6, atol=0) + + def build_shapes(self, *args, **kwargs): + raise NotImplementedError("Subclass failed to implement `build_shapes`.") + + +class ReshapeBijectorTestStatic(test.TestCase, _ReshapeBijectorTest): + + def build_shapes(self, shape_in, shape_out): + shape_in_static = shape_in + shape_out_static = shape_out + feed_dict = {} + return shape_in_static, shape_out_static, feed_dict + + def assertRaisesError(self, msg): + return self.assertRaisesRegexp(Exception, msg) + + def testEventShape(self): + shape_in_static = tensor_shape.TensorShape([2, 3]) + shape_out_static = tensor_shape.TensorShape([6,]) + bijector = Reshape( + event_shape_out=shape_out_static, + event_shape_in=shape_in_static, validate_args=True) + + # test that forward_ and inverse_event_shape do sensible things + # when shapes are statically known. + self.assertEqual( + bijector.forward_event_shape(shape_in_static), + shape_out_static) + self.assertEqual( + bijector.inverse_event_shape(shape_out_static), + shape_in_static) def testBijectiveAndFinite(self): x = np.random.randn(4, 2, 3) @@ -238,5 +299,32 @@ class ReshapeBijectorTest(test.TestCase): validate_args=True) assert_bijective_and_finite(bijector, x, y, rtol=1e-6, atol=0) + +class ReshapeBijectorTestDynamic(test.TestCase, _ReshapeBijectorTest): + + def build_shapes(self, shape_in, shape_out): + shape_in_ph = array_ops.placeholder(shape=(len(shape_in),), + dtype=dtypes.int32) + shape_out_ph = array_ops.placeholder(shape=(len(shape_out),), + dtype=dtypes.int32) + feed_dict = {shape_in_ph: shape_in, shape_out_ph: shape_out} + return shape_in_ph, shape_out_ph, feed_dict + + def assertRaisesError(self, msg): + return self.assertRaisesOpError(msg) + + +class ReshapeBijectorTestDynamicNdims(test.TestCase, _ReshapeBijectorTest): + + def build_shapes(self, shape_in, shape_out): + shape_in_ph = array_ops.placeholder(shape=None, dtype=dtypes.int32) + shape_out_ph = array_ops.placeholder(shape=None, dtype=dtypes.int32) + feed_dict = {shape_in_ph: shape_in, shape_out_ph: shape_out} + return shape_in_ph, shape_out_ph, feed_dict + + def assertRaisesError(self, msg): + return self.assertRaisesOpError(msg) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py b/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py index 2d74aa1f320149d0f7ef9e9c52b8c7053c2f74d7..a255d4fc890e67180532e342332a8e3f63a869cd 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py @@ -395,5 +395,110 @@ class MixtureStddevTest(test.TestCase): self.assertAllClose(actual_devs, expected_devs) +class _PadTest(object): + + def testNegAxisCorrectness(self): + x_ = np.float32([[1., 2, 3], + [4, 5, 6]]) + value_ = np.float32(0.25) + count_ = np.int32(2) + with self.test_session() as sess: + x = array_ops.placeholder_with_default( + x_, shape=x_.shape if self.is_static_shape else None) + value = (constant_op.constant(value_) if self.is_static_shape + else array_ops.placeholder_with_default(value_, shape=None)) + count = (constant_op.constant(count_) if self.is_static_shape + else array_ops.placeholder_with_default(count_, shape=None)) + + x0_front = distribution_util.pad( + x, axis=-2, value=value, count=count, front=True) + x0_back = distribution_util.pad( + x, axis=-2, count=count, back=True) + x0_both = distribution_util.pad( + x, axis=-2, value=value, front=True, back=True) + + if self.is_static_shape: + self.assertAllEqual([4, 3], x0_front.shape) + self.assertAllEqual([4, 3], x0_back.shape) + self.assertAllEqual([4, 3], x0_both.shape) + + [x0_front_, x0_back_, x0_both_] = sess.run([ + x0_front, x0_back, x0_both]) + + self.assertAllClose( + np.float32([[value_]*3, + [value_]*3, + [1, 2, 3], + [4, 5, 6]]), + x0_front_, atol=0., rtol=1e-6) + self.assertAllClose( + np.float32([[1, 2, 3], + [4, 5, 6], + [0.]*3, + [0.]*3]), + x0_back_, atol=0., rtol=1e-6) + self.assertAllClose( + np.float32([[value_]*3, + [1, 2, 3], + [4, 5, 6], + [value_]*3]), + x0_both_, atol=0., rtol=1e-6) + + def testPosAxisCorrectness(self): + x_ = np.float32([[1., 2, 3], + [4, 5, 6]]) + value_ = np.float32(0.25) + count_ = np.int32(2) + with self.test_session() as sess: + x = array_ops.placeholder_with_default( + x_, shape=x_.shape if self.is_static_shape else None) + value = (constant_op.constant(value_) if self.is_static_shape + else array_ops.placeholder_with_default(value_, shape=None)) + count = (constant_op.constant(count_) if self.is_static_shape + else array_ops.placeholder_with_default(count_, shape=None)) + + x1_front = distribution_util.pad( + x, axis=1, value=value, count=count, front=True) + x1_back = distribution_util.pad( + x, axis=1, count=count, back=True) + x1_both = distribution_util.pad( + x, axis=1, value=value, front=True, back=True) + + if self.is_static_shape: + self.assertAllEqual([2, 5], x1_front.shape) + self.assertAllEqual([2, 5], x1_back.shape) + self.assertAllEqual([2, 5], x1_both.shape) + + [x1_front_, x1_back_, x1_both_] = sess.run([ + x1_front, x1_back, x1_both]) + + self.assertAllClose( + np.float32([[value_]*2 + [1, 2, 3], + [value_]*2 + [4, 5, 6]]), + x1_front_, atol=0., rtol=1e-6) + self.assertAllClose( + np.float32([[1, 2, 3] + [0.]*2, + [4, 5, 6] + [0.]*2]), + x1_back_, atol=0., rtol=1e-6) + self.assertAllClose( + np.float32([[value_, 1, 2, 3, value_], + [value_, 4, 5, 6, value_]]), + x1_both_, atol=0., rtol=1e-6) + + +class PadStaticTest(_PadTest, test.TestCase): + + @property + def is_static_shape(self): + return True + + +class PadDynamicTest(_PadTest, test.TestCase): + + @property + def is_static_shape(self): + return False + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/half_normal_test.py b/tensorflow/contrib/distributions/python/kernel_tests/half_normal_test.py index a7571806f295af4566e57ac4a785bc8774fd31ab..a4e75660083dc2edd1759a3a54e221d9e8a268c3 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/half_normal_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/half_normal_test.py @@ -21,6 +21,7 @@ from __future__ import print_function import importlib import numpy as np +from tensorflow.contrib.distributions.python.ops import half_normal as hn_lib from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -28,7 +29,6 @@ from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import variables -from tensorflow.contrib.distributions.python.ops import half_normal as hn_lib from tensorflow.python.platform import test from tensorflow.python.platform import tf_logging @@ -200,7 +200,7 @@ class HalfNormalTest(test.TestCase): with self.test_session(): scale = np.array([[1.0, 2.0, 3.0]]) halfnorm = hn_lib.HalfNormal(scale=scale) - + # See https://en.wikipedia.org/wiki/Half-normal_distribution for the # entropy formula used here. expected_entropy = 0.5 * np.log(np.pi * scale ** 2.0 / 2.0) + 0.5 diff --git a/tensorflow/contrib/distributions/python/kernel_tests/mixture_same_family_test.py b/tensorflow/contrib/distributions/python/kernel_tests/mixture_same_family_test.py index ece6bc077d9e21502fdfd01300a9d3e9f2c9c380..ff6092fc260660b512e8123823c63e98a023af6d 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/mixture_same_family_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/mixture_same_family_test.py @@ -45,6 +45,17 @@ class MixtureSameFamilyTest(test_util.VectorDistributionTestHelpers, self.assertEqual([4, 5], x.shape) self.assertEqual([4, 5], log_prob_x.shape) + def testSampleAndLogProbBatch(self): + with self.test_session(): + gm = mixture_same_family_lib.MixtureSameFamily( + mixture_distribution=categorical_lib.Categorical(probs=[[0.3, 0.7]]), + components_distribution=normal_lib.Normal( + loc=[[-1., 1]], scale=[[0.1, 0.5]])) + x = gm.sample([4, 5], seed=42) + log_prob_x = gm.log_prob(x) + self.assertEqual([4, 5, 1], x.shape) + self.assertEqual([4, 5, 1], log_prob_x.shape) + def testSampleAndLogProbShapesBroadcastMix(self): mix_probs = np.float32([.3, .7]) bern_probs = np.float32([[.4, .6], [.25, .75]]) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/poisson_lognormal_test.py b/tensorflow/contrib/distributions/python/kernel_tests/poisson_lognormal_test.py index 3c0147b8cf6e1b6a2791e85c0c0997992445fa7e..1035cb00f76d95c7c52c3e812e8bb2868d34b890 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/poisson_lognormal_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/poisson_lognormal_test.py @@ -18,37 +18,40 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import numpy as np - from tensorflow.contrib.distributions.python.ops import poisson_lognormal from tensorflow.contrib.distributions.python.ops import test_util -from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class PoissonLogNormalQuadratureCompoundTest( - test_util.DiscreteScalarDistributionTestHelpers, test.TestCase): +class _PoissonLogNormalQuadratureCompoundTest( + test_util.DiscreteScalarDistributionTestHelpers): """Tests the PoissonLogNormalQuadratureCompoundTest distribution.""" def testSampleProbConsistent(self): with self.test_session() as sess: pln = poisson_lognormal.PoissonLogNormalQuadratureCompound( - loc=-2., - scale=1.1, - quadrature_grid_and_probs=( - np.polynomial.hermite.hermgauss(deg=10)), + loc=array_ops.placeholder_with_default( + -2., + shape=[] if self.static_shape else None), + scale=array_ops.placeholder_with_default( + 1.1, + shape=[] if self.static_shape else None), + quadrature_size=10, validate_args=True) self.run_test_sample_consistent_log_prob( - sess.run, pln, rtol=0.1) + sess.run, pln, batch_size=1, rtol=0.1) def testMeanVariance(self): with self.test_session() as sess: pln = poisson_lognormal.PoissonLogNormalQuadratureCompound( - loc=0., - scale=1., - quadrature_grid_and_probs=( - np.polynomial.hermite.hermgauss(deg=10)), + loc=array_ops.placeholder_with_default( + 0., + shape=[] if self.static_shape else None), + scale=array_ops.placeholder_with_default( + 1., + shape=[] if self.static_shape else None), + quadrature_size=10, validate_args=True) self.run_test_sample_consistent_mean_variance( sess.run, pln, rtol=0.02) @@ -56,21 +59,27 @@ class PoissonLogNormalQuadratureCompoundTest( def testSampleProbConsistentBroadcastScalar(self): with self.test_session() as sess: pln = poisson_lognormal.PoissonLogNormalQuadratureCompound( - loc=[0., -0.5], - scale=1., - quadrature_grid_and_probs=( - np.polynomial.hermite.hermgauss(deg=10)), + loc=array_ops.placeholder_with_default( + [0., -0.5], + shape=[2] if self.static_shape else None), + scale=array_ops.placeholder_with_default( + 1., + shape=[] if self.static_shape else None), + quadrature_size=10, validate_args=True) self.run_test_sample_consistent_log_prob( - sess.run, pln, rtol=0.1, atol=0.01) + sess.run, pln, batch_size=2, rtol=0.1, atol=0.01) def testMeanVarianceBroadcastScalar(self): with self.test_session() as sess: pln = poisson_lognormal.PoissonLogNormalQuadratureCompound( - loc=[0., -0.5], - scale=1., - quadrature_grid_and_probs=( - np.polynomial.hermite.hermgauss(deg=10)), + loc=array_ops.placeholder_with_default( + [0., -0.5], + shape=[2] if self.static_shape else None), + scale=array_ops.placeholder_with_default( + 1., + shape=[] if self.static_shape else None), + quadrature_size=10, validate_args=True) self.run_test_sample_consistent_mean_variance( sess.run, pln, rtol=0.1, atol=0.01) @@ -78,38 +87,46 @@ class PoissonLogNormalQuadratureCompoundTest( def testSampleProbConsistentBroadcastBoth(self): with self.test_session() as sess: pln = poisson_lognormal.PoissonLogNormalQuadratureCompound( - loc=[[0.], [-0.5]], - scale=[[1., 0.9]], - quadrature_grid_and_probs=( - np.polynomial.hermite.hermgauss(deg=10)), + loc=array_ops.placeholder_with_default( + [[0.], [-0.5]], + shape=[2, 1] if self.static_shape else None), + scale=array_ops.placeholder_with_default( + [[1., 0.9]], + shape=[1, 2] if self.static_shape else None), + quadrature_size=10, validate_args=True) self.run_test_sample_consistent_log_prob( - sess.run, pln, rtol=0.1, atol=0.08) + sess.run, pln, batch_size=4, rtol=0.1, atol=0.08) def testMeanVarianceBroadcastBoth(self): with self.test_session() as sess: pln = poisson_lognormal.PoissonLogNormalQuadratureCompound( - loc=[[0.], [-0.5]], - scale=[[1., 0.9]], - quadrature_grid_and_probs=( - np.polynomial.hermite.hermgauss(deg=10)), + loc=array_ops.placeholder_with_default( + [[0.], [-0.5]], + shape=[2, 1] if self.static_shape else None), + scale=array_ops.placeholder_with_default( + [[1., 0.9]], + shape=[1, 2] if self.static_shape else None), + quadrature_size=10, validate_args=True) self.run_test_sample_consistent_mean_variance( sess.run, pln, rtol=0.1, atol=0.01) - def testSampleProbConsistentDynamicQuadrature(self): - with self.test_session() as sess: - qgrid = array_ops.placeholder(dtype=dtypes.float32) - qprobs = array_ops.placeholder(dtype=dtypes.float32) - g, p = np.polynomial.hermite.hermgauss(deg=10) - pln = poisson_lognormal.PoissonLogNormalQuadratureCompound( - loc=-2., - scale=1.1, - quadrature_grid_and_probs=(g, p), - validate_args=True) - self.run_test_sample_consistent_log_prob( - lambda x: sess.run(x, feed_dict={qgrid: g, qprobs: p}), - pln, rtol=0.1) + +class PoissonLogNormalQuadratureCompoundStaticShapeTest( + _PoissonLogNormalQuadratureCompoundTest, test.TestCase): + + @property + def static_shape(self): + return True + + +class PoissonLogNormalQuadratureCompoundDynamicShapeTest( + _PoissonLogNormalQuadratureCompoundTest, test.TestCase): + + @property + def static_shape(self): + return False if __name__ == "__main__": diff --git a/tensorflow/contrib/distributions/python/kernel_tests/sample_stats_test.py b/tensorflow/contrib/distributions/python/kernel_tests/sample_stats_test.py index 595d9f5df755d7defa63d385039bafe4f87aa6ec..4186cf129dbf31724c84133734da3f226817c71a 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/sample_stats_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/sample_stats_test.py @@ -23,11 +23,244 @@ import numpy as np from tensorflow.contrib.distributions.python.ops import sample_stats from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops +from tensorflow.python.ops import spectral_ops_test_util from tensorflow.python.platform import test rng = np.random.RandomState(0) +class _AutoCorrelationTest(object): + + @property + def use_static_shape(self): + raise NotImplementedError("Subclass failed to implement `use_static_shape`") + + @property + def dtype(self): + raise NotImplementedError("Subclass failed to implement `dtype`.") + + def test_constant_sequence_axis_0_max_lags_none_center_false(self): + x_ = np.array([[0., 0., 0.], + [1., 1., 1.]]).astype(self.dtype) + x_ph = array_ops.placeholder_with_default( + input=x_, + shape=x_.shape if self.use_static_shape else None) + with spectral_ops_test_util.fft_kernel_label_map(): + with self.test_session() as sess: + # Setting normalize = True means we divide by zero. + auto_corr = sample_stats.auto_correlation( + x_ph, axis=1, center=False, normalize=False) + if self.use_static_shape: + self.assertEqual((2, 3), auto_corr.shape) + auto_corr_ = sess.run(auto_corr) + self.assertAllClose( + [[0., 0., 0.], + [1., 1., 1.]], auto_corr_) + + def test_constant_sequence_axis_0_max_lags_none_center_true(self): + x_ = np.array([[0., 0., 0.], + [1., 1., 1.]]).astype(self.dtype) + x_ph = array_ops.placeholder_with_default( + input=x_, + shape=x_.shape if self.use_static_shape else None) + with spectral_ops_test_util.fft_kernel_label_map(): + with self.test_session() as sess: + # Setting normalize = True means we divide by zero. + auto_corr = sample_stats.auto_correlation( + x_ph, axis=1, normalize=False, center=True) + if self.use_static_shape: + self.assertEqual((2, 3), auto_corr.shape) + auto_corr_ = sess.run(auto_corr) + self.assertAllClose( + [[0., 0., 0.], + [0., 0., 0.]], auto_corr_) + + def check_results_versus_brute_force( + self, x, axis, max_lags, center, normalize): + """Compute auto-correlation by brute force, then compare to tf result.""" + # Brute for auto-corr -- avoiding fft and transpositions. + axis_len = x.shape[axis] + if max_lags is None: + max_lags = axis_len - 1 + else: + max_lags = min(axis_len - 1, max_lags) + auto_corr_at_lag = [] + if center: + x -= x.mean(axis=axis, keepdims=True) + for m in range(max_lags + 1): + auto_corr_at_lag.append(( + np.take(x, indices=range(0, axis_len - m), axis=axis) * + np.conj(np.take(x, indices=range(m, axis_len), axis=axis)) + ).mean(axis=axis, keepdims=True)) + rxx = np.concatenate(auto_corr_at_lag, axis=axis) + if normalize: + rxx /= np.take(rxx, [0], axis=axis) + + x_ph = array_ops.placeholder_with_default( + x, shape=x.shape if self.use_static_shape else None) + with spectral_ops_test_util.fft_kernel_label_map(): + with self.test_session(): + auto_corr = sample_stats.auto_correlation( + x_ph, axis=axis, max_lags=max_lags, center=center, + normalize=normalize) + if self.use_static_shape: + output_shape = list(x.shape) + output_shape[axis] = max_lags + 1 + self.assertAllEqual(output_shape, auto_corr.shape) + self.assertAllClose(rxx, auto_corr.eval(), rtol=1e-5, atol=1e-5) + + def test_axis_n1_center_false_max_lags_none(self): + x = rng.randn(2, 3, 4).astype(self.dtype) + if self.dtype in [np.complex64]: + x = 1j * rng.randn(2, 3, 4).astype(self.dtype) + self.check_results_versus_brute_force( + x, axis=-1, max_lags=None, center=False, normalize=False) + + def test_axis_n2_center_false_max_lags_none(self): + x = rng.randn(3, 4, 5).astype(self.dtype) + if self.dtype in [np.complex64]: + x = 1j * rng.randn(3, 4, 5).astype(self.dtype) + self.check_results_versus_brute_force( + x, axis=-2, max_lags=None, center=False, normalize=False) + + def test_axis_n1_center_false_max_lags_none_normalize_true(self): + x = rng.randn(2, 3, 4).astype(self.dtype) + if self.dtype in [np.complex64]: + x = 1j * rng.randn(2, 3, 4).astype(self.dtype) + self.check_results_versus_brute_force( + x, axis=-1, max_lags=None, center=False, normalize=True) + + def test_axis_n2_center_false_max_lags_none_normalize_true(self): + x = rng.randn(3, 4, 5).astype(self.dtype) + if self.dtype in [np.complex64]: + x = 1j * rng.randn(3, 4, 5).astype(self.dtype) + self.check_results_versus_brute_force( + x, axis=-2, max_lags=None, center=False, normalize=True) + + def test_axis_0_center_true_max_lags_none(self): + x = rng.randn(3, 4, 5).astype(self.dtype) + if self.dtype in [np.complex64]: + x = 1j * rng.randn(3, 4, 5).astype(self.dtype) + self.check_results_versus_brute_force( + x, axis=0, max_lags=None, center=True, normalize=False) + + def test_axis_2_center_true_max_lags_1(self): + x = rng.randn(3, 4, 5).astype(self.dtype) + if self.dtype in [np.complex64]: + x = 1j * rng.randn(3, 4, 5).astype(self.dtype) + self.check_results_versus_brute_force( + x, axis=2, max_lags=1, center=True, normalize=False) + + def test_axis_2_center_true_max_lags_100(self): + # There are less than 100 elements in axis 2, so expect we get back an array + # the same size as x, despite having asked for 100 lags. + x = rng.randn(3, 4, 5).astype(self.dtype) + if self.dtype in [np.complex64]: + x = 1j * rng.randn(3, 4, 5).astype(self.dtype) + self.check_results_versus_brute_force( + x, axis=2, max_lags=100, center=True, normalize=False) + + def test_long_orthonormal_sequence_has_corr_length_0(self): + l = 10000 + x = rng.randn(l).astype(self.dtype) + x_ph = array_ops.placeholder_with_default( + x, shape=(l,) if self.use_static_shape else None) + with spectral_ops_test_util.fft_kernel_label_map(): + with self.test_session(): + rxx = sample_stats.auto_correlation( + x_ph, max_lags=l // 2, center=True, normalize=False) + if self.use_static_shape: + self.assertAllEqual((l // 2 + 1,), rxx.shape) + rxx_ = rxx.eval() + # OSS CPU FFT has some accuracy issues is not the most accurate. + # So this tolerance is a bit bad. + self.assertAllClose(1., rxx_[0], rtol=0.05) + # The maximal error in the rest of the sequence is not great. + self.assertAllClose(np.zeros(l // 2), rxx_[1:], atol=0.1) + # The mean error in the rest is ok, actually 0.008 when I tested it. + self.assertLess(np.abs(rxx_[1:]).mean(), 0.02) + + def test_step_function_sequence(self): + # x jumps to new random value every 10 steps. So correlation length = 10. + x = (rng.randint(-10, 10, size=(1000, 1)) + * np.ones((1, 10))).ravel().astype(self.dtype) + x_ph = array_ops.placeholder_with_default( + x, shape=(1000 * 10,) if self.use_static_shape else None) + with spectral_ops_test_util.fft_kernel_label_map(): + with self.test_session(): + rxx = sample_stats.auto_correlation( + x_ph, max_lags=1000 * 10 // 2, center=True, normalize=False) + if self.use_static_shape: + self.assertAllEqual((1000 * 10 // 2 + 1,), rxx.shape) + rxx_ = rxx.eval() + rxx_ /= rxx_[0] + # Expect positive correlation for the first 10 lags, then significantly + # smaller negative. + self.assertGreater(rxx_[:10].min(), 0) + self.assertGreater(rxx_[9], 5 * rxx_[10:20].mean()) + # RXX should be decreasing for the first 10 lags. + diff = np.diff(rxx_) + self.assertLess(diff[:10].max(), 0) + + def test_normalization(self): + l = 10000 + x = 3 * rng.randn(l).astype(self.dtype) + x_ph = array_ops.placeholder_with_default( + x, shape=(l,) if self.use_static_shape else None) + with spectral_ops_test_util.fft_kernel_label_map(): + with self.test_session(): + rxx = sample_stats.auto_correlation( + x_ph, max_lags=l // 2, center=True, normalize=True) + if self.use_static_shape: + self.assertAllEqual((l // 2 + 1,), rxx.shape) + rxx_ = rxx.eval() + # Note that RXX[0] = 1, despite the fact that E[X^2] = 9, and this is + # due to normalize=True. + # OSS CPU FFT has some accuracy issues is not the most accurate. + # So this tolerance is a bit bad. + self.assertAllClose(1., rxx_[0], rtol=0.05) + # The maximal error in the rest of the sequence is not great. + self.assertAllClose(np.zeros(l // 2), rxx_[1:], atol=0.1) + # The mean error in the rest is ok, actually 0.008 when I tested it. + self.assertLess(np.abs(rxx_[1:]).mean(), 0.02) + + +class AutoCorrelationTestStaticShapeFloat32(test.TestCase, + _AutoCorrelationTest): + + @property + def dtype(self): + return np.float32 + + @property + def use_static_shape(self): + return True + + +class AutoCorrelationTestStaticShapeComplex64(test.TestCase, + _AutoCorrelationTest): + + @property + def dtype(self): + return np.complex64 + + @property + def use_static_shape(self): + return True + + +class AutoCorrelationTestDynamicShapeFloat32(test.TestCase, + _AutoCorrelationTest): + + @property + def dtype(self): + return np.float32 + + @property + def use_static_shape(self): + return False + + class PercentileTestWithLowerInterpolation(test.TestCase): _interpolation = "lower" diff --git a/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py b/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py index 103d8e186221e879d1734a097114708429f725bd..cbaf74d3f66253ae5727e1ba579e2d49235b748e 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py @@ -200,6 +200,27 @@ class TransformedDistributionTest(test.TestCase): self.assertAllEqual([2], multi_logit_normal.event_shape) self.assertAllEqual([2], multi_logit_normal.event_shape_tensor().eval()) + def testCastLogDetJacobian(self): + """Test log_prob when Jacobian and log_prob dtypes do not match.""" + + with self.test_session(): + # Create an identity bijector whose jacobians have dtype int32 + int_identity = bs.Inline( + forward_fn=array_ops.identity, + inverse_fn=array_ops.identity, + inverse_log_det_jacobian_fn=lambda x: math_ops.cast(0, dtypes.int32), + forward_log_det_jacobian_fn=lambda x: math_ops.cast(0, dtypes.int32), + is_constant_jacobian=True) + normal = self._cls()( + distribution=ds.Normal(loc=0., scale=1.), + bijector=int_identity, + validate_args=True) + + y = normal.sample() + normal.log_prob(y).eval() + normal.prob(y).eval() + normal.entropy().eval() + def testEntropy(self): with self.test_session(): shift = np.array([[-1, 0, 1], [-1, -2, -3]], dtype=np.float32) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/vector_diffeomixture_test.py b/tensorflow/contrib/distributions/python/kernel_tests/vector_diffeomixture_test.py index de4a221f7badca8267a81d612a57137c676ff052..d292b04665e34196670ee4f1c1655f805e04e06a 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/vector_diffeomixture_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/vector_diffeomixture_test.py @@ -21,9 +21,7 @@ from __future__ import print_function import numpy as np from tensorflow.contrib.distributions.python.ops import test_util -from tensorflow.contrib.distributions.python.ops import vector_diffeomixture as vector_diffeomixture_lib -from tensorflow.python.framework import dtypes -from tensorflow.python.ops import array_ops +from tensorflow.contrib.distributions.python.ops import vector_diffeomixture as vdm_lib from tensorflow.python.ops.distributions import normal as normal_lib from tensorflow.python.ops.linalg import linear_operator_diag as linop_diag_lib from tensorflow.python.ops.linalg import linear_operator_identity as linop_identity_lib @@ -37,7 +35,7 @@ class VectorDiffeomixtureTest( def testSampleProbConsistentBroadcastMixNoBatch(self): with self.test_session() as sess: dims = 4 - vdm = vector_diffeomixture_lib.VectorDiffeomixture( + vdm = vdm_lib.VectorDiffeomixture( mix_loc=[[0.], [1.]], mix_scale=[1.], distribution=normal_lib.Normal(0., 1.), @@ -54,18 +52,19 @@ class VectorDiffeomixtureTest( diag=np.linspace(2.5, 3.5, dims, dtype=np.float32), is_positive_definite=True), ], + quadrature_size=8, validate_args=True) # Ball centered at component0's mean. self.run_test_sample_consistent_log_prob( - sess.run, vdm, radius=2., center=0., rtol=0.005) + sess.run, vdm, radius=2., center=0., rtol=0.015) # Larger ball centered at component1's mean. self.run_test_sample_consistent_log_prob( - sess.run, vdm, radius=4., center=2., rtol=0.005) + sess.run, vdm, radius=4., center=2., rtol=0.015) def testSampleProbConsistentBroadcastMixNonStandardBase(self): with self.test_session() as sess: dims = 4 - vdm = vector_diffeomixture_lib.VectorDiffeomixture( + vdm = vdm_lib.VectorDiffeomixture( mix_loc=[[0.], [1.]], mix_scale=[1.], distribution=normal_lib.Normal(1., 1.5), @@ -82,18 +81,19 @@ class VectorDiffeomixtureTest( diag=np.linspace(2.5, 3.5, dims, dtype=np.float32), is_positive_definite=True), ], + quadrature_size=8, validate_args=True) # Ball centered at component0's mean. self.run_test_sample_consistent_log_prob( - sess.run, vdm, radius=2., center=1., rtol=0.006) + sess.run, vdm, radius=2., center=1., rtol=0.015) # Larger ball centered at component1's mean. self.run_test_sample_consistent_log_prob( - sess.run, vdm, radius=4., center=3., rtol=0.009) + sess.run, vdm, radius=4., center=3., rtol=0.01) def testSampleProbConsistentBroadcastMixBatch(self): with self.test_session() as sess: dims = 4 - vdm = vector_diffeomixture_lib.VectorDiffeomixture( + vdm = vdm_lib.VectorDiffeomixture( mix_loc=[[0.], [1.]], mix_scale=[1.], distribution=normal_lib.Normal(0., 1.), @@ -113,18 +113,19 @@ class VectorDiffeomixtureTest( ]), is_positive_definite=True), ], + quadrature_size=8, validate_args=True) # Ball centered at component0's mean. self.run_test_sample_consistent_log_prob( - sess.run, vdm, radius=2., center=0., rtol=0.005) + sess.run, vdm, radius=2., center=0., rtol=0.01) # Larger ball centered at component1's mean. self.run_test_sample_consistent_log_prob( - sess.run, vdm, radius=4., center=2., rtol=0.005) + sess.run, vdm, radius=4., center=2., rtol=0.01) def testMeanCovarianceNoBatch(self): with self.test_session() as sess: dims = 3 - vdm = vector_diffeomixture_lib.VectorDiffeomixture( + vdm = vdm_lib.VectorDiffeomixture( mix_loc=[[0.], [4.]], mix_scale=[10.], distribution=normal_lib.Normal(0., 1.), @@ -141,14 +142,15 @@ class VectorDiffeomixtureTest( diag=np.linspace(2.5, 3.5, dims, dtype=np.float32), is_positive_definite=True), ], + quadrature_size=8, validate_args=True) self.run_test_sample_consistent_mean_covariance( - sess.run, vdm, rtol=0.02, cov_rtol=0.06) + sess.run, vdm, rtol=0.02, cov_rtol=0.08) def testMeanCovarianceNoBatchUncenteredNonStandardBase(self): with self.test_session() as sess: dims = 3 - vdm = vector_diffeomixture_lib.VectorDiffeomixture( + vdm = vdm_lib.VectorDiffeomixture( mix_loc=[[0.], [4.]], mix_scale=[10.], distribution=normal_lib.Normal(-1., 1.5), @@ -165,6 +167,7 @@ class VectorDiffeomixtureTest( diag=np.linspace(2.5, 3.5, dims, dtype=np.float32), is_positive_definite=True), ], + quadrature_size=8, validate_args=True) self.run_test_sample_consistent_mean_covariance( sess.run, vdm, num_samples=int(1e6), rtol=0.01, cov_atol=0.025) @@ -172,7 +175,7 @@ class VectorDiffeomixtureTest( def testMeanCovarianceBatch(self): with self.test_session() as sess: dims = 3 - vdm = vector_diffeomixture_lib.VectorDiffeomixture( + vdm = vdm_lib.VectorDiffeomixture( mix_loc=[[0.], [4.]], mix_scale=[10.], distribution=normal_lib.Normal(0., 1.), @@ -192,18 +195,16 @@ class VectorDiffeomixtureTest( ]), is_positive_definite=True), ], + quadrature_size=8, validate_args=True) self.run_test_sample_consistent_mean_covariance( - sess.run, vdm, rtol=0.02, cov_rtol=0.06) + sess.run, vdm, rtol=0.02, cov_rtol=0.07) - def testSampleProbConsistentDynamicQuadrature(self): + def testSampleProbConsistentQuadrature(self): with self.test_session() as sess: - qgrid = array_ops.placeholder(dtype=dtypes.float32) - qprobs = array_ops.placeholder(dtype=dtypes.float32) - g, p = np.polynomial.hermite.hermgauss(deg=8) dims = 4 - vdm = vector_diffeomixture_lib.VectorDiffeomixture( - mix_loc=[[0.], [1.]], + vdm = vdm_lib.VectorDiffeomixture( + mix_loc=[0.], mix_scale=[1.], distribution=normal_lib.Normal(0., 1.), loc=[ @@ -219,15 +220,14 @@ class VectorDiffeomixtureTest( diag=np.linspace(2.5, 3.5, dims, dtype=np.float32), is_positive_definite=True), ], - quadrature_grid_and_probs=(g, p), + quadrature_size=3, validate_args=True) # Ball centered at component0's mean. - sess_run_fn = lambda x: sess.run(x, feed_dict={qgrid: g, qprobs: p}) self.run_test_sample_consistent_log_prob( - sess_run_fn, vdm, radius=2., center=0., rtol=0.005) + sess.run, vdm, radius=2., center=0., rtol=0.015) # Larger ball centered at component1's mean. self.run_test_sample_consistent_log_prob( - sess_run_fn, vdm, radius=4., center=2., rtol=0.005) + sess.run, vdm, radius=4., center=2., rtol=0.005) # TODO(jvdillon): We've tested that (i) .sample and .log_prob are consistent, # (ii) .mean, .stddev etc... and .sample are consistent. However, we haven't diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/absolute_value.py b/tensorflow/contrib/distributions/python/ops/bijectors/absolute_value.py index 6049419818e18c54209f0be95d41fcecf6627b7e..0fe9f6aa78fbe845b99d0668f075b0162ec2a9f7 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/absolute_value.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/absolute_value.py @@ -18,12 +18,117 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.distributions.python.ops.bijectors.absolute_value_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import bijector -_allowed_symbols = ["AbsoluteValue"] +__all__ = [ + "AbsoluteValue", +] -remove_undocumented(__name__, _allowed_symbols) + +class AbsoluteValue(bijector.Bijector): + """Computes `Y = g(X) = Abs(X)`, element-wise. + + This non-injective bijector allows for transformations of scalar distributions + with the absolute value function, which maps `(-inf, inf)` to `[0, inf)`. + + * For `y in (0, inf)`, `AbsoluteValue.inverse(y)` returns the set inverse + `{x in (-inf, inf) : |x| = y}` as a tuple, `-y, y`. + * `AbsoluteValue.inverse(0)` returns `0, 0`, which is not the set inverse + (the set inverse is the singleton `{0}`), but "works" in conjunction with + `TransformedDistribution` to produce a left semi-continuous pdf. + * For `y < 0`, `AbsoluteValue.inverse(y)` happily returns the + wrong thing, `-y, y`. This is done for efficiency. If + `validate_args == True`, `y < 0` will raise an exception. + + + ```python + tfd = tf.contrib.distributions + + abs = tfd.bijectors.AbsoluteValue() + + abs.forward([-1., 0., 1.]) + ==> [1., 0., 1.] + + abs.inverse(1.) + ==> [-1., 1.] + + # The |dX/dY| is constant, == 1. So Log|dX/dY| == 0. + abs.inverse_log_det_jacobian(1.) + ==> [0., 0.] + + # Special case handling of 0. + abs.inverse(0.) + ==> [0., 0.] + + abs.inverse_log_det_jacobian(0.) + ==> [0., 0.] + ``` + + """ + + def __init__(self, event_ndims=0, validate_args=False, name="absolute_value"): + """Instantiates the `AbsoluteValue` bijector. + + Args: + event_ndims: Python scalar indicating the number of dimensions associated + with a particular draw from the distribution. Currently only zero is + supported. + validate_args: Python `bool` indicating whether arguments should be + checked for correctness, in particular whether inputs to `inverse` and + `inverse_log_det_jacobian` are non-negative. + name: Python `str` name given to ops managed by this object. + + Raises: + ValueError: If `event_ndims` is not zero. + """ + self._graph_parents = [] + self._name = name + + event_ndims = ops.convert_to_tensor(event_ndims, name="event_ndims") + event_ndims_const = tensor_util.constant_value(event_ndims) + if event_ndims_const is not None and event_ndims_const not in (0,): + raise ValueError("event_ndims(%s) was not 0" % event_ndims_const) + else: + if validate_args: + event_ndims = control_flow_ops.with_dependencies( + [check_ops.assert_equal( + event_ndims, 0, message="event_ndims was not 0")], + event_ndims) + + with self._name_scope("init"): + super(AbsoluteValue, self).__init__( + event_ndims=event_ndims, + validate_args=validate_args, + name=name) + + def _forward(self, x): + return math_ops.abs(x) + + def _inverse(self, y): + if self.validate_args: + y = control_flow_ops.with_dependencies( + [check_ops.assert_non_negative(y, message="Argument y was negative")], + y) + return -y, y + + def _inverse_log_det_jacobian(self, y): + # If event_ndims = 2, + # F^{-1}(y) = (-y, y), so DF^{-1}(y) = (-1, 1), + # so Log|DF^{-1}(y)| = Log[1, 1] = [0, 0]. + batch_shape = array_ops.shape(y)[:array_ops.rank(y) - self.event_ndims] + zeros = array_ops.zeros(batch_shape, dtype=y.dtype) + if self.validate_args: + zeros = control_flow_ops.with_dependencies( + [check_ops.assert_non_negative(y, message="Argument y was negative")], + zeros) + return zeros, zeros + + @property + def _is_injective(self): + return False diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/absolute_value_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/absolute_value_impl.py deleted file mode 100644 index b84502003ab6c0c4ffdda21eea162f441509e1fa..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distributions/python/ops/bijectors/absolute_value_impl.py +++ /dev/null @@ -1,132 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""AbsoluteValue bijector.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_util -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import check_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops.distributions import bijector - -__all__ = [ - "AbsoluteValue", -] - - -class AbsoluteValue(bijector.Bijector): - """Computes `Y = g(X) = Abs(X)`, element-wise. - - This non-injective bijector allows for transformations of scalar distributions - with the absolute value function, which maps `(-inf, inf)` to `[0, inf)`. - - * For `y in (0, inf)`, `AbsoluteValue.inverse(y)` returns the set inverse - `{x in (-inf, inf) : |x| = y}` as a tuple, `-y, y`. - * `AbsoluteValue.inverse(0)` returns `0, 0`, which is not the set inverse - (the set inverse is the singleton `{0}`), but "works" in conjunction with - `TransformedDistribution` to produce a left semi-continuous pdf. - * For `y < 0`, `AbsoluteValue.inverse(y)` happily returns the - wrong thing, `-y, y`. This is done for efficiency. If - `validate_args == True`, `y < 0` will raise an exception. - - - ```python - abs = ds.bijectors.AbsoluteValue() - - abs.forward([-1., 0., 1.]) - ==> [1., 0., 1.] - - abs.inverse(1.) - ==> [-1., 1.] - - # The |dX/dY| is constant, == 1. So Log|dX/dY| == 0. - abs.inverse_log_det_jacobian(1.) - ==> [0., 0.] - - # Special case handling of 0. - abs.inverse(0.) - ==> [0., 0.] - - abs.inverse_log_det_jacobian(0.) - ==> [0., 0.] - ``` - - """ - - def __init__(self, event_ndims=0, validate_args=False, name="absolute_value"): - """Instantiates the `AbsoluteValue` bijector. - - Args: - event_ndims: Python scalar indicating the number of dimensions associated - with a particular draw from the distribution. Currently only zero is - supported. - validate_args: Python `bool` indicating whether arguments should be - checked for correctness, in particular whether inputs to `inverse` and - `inverse_log_det_jacobian` are non-negative. - name: Python `str` name given to ops managed by this object. - - Raises: - ValueError: If `event_ndims` is not zero. - """ - self._graph_parents = [] - self._name = name - - event_ndims = ops.convert_to_tensor(event_ndims, name="event_ndims") - event_ndims_const = tensor_util.constant_value(event_ndims) - if event_ndims_const is not None and event_ndims_const not in (0,): - raise ValueError("event_ndims(%s) was not 0" % event_ndims_const) - else: - if validate_args: - event_ndims = control_flow_ops.with_dependencies( - [check_ops.assert_equal( - event_ndims, 0, message="event_ndims was not 0")], - event_ndims) - - with self._name_scope("init"): - super(AbsoluteValue, self).__init__( - event_ndims=event_ndims, - validate_args=validate_args, - name=name) - - def _forward(self, x): - return math_ops.abs(x) - - def _inverse(self, y): - if self.validate_args: - y = control_flow_ops.with_dependencies( - [check_ops.assert_non_negative(y, message="Argument y was negative")], - y) - return -y, y - - def _inverse_log_det_jacobian(self, y): - # If event_ndims = 2, - # F^{-1}(y) = (-y, y), so DF^{-1}(y) = (-1, 1), - # so Log|DF^{-1}(y)| = Log[1, 1] = [0, 0]. - batch_shape = array_ops.shape(y)[:array_ops.rank(y) - self.event_ndims] - zeros = array_ops.zeros(batch_shape, dtype=y.dtype) - if self.validate_args: - zeros = control_flow_ops.with_dependencies( - [check_ops.assert_non_negative(y, message="Argument y was negative")], - zeros) - return zeros, zeros - - @property - def _is_injective(self): - return False diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/affine.py b/tensorflow/contrib/distributions/python/ops/bijectors/affine.py index 940cceff04e77cfc2f7caae5a798d135f7601b95..05bb9c2f9bdf35e222c94db3491157893da64ebd 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/affine.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/affine.py @@ -18,12 +18,386 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.distributions.python.ops.bijectors.affine_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented +from tensorflow.contrib import linalg +from tensorflow.contrib.distributions.python.ops import distribution_util +from tensorflow.contrib.distributions.python.ops.shape import _DistributionShape +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import bijector -_allowed_symbols = ["Affine"] -remove_undocumented(__name__, _allowed_symbols) +__all__ = [ + "Affine", +] + + +def _as_tensor(x, name): + """Convenience to convert to `Tensor` or leave as `None`.""" + return None if x is None else ops.convert_to_tensor(x, name=name) + + +class Affine(bijector.Bijector): + """Compute `Y = g(X; shift, scale) = scale @ X + shift`. + + Here `scale = c * I + diag(D1) + tril(L) + V @ diag(D2) @ V.T`. + + In TF parlance, the `scale` term is logically equivalent to: + + ```python + scale = ( + scale_identity_multiplier * tf.diag(tf.ones(d)) + + tf.diag(scale_diag) + + scale_tril + + scale_perturb_factor @ diag(scale_perturb_diag) @ + tf.transpose([scale_perturb_factor]) + ) + ``` + + The `scale` term is applied without necessarily materializing constituent + matrices, i.e., the matmul is [matrix-free]( + https://en.wikipedia.org/wiki/Matrix-free_methods) when possible. + + Examples: + + ```python + # Y = X + b = Affine() + + # Y = X + shift + b = Affine(shift=[1., 2, 3]) + + # Y = 2 * I @ X.T + shift + b = Affine(shift=[1., 2, 3], + scale_identity_multiplier=2.) + + # Y = tf.diag(d1) @ X.T + shift + b = Affine(shift=[1., 2, 3], + scale_diag=[-1., 2, 1]) # Implicitly 3x3. + + # Y = (I + v * v.T) @ X.T + shift + b = Affine(shift=[1., 2, 3], + scale_perturb_factor=[[1., 0], + [0, 1], + [1, 1]]) + + # Y = (diag(d1) + v * diag(d2) * v.T) @ X.T + shift + b = Affine(shift=[1., 2, 3], + scale_diag=[1., 3, 3], # Implicitly 3x3. + scale_perturb_diag=[2., 1], # Implicitly 2x2. + scale_perturb_factor=[[1., 0], + [0, 1], + [1, 1]]) + + ``` + + """ + + def __init__(self, + shift=None, + scale_identity_multiplier=None, + scale_diag=None, + scale_tril=None, + scale_perturb_factor=None, + scale_perturb_diag=None, + event_ndims=1, + validate_args=False, + name="affine"): + """Instantiates the `Affine` bijector. + + This `Bijector` is initialized with `shift` `Tensor` and `scale` arguments, + giving the forward operation: + + ```none + Y = g(X) = scale @ X + shift + ``` + + where the `scale` term is logically equivalent to: + + ```python + scale = ( + scale_identity_multiplier * tf.diag(tf.ones(d)) + + tf.diag(scale_diag) + + scale_tril + + scale_perturb_factor @ diag(scale_perturb_diag) @ + tf.transpose([scale_perturb_factor]) + ) + ``` + + If none of `scale_identity_multiplier`, `scale_diag`, or `scale_tril` are + specified then `scale += IdentityMatrix`. Otherwise specifying a + `scale` argument has the semantics of `scale += Expand(arg)`, i.e., + `scale_diag != None` means `scale += tf.diag(scale_diag)`. + + Args: + shift: Floating-point `Tensor`. If this is set to `None`, no shift is + applied. + scale_identity_multiplier: floating point rank 0 `Tensor` representing a + scaling done to the identity matrix. + When `scale_identity_multiplier = scale_diag = scale_tril = None` then + `scale += IdentityMatrix`. Otherwise no scaled-identity-matrix is added + to `scale`. + scale_diag: Floating-point `Tensor` representing the diagonal matrix. + `scale_diag` has shape [N1, N2, ... k], which represents a k x k + diagonal matrix. + When `None` no diagonal term is added to `scale`. + scale_tril: Floating-point `Tensor` representing the diagonal matrix. + `scale_diag` has shape [N1, N2, ... k, k], which represents a k x k + lower triangular matrix. + When `None` no `scale_tril` term is added to `scale`. + The upper triangular elements above the diagonal are ignored. + scale_perturb_factor: Floating-point `Tensor` representing factor matrix + with last two dimensions of shape `(k, r)`. When `None`, no rank-r + update is added to `scale`. + scale_perturb_diag: Floating-point `Tensor` representing the diagonal + matrix. `scale_perturb_diag` has shape [N1, N2, ... r], which + represents an `r x r` diagonal matrix. When `None` low rank updates will + take the form `scale_perturb_factor * scale_perturb_factor.T`. + event_ndims: Scalar `int` `Tensor` indicating the number of dimensions + associated with a particular draw from the distribution. Must be 0 or 1. + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str` name given to ops managed by this object. + + Raises: + ValueError: if `perturb_diag` is specified but not `perturb_factor`. + TypeError: if `shift` has different `dtype` from `scale` arguments. + """ + self._graph_parents = [] + self._name = name + self._validate_args = validate_args + + # Ambiguous definition of low rank update. + if scale_perturb_diag is not None and scale_perturb_factor is None: + raise ValueError("When scale_perturb_diag is specified, " + "scale_perturb_factor must be specified.") + + # Special case, only handling a scaled identity matrix. We don't know its + # dimensions, so this is special cased. + # We don't check identity_multiplier, since below we set it to 1. if all + # other scale args are None. + self._is_only_identity_multiplier = (scale_tril is None and + scale_diag is None and + scale_perturb_factor is None) + + with self._name_scope("init", values=[ + shift, scale_identity_multiplier, scale_diag, scale_tril, + scale_perturb_diag, scale_perturb_factor]): + event_ndims = ops.convert_to_tensor(event_ndims, name="event_ndims") + event_ndims_const = tensor_util.constant_value(event_ndims) + if event_ndims_const is not None and event_ndims_const not in (0, 1): + raise ValueError("event_ndims(%s) was not 0 or 1" % event_ndims_const) + else: + if validate_args: + # Shape tool will catch if event_ndims is negative. + event_ndims = control_flow_ops.with_dependencies( + [check_ops.assert_less( + event_ndims, 2, message="event_ndims must be 0 or 1")], + event_ndims) + + if event_ndims_const == 0 and not self._is_only_identity_multiplier: + raise ValueError( + "If event_ndims == 0, the only scale argument you can pass is " + "scale_identity_multiplier. All others operate on vectors.") + + # In the absence of `loc` and `scale`, we'll assume `dtype` is `float32`. + dtype = dtypes.float32 + + if shift is not None: + shift = ops.convert_to_tensor(shift, name="shift") + dtype = shift.dtype.base_dtype + self._shift = shift + + # When no args are specified, pretend the scale matrix is the identity + # matrix. + if (self._is_only_identity_multiplier and + scale_identity_multiplier is None): + scale_identity_multiplier = ops.convert_to_tensor(1., dtype=dtype) + + # self._create_scale_operator returns a LinearOperator in all cases + # except if self._is_only_identity_multiplier; in which case it + # returns a scalar Tensor. + scale = self._create_scale_operator( + identity_multiplier=scale_identity_multiplier, + diag=scale_diag, + tril=scale_tril, + perturb_diag=scale_perturb_diag, + perturb_factor=scale_perturb_factor, + shift=shift, + validate_args=validate_args) + + if scale.dtype is not None: + dtype = scale.dtype.base_dtype + + if scale is not None and not self._is_only_identity_multiplier: + if (shift is not None and + shift.dtype.base_dtype != scale.dtype.base_dtype): + raise TypeError( + "shift.dtype({}) is incompatible with scale.dtype({}).".format( + shift.dtype, scale.dtype)) + + if scale.tensor_rank is not None: + batch_ndims = scale.tensor_rank - 2 + else: + batch_ndims = scale.tensor_rank_tensor() - 2 + else: + # We won't need shape inference when scale is None or when scale is a + # scalar. + batch_ndims = 0 + self._scale = scale + self._shaper = _DistributionShape( + batch_ndims=batch_ndims, + event_ndims=event_ndims, + validate_args=validate_args) + super(Affine, self).__init__( + event_ndims=event_ndims, + graph_parents=( + [event_ndims] + + [self._scale] if tensor_util.is_tensor(self._scale) + else self._scale.graph_parents + + [self._shift] if self._shift is not None else []), + is_constant_jacobian=True, + dtype=dtype, + validate_args=validate_args, + name=name) + + def _create_scale_operator(self, identity_multiplier, diag, tril, + perturb_diag, perturb_factor, shift, + validate_args): + """Construct `scale` from various components. + + Args: + identity_multiplier: floating point rank 0 `Tensor` representing a scaling + done to the identity matrix. + diag: Floating-point `Tensor` representing the diagonal matrix. + `scale_diag` has shape [N1, N2, ... k], which represents a k x k + diagonal matrix. + tril: Floating-point `Tensor` representing the diagonal matrix. + `scale_tril` has shape [N1, N2, ... k], which represents a k x k lower + triangular matrix. + perturb_diag: Floating-point `Tensor` representing the diagonal matrix of + the low rank update. + perturb_factor: Floating-point `Tensor` representing factor matrix. + shift: Floating-point `Tensor` representing `shift in `scale @ X + shift`. + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + + Returns: + scale. In the case of scaling by a constant, scale is a + floating point `Tensor`. Otherwise, scale is a `LinearOperator`. + + Raises: + ValueError: if all of `tril`, `diag` and `identity_multiplier` are `None`. + """ + identity_multiplier = _as_tensor(identity_multiplier, "identity_multiplier") + diag = _as_tensor(diag, "diag") + tril = _as_tensor(tril, "tril") + perturb_diag = _as_tensor(perturb_diag, "perturb_diag") + perturb_factor = _as_tensor(perturb_factor, "perturb_factor") + + # If possible, use the low rank update to infer the shape of + # the identity matrix, when scale represents a scaled identity matrix + # with a low rank update. + shape_hint = None + if perturb_factor is not None: + shape_hint = distribution_util.dimension_size(perturb_factor, axis=-2) + + if self._is_only_identity_multiplier: + if validate_args: + return control_flow_ops.with_dependencies( + [check_ops.assert_none_equal( + identity_multiplier, + array_ops.zeros([], identity_multiplier.dtype), + ["identity_multiplier should be non-zero."])], + identity_multiplier) + return identity_multiplier + + scale = distribution_util.make_tril_scale( + loc=shift, + scale_tril=tril, + scale_diag=diag, + scale_identity_multiplier=identity_multiplier, + validate_args=validate_args, + assert_positive=False, + shape_hint=shape_hint) + + if perturb_factor is not None: + return linalg.LinearOperatorLowRankUpdate( + scale, + u=perturb_factor, + diag_update=perturb_diag, + is_diag_update_positive=perturb_diag is None, + is_non_singular=True, # Implied by is_positive_definite=True. + is_self_adjoint=True, + is_positive_definite=True, + is_square=True) + + return scale + + @property + def shift(self): + """The `shift` `Tensor` in `Y = scale @ X + shift`.""" + return self._shift + + @property + def scale(self): + """The `scale` `LinearOperator` in `Y = scale @ X + shift`.""" + return self._scale + + def _forward(self, x): + y = x + if self._is_only_identity_multiplier: + y *= self._scale + if self.shift is not None: + return y + self.shift + return y + y, sample_shape = self._shaper.make_batch_of_event_sample_matrices( + y, expand_batch_dim=False) + with ops.control_dependencies(self._maybe_check_scale() if + self.validate_args else []): + y = self.scale.matmul(y) + y = self._shaper.undo_make_batch_of_event_sample_matrices( + y, sample_shape, expand_batch_dim=False) + if self.shift is not None: + y += self.shift + return y + + def _inverse(self, y): + x = y + if self.shift is not None: + x -= self.shift + if self._is_only_identity_multiplier: + return x / self._scale + + x, sample_shape = self._shaper.make_batch_of_event_sample_matrices( + x, expand_batch_dim=False) + # Solve fails if the op is singular so we may safely skip this assertion. + x = self.scale.solve(x) + x = self._shaper.undo_make_batch_of_event_sample_matrices( + x, sample_shape, expand_batch_dim=False) + return x + + def _inverse_log_det_jacobian(self, y): + return -self._forward_log_det_jacobian(y) + + def _forward_log_det_jacobian(self, x): + if self._is_only_identity_multiplier: + # We don't pad in this case and instead let the fldj be applied + # via broadcast. + event_size = distribution_util.pick_vector( + math_ops.equal(self._shaper.event_ndims, 0), + [1], array_ops.shape(x))[-1] + event_size = math_ops.cast(event_size, dtype=self._scale.dtype) + return math_ops.log(math_ops.abs(self._scale)) * event_size + return self.scale.log_abs_determinant() + + def _maybe_check_scale(self): + try: + return [self.scale.assert_non_singular()] + except NotImplementedError: + pass + return [] diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/affine_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/affine_impl.py deleted file mode 100644 index 05bb9c2f9bdf35e222c94db3491157893da64ebd..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distributions/python/ops/bijectors/affine_impl.py +++ /dev/null @@ -1,403 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Affine bijector.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib import linalg -from tensorflow.contrib.distributions.python.ops import distribution_util -from tensorflow.contrib.distributions.python.ops.shape import _DistributionShape -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_util -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import check_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops.distributions import bijector - - -__all__ = [ - "Affine", -] - - -def _as_tensor(x, name): - """Convenience to convert to `Tensor` or leave as `None`.""" - return None if x is None else ops.convert_to_tensor(x, name=name) - - -class Affine(bijector.Bijector): - """Compute `Y = g(X; shift, scale) = scale @ X + shift`. - - Here `scale = c * I + diag(D1) + tril(L) + V @ diag(D2) @ V.T`. - - In TF parlance, the `scale` term is logically equivalent to: - - ```python - scale = ( - scale_identity_multiplier * tf.diag(tf.ones(d)) + - tf.diag(scale_diag) + - scale_tril + - scale_perturb_factor @ diag(scale_perturb_diag) @ - tf.transpose([scale_perturb_factor]) - ) - ``` - - The `scale` term is applied without necessarily materializing constituent - matrices, i.e., the matmul is [matrix-free]( - https://en.wikipedia.org/wiki/Matrix-free_methods) when possible. - - Examples: - - ```python - # Y = X - b = Affine() - - # Y = X + shift - b = Affine(shift=[1., 2, 3]) - - # Y = 2 * I @ X.T + shift - b = Affine(shift=[1., 2, 3], - scale_identity_multiplier=2.) - - # Y = tf.diag(d1) @ X.T + shift - b = Affine(shift=[1., 2, 3], - scale_diag=[-1., 2, 1]) # Implicitly 3x3. - - # Y = (I + v * v.T) @ X.T + shift - b = Affine(shift=[1., 2, 3], - scale_perturb_factor=[[1., 0], - [0, 1], - [1, 1]]) - - # Y = (diag(d1) + v * diag(d2) * v.T) @ X.T + shift - b = Affine(shift=[1., 2, 3], - scale_diag=[1., 3, 3], # Implicitly 3x3. - scale_perturb_diag=[2., 1], # Implicitly 2x2. - scale_perturb_factor=[[1., 0], - [0, 1], - [1, 1]]) - - ``` - - """ - - def __init__(self, - shift=None, - scale_identity_multiplier=None, - scale_diag=None, - scale_tril=None, - scale_perturb_factor=None, - scale_perturb_diag=None, - event_ndims=1, - validate_args=False, - name="affine"): - """Instantiates the `Affine` bijector. - - This `Bijector` is initialized with `shift` `Tensor` and `scale` arguments, - giving the forward operation: - - ```none - Y = g(X) = scale @ X + shift - ``` - - where the `scale` term is logically equivalent to: - - ```python - scale = ( - scale_identity_multiplier * tf.diag(tf.ones(d)) + - tf.diag(scale_diag) + - scale_tril + - scale_perturb_factor @ diag(scale_perturb_diag) @ - tf.transpose([scale_perturb_factor]) - ) - ``` - - If none of `scale_identity_multiplier`, `scale_diag`, or `scale_tril` are - specified then `scale += IdentityMatrix`. Otherwise specifying a - `scale` argument has the semantics of `scale += Expand(arg)`, i.e., - `scale_diag != None` means `scale += tf.diag(scale_diag)`. - - Args: - shift: Floating-point `Tensor`. If this is set to `None`, no shift is - applied. - scale_identity_multiplier: floating point rank 0 `Tensor` representing a - scaling done to the identity matrix. - When `scale_identity_multiplier = scale_diag = scale_tril = None` then - `scale += IdentityMatrix`. Otherwise no scaled-identity-matrix is added - to `scale`. - scale_diag: Floating-point `Tensor` representing the diagonal matrix. - `scale_diag` has shape [N1, N2, ... k], which represents a k x k - diagonal matrix. - When `None` no diagonal term is added to `scale`. - scale_tril: Floating-point `Tensor` representing the diagonal matrix. - `scale_diag` has shape [N1, N2, ... k, k], which represents a k x k - lower triangular matrix. - When `None` no `scale_tril` term is added to `scale`. - The upper triangular elements above the diagonal are ignored. - scale_perturb_factor: Floating-point `Tensor` representing factor matrix - with last two dimensions of shape `(k, r)`. When `None`, no rank-r - update is added to `scale`. - scale_perturb_diag: Floating-point `Tensor` representing the diagonal - matrix. `scale_perturb_diag` has shape [N1, N2, ... r], which - represents an `r x r` diagonal matrix. When `None` low rank updates will - take the form `scale_perturb_factor * scale_perturb_factor.T`. - event_ndims: Scalar `int` `Tensor` indicating the number of dimensions - associated with a particular draw from the distribution. Must be 0 or 1. - validate_args: Python `bool` indicating whether arguments should be - checked for correctness. - name: Python `str` name given to ops managed by this object. - - Raises: - ValueError: if `perturb_diag` is specified but not `perturb_factor`. - TypeError: if `shift` has different `dtype` from `scale` arguments. - """ - self._graph_parents = [] - self._name = name - self._validate_args = validate_args - - # Ambiguous definition of low rank update. - if scale_perturb_diag is not None and scale_perturb_factor is None: - raise ValueError("When scale_perturb_diag is specified, " - "scale_perturb_factor must be specified.") - - # Special case, only handling a scaled identity matrix. We don't know its - # dimensions, so this is special cased. - # We don't check identity_multiplier, since below we set it to 1. if all - # other scale args are None. - self._is_only_identity_multiplier = (scale_tril is None and - scale_diag is None and - scale_perturb_factor is None) - - with self._name_scope("init", values=[ - shift, scale_identity_multiplier, scale_diag, scale_tril, - scale_perturb_diag, scale_perturb_factor]): - event_ndims = ops.convert_to_tensor(event_ndims, name="event_ndims") - event_ndims_const = tensor_util.constant_value(event_ndims) - if event_ndims_const is not None and event_ndims_const not in (0, 1): - raise ValueError("event_ndims(%s) was not 0 or 1" % event_ndims_const) - else: - if validate_args: - # Shape tool will catch if event_ndims is negative. - event_ndims = control_flow_ops.with_dependencies( - [check_ops.assert_less( - event_ndims, 2, message="event_ndims must be 0 or 1")], - event_ndims) - - if event_ndims_const == 0 and not self._is_only_identity_multiplier: - raise ValueError( - "If event_ndims == 0, the only scale argument you can pass is " - "scale_identity_multiplier. All others operate on vectors.") - - # In the absence of `loc` and `scale`, we'll assume `dtype` is `float32`. - dtype = dtypes.float32 - - if shift is not None: - shift = ops.convert_to_tensor(shift, name="shift") - dtype = shift.dtype.base_dtype - self._shift = shift - - # When no args are specified, pretend the scale matrix is the identity - # matrix. - if (self._is_only_identity_multiplier and - scale_identity_multiplier is None): - scale_identity_multiplier = ops.convert_to_tensor(1., dtype=dtype) - - # self._create_scale_operator returns a LinearOperator in all cases - # except if self._is_only_identity_multiplier; in which case it - # returns a scalar Tensor. - scale = self._create_scale_operator( - identity_multiplier=scale_identity_multiplier, - diag=scale_diag, - tril=scale_tril, - perturb_diag=scale_perturb_diag, - perturb_factor=scale_perturb_factor, - shift=shift, - validate_args=validate_args) - - if scale.dtype is not None: - dtype = scale.dtype.base_dtype - - if scale is not None and not self._is_only_identity_multiplier: - if (shift is not None and - shift.dtype.base_dtype != scale.dtype.base_dtype): - raise TypeError( - "shift.dtype({}) is incompatible with scale.dtype({}).".format( - shift.dtype, scale.dtype)) - - if scale.tensor_rank is not None: - batch_ndims = scale.tensor_rank - 2 - else: - batch_ndims = scale.tensor_rank_tensor() - 2 - else: - # We won't need shape inference when scale is None or when scale is a - # scalar. - batch_ndims = 0 - self._scale = scale - self._shaper = _DistributionShape( - batch_ndims=batch_ndims, - event_ndims=event_ndims, - validate_args=validate_args) - super(Affine, self).__init__( - event_ndims=event_ndims, - graph_parents=( - [event_ndims] + - [self._scale] if tensor_util.is_tensor(self._scale) - else self._scale.graph_parents + - [self._shift] if self._shift is not None else []), - is_constant_jacobian=True, - dtype=dtype, - validate_args=validate_args, - name=name) - - def _create_scale_operator(self, identity_multiplier, diag, tril, - perturb_diag, perturb_factor, shift, - validate_args): - """Construct `scale` from various components. - - Args: - identity_multiplier: floating point rank 0 `Tensor` representing a scaling - done to the identity matrix. - diag: Floating-point `Tensor` representing the diagonal matrix. - `scale_diag` has shape [N1, N2, ... k], which represents a k x k - diagonal matrix. - tril: Floating-point `Tensor` representing the diagonal matrix. - `scale_tril` has shape [N1, N2, ... k], which represents a k x k lower - triangular matrix. - perturb_diag: Floating-point `Tensor` representing the diagonal matrix of - the low rank update. - perturb_factor: Floating-point `Tensor` representing factor matrix. - shift: Floating-point `Tensor` representing `shift in `scale @ X + shift`. - validate_args: Python `bool` indicating whether arguments should be - checked for correctness. - - Returns: - scale. In the case of scaling by a constant, scale is a - floating point `Tensor`. Otherwise, scale is a `LinearOperator`. - - Raises: - ValueError: if all of `tril`, `diag` and `identity_multiplier` are `None`. - """ - identity_multiplier = _as_tensor(identity_multiplier, "identity_multiplier") - diag = _as_tensor(diag, "diag") - tril = _as_tensor(tril, "tril") - perturb_diag = _as_tensor(perturb_diag, "perturb_diag") - perturb_factor = _as_tensor(perturb_factor, "perturb_factor") - - # If possible, use the low rank update to infer the shape of - # the identity matrix, when scale represents a scaled identity matrix - # with a low rank update. - shape_hint = None - if perturb_factor is not None: - shape_hint = distribution_util.dimension_size(perturb_factor, axis=-2) - - if self._is_only_identity_multiplier: - if validate_args: - return control_flow_ops.with_dependencies( - [check_ops.assert_none_equal( - identity_multiplier, - array_ops.zeros([], identity_multiplier.dtype), - ["identity_multiplier should be non-zero."])], - identity_multiplier) - return identity_multiplier - - scale = distribution_util.make_tril_scale( - loc=shift, - scale_tril=tril, - scale_diag=diag, - scale_identity_multiplier=identity_multiplier, - validate_args=validate_args, - assert_positive=False, - shape_hint=shape_hint) - - if perturb_factor is not None: - return linalg.LinearOperatorLowRankUpdate( - scale, - u=perturb_factor, - diag_update=perturb_diag, - is_diag_update_positive=perturb_diag is None, - is_non_singular=True, # Implied by is_positive_definite=True. - is_self_adjoint=True, - is_positive_definite=True, - is_square=True) - - return scale - - @property - def shift(self): - """The `shift` `Tensor` in `Y = scale @ X + shift`.""" - return self._shift - - @property - def scale(self): - """The `scale` `LinearOperator` in `Y = scale @ X + shift`.""" - return self._scale - - def _forward(self, x): - y = x - if self._is_only_identity_multiplier: - y *= self._scale - if self.shift is not None: - return y + self.shift - return y - y, sample_shape = self._shaper.make_batch_of_event_sample_matrices( - y, expand_batch_dim=False) - with ops.control_dependencies(self._maybe_check_scale() if - self.validate_args else []): - y = self.scale.matmul(y) - y = self._shaper.undo_make_batch_of_event_sample_matrices( - y, sample_shape, expand_batch_dim=False) - if self.shift is not None: - y += self.shift - return y - - def _inverse(self, y): - x = y - if self.shift is not None: - x -= self.shift - if self._is_only_identity_multiplier: - return x / self._scale - - x, sample_shape = self._shaper.make_batch_of_event_sample_matrices( - x, expand_batch_dim=False) - # Solve fails if the op is singular so we may safely skip this assertion. - x = self.scale.solve(x) - x = self._shaper.undo_make_batch_of_event_sample_matrices( - x, sample_shape, expand_batch_dim=False) - return x - - def _inverse_log_det_jacobian(self, y): - return -self._forward_log_det_jacobian(y) - - def _forward_log_det_jacobian(self, x): - if self._is_only_identity_multiplier: - # We don't pad in this case and instead let the fldj be applied - # via broadcast. - event_size = distribution_util.pick_vector( - math_ops.equal(self._shaper.event_ndims, 0), - [1], array_ops.shape(x))[-1] - event_size = math_ops.cast(event_size, dtype=self._scale.dtype) - return math_ops.log(math_ops.abs(self._scale)) * event_size - return self.scale.log_abs_determinant() - - def _maybe_check_scale(self): - try: - return [self.scale.assert_non_singular()] - except NotImplementedError: - pass - return [] diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/affine_linear_operator.py b/tensorflow/contrib/distributions/python/ops/bijectors/affine_linear_operator.py index aca04a89df7c3ee09d5f7cc10f6779e33fa7aa66..89043b1410370074f11f2cfa59b6b6663fa62521 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/affine_linear_operator.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/affine_linear_operator.py @@ -18,12 +18,214 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.distributions.python.ops.bijectors.affine_linear_operator_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented +from tensorflow.contrib.distributions.python.ops.shape import _DistributionShape +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops.distributions import bijector +from tensorflow.python.ops.linalg import linear_operator -_allowed_symbols = ["AffineLinearOperator"] -remove_undocumented(__name__, _allowed_symbols) +__all__ = [ + "AffineLinearOperator", +] + + +class AffineLinearOperator(bijector.Bijector): + """Compute `Y = g(X; shift, scale) = scale @ X + shift`. + + `shift` is a numeric `Tensor` and `scale` is a `LinearOperator`. + + If `X` is a scalar then the forward transformation is: `scale * X + shift` + where `*` denotes the scalar product. + + Note: we don't always simply transpose `X` (but write it this way for + brevity). Actually the input `X` undergoes the following transformation + before being premultiplied by `scale`: + + 1. If there are no sample dims, we call `X = tf.expand_dims(X, 0)`, i.e., + `new_sample_shape = [1]`. Otherwise do nothing. + 2. The sample shape is flattened to have one dimension, i.e., + `new_sample_shape = [n]` where `n = tf.reduce_prod(old_sample_shape)`. + 3. The sample dim is cyclically rotated left by 1, i.e., + `new_shape = [B1,...,Bb, k, n]` where `n` is as above, `k` is the + event_shape, and `B1,...,Bb` are the batch shapes for each of `b` batch + dimensions. + + (For more details see `shape.make_batch_of_event_sample_matrices`.) + + The result of the above transformation is that `X` can be regarded as a batch + of matrices where each column is a draw from the distribution. After + premultiplying by `scale`, we take the inverse of this procedure. The input + `Y` also undergoes the same transformation before/after premultiplying by + `inv(scale)`. + + Example Use: + + ```python + linalg = tf.linalg + + x = [1., 2, 3] + + shift = [-1., 0., 1] + diag = [1., 2, 3] + scale = linalg.LinearOperatorDiag(diag) + affine = AffineLinearOperator(shift, scale) + # In this case, `forward` is equivalent to: + # y = scale @ x + shift + y = affine.forward(x) # [0., 4, 10] + + shift = [2., 3, 1] + tril = [[1., 0, 0], + [2, 1, 0], + [3, 2, 1]] + scale = linalg.LinearOperatorLowerTriangular(tril) + affine = AffineLinearOperator(shift, scale) + # In this case, `forward` is equivalent to: + # np.squeeze(np.matmul(tril, np.expand_dims(x, -1)), -1) + shift + y = affine.forward(x) # [3., 7, 11] + ``` + + """ + + def __init__(self, + shift=None, + scale=None, + event_ndims=1, + validate_args=False, + name="affine_linear_operator"): + """Instantiates the `AffineLinearOperator` bijector. + + Args: + shift: Floating-point `Tensor`. + scale: Subclass of `LinearOperator`. Represents the (batch) positive + definite matrix `M` in `R^{k x k}`. + event_ndims: Scalar `integer` `Tensor` indicating the number of dimensions + associated with a particular draw from the distribution. Must be 0 or 1. + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str` name given to ops managed by this object. + + Raises: + ValueError: if `event_ndims` is not 0 or 1. + TypeError: if `scale` is not a `LinearOperator`. + TypeError: if `shift.dtype` does not match `scale.dtype`. + ValueError: if not `scale.is_non_singular`. + """ + self._graph_parents = [] + self._name = name + self._validate_args = validate_args + graph_parents = [] + with self._name_scope("init", values=[shift]): + event_ndims = ops.convert_to_tensor(event_ndims, name="event_ndims") + if tensor_util.constant_value(event_ndims) is not None: + event_ndims = tensor_util.constant_value(event_ndims) + if event_ndims not in (0, 1): + raise ValueError("event_ndims({}) was not 0 or 1".format(event_ndims)) + else: + if validate_args: + # Shape tool will catch if event_ndims is negative. + event_ndims = control_flow_ops.with_dependencies( + [check_ops.assert_less( + event_ndims, 2, message="event_ndims must be 0 or 1")], + event_ndims) + graph_parents += [event_ndims] + + # In the absence of `loc` and `scale`, we'll assume `dtype` is `float32`. + dtype = dtypes.float32 + + if shift is not None: + shift = ops.convert_to_tensor(shift, name="shift") + graph_parents += [shift] + dtype = shift.dtype.base_dtype + self._shift = shift + + if scale is not None: + if (shift is not None and + shift.dtype.base_dtype != scale.dtype.base_dtype): + raise TypeError( + "shift.dtype({}) is incompatible with scale.dtype({}).".format( + shift.dtype, scale.dtype)) + if not isinstance(scale, linear_operator.LinearOperator): + raise TypeError("scale is not an instance of tf.LinearOperator") + if validate_args and not scale.is_non_singular: + raise ValueError("Scale matrix must be non-singular.") + graph_parents += scale.graph_parents + if scale.tensor_rank is not None: + batch_ndims = scale.tensor_rank - 2 + else: + batch_ndims = scale.tensor_rank_tensor() - 2 + graph_parents += [batch_ndims] + if scale.dtype is not None: + dtype = scale.dtype.base_dtype + else: + batch_ndims = 0 # We won't need shape inference when scale is None. + self._scale = scale + self._shaper = _DistributionShape( + batch_ndims=batch_ndims, + event_ndims=event_ndims, + validate_args=validate_args) + super(AffineLinearOperator, self).__init__( + event_ndims=event_ndims, + graph_parents=graph_parents, + is_constant_jacobian=True, + dtype=dtype, + validate_args=validate_args, + name=name) + + @property + def shift(self): + """The `shift` `Tensor` in `Y = scale @ X + shift`.""" + return self._shift + + @property + def scale(self): + """The `scale` `LinearOperator` in `Y = scale @ X + shift`.""" + return self._scale + + def _forward(self, x): + y = x + if self.scale is not None: + y, sample_shape = self._shaper.make_batch_of_event_sample_matrices( + y, expand_batch_dim=False) + with ops.control_dependencies(self._maybe_collect_assertions() if + self.validate_args else []): + y = self.scale.matmul(y) + y = self._shaper.undo_make_batch_of_event_sample_matrices( + y, sample_shape, expand_batch_dim=False) + if self.shift is not None: + y += self.shift + return y + + def _inverse(self, y): + x = y + if self.shift is not None: + x -= self.shift + if self.scale is not None: + x, sample_shape = self._shaper.make_batch_of_event_sample_matrices( + x, expand_batch_dim=False) + # Solve fails if the op is singular so we may safely skip this assertion. + x = self.scale.solve(x) + x = self._shaper.undo_make_batch_of_event_sample_matrices( + x, sample_shape, expand_batch_dim=False) + return x + + def _inverse_log_det_jacobian(self, y): + return -self._forward_log_det_jacobian(y) + + def _forward_log_det_jacobian(self, x): # pylint: disable=unused-argument + if self.scale is None: + return constant_op.constant(0, dtype=x.dtype.base_dtype) + with ops.control_dependencies(self._maybe_collect_assertions() if + self.validate_args else []): + return self.scale.log_abs_determinant() + + def _maybe_collect_assertions(self): + try: + return [self.scale.assert_non_singular()] + except NotImplementedError: + pass + return [] diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/affine_linear_operator_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/affine_linear_operator_impl.py deleted file mode 100644 index 89043b1410370074f11f2cfa59b6b6663fa62521..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distributions/python/ops/bijectors/affine_linear_operator_impl.py +++ /dev/null @@ -1,231 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""AffineLinearOperator bijector.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.distributions.python.ops.shape import _DistributionShape -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_util -from tensorflow.python.ops import check_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops.distributions import bijector -from tensorflow.python.ops.linalg import linear_operator - - -__all__ = [ - "AffineLinearOperator", -] - - -class AffineLinearOperator(bijector.Bijector): - """Compute `Y = g(X; shift, scale) = scale @ X + shift`. - - `shift` is a numeric `Tensor` and `scale` is a `LinearOperator`. - - If `X` is a scalar then the forward transformation is: `scale * X + shift` - where `*` denotes the scalar product. - - Note: we don't always simply transpose `X` (but write it this way for - brevity). Actually the input `X` undergoes the following transformation - before being premultiplied by `scale`: - - 1. If there are no sample dims, we call `X = tf.expand_dims(X, 0)`, i.e., - `new_sample_shape = [1]`. Otherwise do nothing. - 2. The sample shape is flattened to have one dimension, i.e., - `new_sample_shape = [n]` where `n = tf.reduce_prod(old_sample_shape)`. - 3. The sample dim is cyclically rotated left by 1, i.e., - `new_shape = [B1,...,Bb, k, n]` where `n` is as above, `k` is the - event_shape, and `B1,...,Bb` are the batch shapes for each of `b` batch - dimensions. - - (For more details see `shape.make_batch_of_event_sample_matrices`.) - - The result of the above transformation is that `X` can be regarded as a batch - of matrices where each column is a draw from the distribution. After - premultiplying by `scale`, we take the inverse of this procedure. The input - `Y` also undergoes the same transformation before/after premultiplying by - `inv(scale)`. - - Example Use: - - ```python - linalg = tf.linalg - - x = [1., 2, 3] - - shift = [-1., 0., 1] - diag = [1., 2, 3] - scale = linalg.LinearOperatorDiag(diag) - affine = AffineLinearOperator(shift, scale) - # In this case, `forward` is equivalent to: - # y = scale @ x + shift - y = affine.forward(x) # [0., 4, 10] - - shift = [2., 3, 1] - tril = [[1., 0, 0], - [2, 1, 0], - [3, 2, 1]] - scale = linalg.LinearOperatorLowerTriangular(tril) - affine = AffineLinearOperator(shift, scale) - # In this case, `forward` is equivalent to: - # np.squeeze(np.matmul(tril, np.expand_dims(x, -1)), -1) + shift - y = affine.forward(x) # [3., 7, 11] - ``` - - """ - - def __init__(self, - shift=None, - scale=None, - event_ndims=1, - validate_args=False, - name="affine_linear_operator"): - """Instantiates the `AffineLinearOperator` bijector. - - Args: - shift: Floating-point `Tensor`. - scale: Subclass of `LinearOperator`. Represents the (batch) positive - definite matrix `M` in `R^{k x k}`. - event_ndims: Scalar `integer` `Tensor` indicating the number of dimensions - associated with a particular draw from the distribution. Must be 0 or 1. - validate_args: Python `bool` indicating whether arguments should be - checked for correctness. - name: Python `str` name given to ops managed by this object. - - Raises: - ValueError: if `event_ndims` is not 0 or 1. - TypeError: if `scale` is not a `LinearOperator`. - TypeError: if `shift.dtype` does not match `scale.dtype`. - ValueError: if not `scale.is_non_singular`. - """ - self._graph_parents = [] - self._name = name - self._validate_args = validate_args - graph_parents = [] - with self._name_scope("init", values=[shift]): - event_ndims = ops.convert_to_tensor(event_ndims, name="event_ndims") - if tensor_util.constant_value(event_ndims) is not None: - event_ndims = tensor_util.constant_value(event_ndims) - if event_ndims not in (0, 1): - raise ValueError("event_ndims({}) was not 0 or 1".format(event_ndims)) - else: - if validate_args: - # Shape tool will catch if event_ndims is negative. - event_ndims = control_flow_ops.with_dependencies( - [check_ops.assert_less( - event_ndims, 2, message="event_ndims must be 0 or 1")], - event_ndims) - graph_parents += [event_ndims] - - # In the absence of `loc` and `scale`, we'll assume `dtype` is `float32`. - dtype = dtypes.float32 - - if shift is not None: - shift = ops.convert_to_tensor(shift, name="shift") - graph_parents += [shift] - dtype = shift.dtype.base_dtype - self._shift = shift - - if scale is not None: - if (shift is not None and - shift.dtype.base_dtype != scale.dtype.base_dtype): - raise TypeError( - "shift.dtype({}) is incompatible with scale.dtype({}).".format( - shift.dtype, scale.dtype)) - if not isinstance(scale, linear_operator.LinearOperator): - raise TypeError("scale is not an instance of tf.LinearOperator") - if validate_args and not scale.is_non_singular: - raise ValueError("Scale matrix must be non-singular.") - graph_parents += scale.graph_parents - if scale.tensor_rank is not None: - batch_ndims = scale.tensor_rank - 2 - else: - batch_ndims = scale.tensor_rank_tensor() - 2 - graph_parents += [batch_ndims] - if scale.dtype is not None: - dtype = scale.dtype.base_dtype - else: - batch_ndims = 0 # We won't need shape inference when scale is None. - self._scale = scale - self._shaper = _DistributionShape( - batch_ndims=batch_ndims, - event_ndims=event_ndims, - validate_args=validate_args) - super(AffineLinearOperator, self).__init__( - event_ndims=event_ndims, - graph_parents=graph_parents, - is_constant_jacobian=True, - dtype=dtype, - validate_args=validate_args, - name=name) - - @property - def shift(self): - """The `shift` `Tensor` in `Y = scale @ X + shift`.""" - return self._shift - - @property - def scale(self): - """The `scale` `LinearOperator` in `Y = scale @ X + shift`.""" - return self._scale - - def _forward(self, x): - y = x - if self.scale is not None: - y, sample_shape = self._shaper.make_batch_of_event_sample_matrices( - y, expand_batch_dim=False) - with ops.control_dependencies(self._maybe_collect_assertions() if - self.validate_args else []): - y = self.scale.matmul(y) - y = self._shaper.undo_make_batch_of_event_sample_matrices( - y, sample_shape, expand_batch_dim=False) - if self.shift is not None: - y += self.shift - return y - - def _inverse(self, y): - x = y - if self.shift is not None: - x -= self.shift - if self.scale is not None: - x, sample_shape = self._shaper.make_batch_of_event_sample_matrices( - x, expand_batch_dim=False) - # Solve fails if the op is singular so we may safely skip this assertion. - x = self.scale.solve(x) - x = self._shaper.undo_make_batch_of_event_sample_matrices( - x, sample_shape, expand_batch_dim=False) - return x - - def _inverse_log_det_jacobian(self, y): - return -self._forward_log_det_jacobian(y) - - def _forward_log_det_jacobian(self, x): # pylint: disable=unused-argument - if self.scale is None: - return constant_op.constant(0, dtype=x.dtype.base_dtype) - with ops.control_dependencies(self._maybe_collect_assertions() if - self.validate_args else []): - return self.scale.log_abs_determinant() - - def _maybe_collect_assertions(self): - try: - return [self.scale.assert_non_singular()] - except NotImplementedError: - pass - return [] diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/chain.py b/tensorflow/contrib/distributions/python/ops/bijectors/chain.py index 0db10fb75c8483a8209f39370362b05a03d047ca..3ce7c26213034c7345a20faa803c94a1bfa8d579 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/chain.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/chain.py @@ -18,12 +18,151 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.distributions.python.ops.bijectors.chain_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented +import itertools -_allowed_symbols = ["Chain"] +from tensorflow.python.framework import constant_op +from tensorflow.python.ops.distributions import bijector -remove_undocumented(__name__, _allowed_symbols) + +__all__ = [ + "Chain", +] + + +class Chain(bijector.Bijector): + """Bijector which applies a sequence of bijectors. + + Example Use: + + ```python + chain = Chain([Exp(), Softplus()], name="one_plus_exp") + ``` + + Results in: + + * Forward: + + ```python + exp = Exp() + softplus = Softplus() + Chain([exp, softplus]).forward(x) + = exp.forward(softplus.forward(x)) + = tf.exp(tf.log(1. + tf.exp(x))) + = 1. + tf.exp(x) + ``` + + * Inverse: + + ```python + exp = Exp() + softplus = Softplus() + Chain([exp, softplus]).inverse(y) + = softplus.inverse(exp.inverse(y)) + = tf.log(tf.exp(tf.log(y)) - 1.) + = tf.log(y - 1.) + ``` + + """ + + def __init__(self, bijectors=None, validate_args=False, name=None): + """Instantiates `Chain` bijector. + + Args: + bijectors: Python `list` of bijector instances. An empty list makes this + bijector equivalent to the `Identity` bijector. + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str`, name given to ops managed by this object. Default: + E.g., `Chain([Exp(), Softplus()]).name == "chain_of_exp_of_softplus"`. + + Raises: + ValueError: if bijectors have different dtypes. + """ + if bijectors is None: + bijectors = () + self._bijectors = bijectors + + for a_bijector in bijectors: + if not a_bijector._is_injective: # pylint: disable=protected-access + raise NotImplementedError( + "Invert is not implemented for non-injective bijector ({})".format( + a_bijector.name)) + + dtype = list(set([b.dtype for b in bijectors])) + if len(dtype) > 2: + raise ValueError("incompatible dtypes: %s" % dtype) + elif len(dtype) == 2: + dtype = dtype[1] if dtype[0] is None else dtype[0] + event_ndims = bijectors[0].event_ndims + elif len(dtype) == 1: + dtype = dtype[0] + event_ndims = bijectors[0].event_ndims + else: + dtype = None + event_ndims = None + + super(Chain, self).__init__( + graph_parents=list(itertools.chain.from_iterable( + b.graph_parents for b in bijectors)), + is_constant_jacobian=all(b.is_constant_jacobian for b in bijectors), + validate_args=validate_args, + dtype=dtype, + event_ndims=event_ndims, + name=name or ("identity" if not bijectors else + "_of_".join(["chain"] + [b.name for b in bijectors]))) + + @property + def bijectors(self): + return self._bijectors + + def _shape_helper(self, func_name, input_shape, reverse): + new_shape = input_shape + for b in reversed(self.bijectors) if reverse else self.bijectors: + func = getattr(b, func_name, None) + if func is None: + raise ValueError("unable to call %s on bijector %s (%s)" % + (func_name, b.name, func)) + new_shape = func(new_shape) + return new_shape + + def _forward_event_shape(self, input_shape): + return self._shape_helper("forward_event_shape", input_shape, + reverse=True) + + def _forward_event_shape_tensor(self, input_shape): + return self._shape_helper( + "forward_event_shape_tensor", input_shape, reverse=True) + + def _inverse_event_shape(self, output_shape): + return self._shape_helper("inverse_event_shape", output_shape, + reverse=False) + + def _inverse_event_shape_tensor(self, output_shape): + return self._shape_helper("inverse_event_shape_tensor", output_shape, + reverse=False) + + def _inverse(self, y, **kwargs): + for b in self.bijectors: + y = b.inverse(y, **kwargs.get(b.name, {})) + return y + + def _inverse_log_det_jacobian(self, y, **kwargs): + ildj = constant_op.constant(0., dtype=y.dtype, + name="inverse_log_det_jacobian") + for b in self.bijectors: + ildj += b.inverse_log_det_jacobian(y, **kwargs.get(b.name, {})) + y = b.inverse(y, **kwargs.get(b.name, {})) + return ildj + + def _forward(self, x, **kwargs): + for b in reversed(self.bijectors): + x = b.forward(x, **kwargs.get(b.name, {})) + return x + + def _forward_log_det_jacobian(self, x, **kwargs): + fldj = constant_op.constant(0., dtype=x.dtype, + name="forward_log_det_jacobian") + for b in reversed(self.bijectors): + fldj += b.forward_log_det_jacobian(x, **kwargs.get(b.name, {})) + x = b.forward(x, **kwargs.get(b.name, {})) + return fldj diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/chain_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/chain_impl.py deleted file mode 100644 index 3ce7c26213034c7345a20faa803c94a1bfa8d579..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distributions/python/ops/bijectors/chain_impl.py +++ /dev/null @@ -1,168 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Chain bijector.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import itertools - -from tensorflow.python.framework import constant_op -from tensorflow.python.ops.distributions import bijector - - -__all__ = [ - "Chain", -] - - -class Chain(bijector.Bijector): - """Bijector which applies a sequence of bijectors. - - Example Use: - - ```python - chain = Chain([Exp(), Softplus()], name="one_plus_exp") - ``` - - Results in: - - * Forward: - - ```python - exp = Exp() - softplus = Softplus() - Chain([exp, softplus]).forward(x) - = exp.forward(softplus.forward(x)) - = tf.exp(tf.log(1. + tf.exp(x))) - = 1. + tf.exp(x) - ``` - - * Inverse: - - ```python - exp = Exp() - softplus = Softplus() - Chain([exp, softplus]).inverse(y) - = softplus.inverse(exp.inverse(y)) - = tf.log(tf.exp(tf.log(y)) - 1.) - = tf.log(y - 1.) - ``` - - """ - - def __init__(self, bijectors=None, validate_args=False, name=None): - """Instantiates `Chain` bijector. - - Args: - bijectors: Python `list` of bijector instances. An empty list makes this - bijector equivalent to the `Identity` bijector. - validate_args: Python `bool` indicating whether arguments should be - checked for correctness. - name: Python `str`, name given to ops managed by this object. Default: - E.g., `Chain([Exp(), Softplus()]).name == "chain_of_exp_of_softplus"`. - - Raises: - ValueError: if bijectors have different dtypes. - """ - if bijectors is None: - bijectors = () - self._bijectors = bijectors - - for a_bijector in bijectors: - if not a_bijector._is_injective: # pylint: disable=protected-access - raise NotImplementedError( - "Invert is not implemented for non-injective bijector ({})".format( - a_bijector.name)) - - dtype = list(set([b.dtype for b in bijectors])) - if len(dtype) > 2: - raise ValueError("incompatible dtypes: %s" % dtype) - elif len(dtype) == 2: - dtype = dtype[1] if dtype[0] is None else dtype[0] - event_ndims = bijectors[0].event_ndims - elif len(dtype) == 1: - dtype = dtype[0] - event_ndims = bijectors[0].event_ndims - else: - dtype = None - event_ndims = None - - super(Chain, self).__init__( - graph_parents=list(itertools.chain.from_iterable( - b.graph_parents for b in bijectors)), - is_constant_jacobian=all(b.is_constant_jacobian for b in bijectors), - validate_args=validate_args, - dtype=dtype, - event_ndims=event_ndims, - name=name or ("identity" if not bijectors else - "_of_".join(["chain"] + [b.name for b in bijectors]))) - - @property - def bijectors(self): - return self._bijectors - - def _shape_helper(self, func_name, input_shape, reverse): - new_shape = input_shape - for b in reversed(self.bijectors) if reverse else self.bijectors: - func = getattr(b, func_name, None) - if func is None: - raise ValueError("unable to call %s on bijector %s (%s)" % - (func_name, b.name, func)) - new_shape = func(new_shape) - return new_shape - - def _forward_event_shape(self, input_shape): - return self._shape_helper("forward_event_shape", input_shape, - reverse=True) - - def _forward_event_shape_tensor(self, input_shape): - return self._shape_helper( - "forward_event_shape_tensor", input_shape, reverse=True) - - def _inverse_event_shape(self, output_shape): - return self._shape_helper("inverse_event_shape", output_shape, - reverse=False) - - def _inverse_event_shape_tensor(self, output_shape): - return self._shape_helper("inverse_event_shape_tensor", output_shape, - reverse=False) - - def _inverse(self, y, **kwargs): - for b in self.bijectors: - y = b.inverse(y, **kwargs.get(b.name, {})) - return y - - def _inverse_log_det_jacobian(self, y, **kwargs): - ildj = constant_op.constant(0., dtype=y.dtype, - name="inverse_log_det_jacobian") - for b in self.bijectors: - ildj += b.inverse_log_det_jacobian(y, **kwargs.get(b.name, {})) - y = b.inverse(y, **kwargs.get(b.name, {})) - return ildj - - def _forward(self, x, **kwargs): - for b in reversed(self.bijectors): - x = b.forward(x, **kwargs.get(b.name, {})) - return x - - def _forward_log_det_jacobian(self, x, **kwargs): - fldj = constant_op.constant(0., dtype=x.dtype, - name="forward_log_det_jacobian") - for b in reversed(self.bijectors): - fldj += b.forward_log_det_jacobian(x, **kwargs.get(b.name, {})) - x = b.forward(x, **kwargs.get(b.name, {})) - return fldj diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product.py b/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product.py index 4686af8bc42a3232cb3a34f2cfcce8323c5896dd..cbd60f92a60612c6cf791b2c7708a3310c6e2b6b 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product.py @@ -18,12 +18,219 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.distributions.python.ops.bijectors.cholesky_outer_product_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented +import numpy as np -_allowed_symbols = ["CholeskyOuterProduct"] +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import bijector +from tensorflow.python.ops.distributions import util as distribution_util -remove_undocumented(__name__, _allowed_symbols) + +__all__ = [ + "CholeskyOuterProduct", +] + + +class CholeskyOuterProduct(bijector.Bijector): + """Compute `g(X) = X @ X.T`; X is lower-triangular, positive-diagonal matrix. + + `event_ndims` must be 0 or 2, i.e., scalar or matrix. + + Note: the upper-triangular part of X is ignored (whether or not its zero). + + The surjectivity of g as a map from the set of n x n positive-diagonal + lower-triangular matrices to the set of SPD matrices follows immediately from + executing the Cholesky factorization algorithm on an SPD matrix A to produce a + positive-diagonal lower-triangular matrix L such that `A = L @ L.T`. + + To prove the injectivity of g, suppose that L_1 and L_2 are lower-triangular + with positive diagonals and satisfy `A = L_1 @ L_1.T = L_2 @ L_2.T`. Then + `inv(L_1) @ A @ inv(L_1).T = [inv(L_1) @ L_2] @ [inv(L_1) @ L_2].T = I`. + Setting `L_3 := inv(L_1) @ L_2`, that L_3 is a positive-diagonal + lower-triangular matrix follows from `inv(L_1)` being positive-diagonal + lower-triangular (which follows from the diagonal of a triangular matrix being + its spectrum), and that the product of two positive-diagonal lower-triangular + matrices is another positive-diagonal lower-triangular matrix. + + A simple inductive argument (proceding one column of L_3 at a time) shows + that, if `I = L_3 @ L_3.T`, with L_3 being lower-triangular with positive- + diagonal, then `L_3 = I`. Thus, `L_1 = L_2`, proving injectivity of g. + + Examples: + + ```python + bijector.CholeskyOuterProduct(event_ndims=2).forward(x=[[1., 0], [2, 1]]) + # Result: [[1., 2], [2, 5]], i.e., x @ x.T + + bijector.CholeskyOuterProduct(event_ndims=2).inverse(y=[[1., 2], [2, 5]]) + # Result: [[1., 0], [2, 1]], i.e., cholesky(y). + ``` + + """ + + def __init__(self, event_ndims=2, validate_args=False, + name="cholesky_outer_product"): + """Instantiates the `CholeskyOuterProduct` bijector. + + Args: + event_ndims: `constant` `int32` scalar `Tensor` indicating the number of + dimensions associated with a particular draw from the distribution. Must + be 0 or 2. + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str` name given to ops managed by this object. + + Raises: + ValueError: if event_ndims is neither 0 or 2. + """ + self._graph_parents = [] + self._name = name + with self._name_scope("init", values=[event_ndims]): + event_ndims = ops.convert_to_tensor(event_ndims, name="event_ndims") + event_ndims = tensor_util.constant_value(event_ndims) + if event_ndims is None or event_ndims not in [0, 2]: + raise ValueError("`event_ndims` must be a TF constant which is 0 or 2") + self._static_event_ndims = event_ndims + super(CholeskyOuterProduct, self).__init__( + event_ndims=event_ndims, + validate_args=validate_args, + name=name) + + def _forward(self, x): + if self._static_event_ndims == 0: + return math_ops.square(x) + if self.validate_args: + is_matrix = check_ops.assert_rank_at_least(x, 2) + shape = array_ops.shape(x) + is_square = check_ops.assert_equal(shape[-2], shape[-1]) + x = control_flow_ops.with_dependencies([is_matrix, is_square], x) + # For safety, explicitly zero-out the upper triangular part. + x = array_ops.matrix_band_part(x, -1, 0) + return math_ops.matmul(x, x, adjoint_b=True) + + def _inverse(self, y): + return (math_ops.sqrt(y) if self._static_event_ndims == 0 + else linalg_ops.cholesky(y)) + + def _inverse_log_det_jacobian(self, y): + return -self._forward_log_det_jacobian(x=self._inverse(y)) + + def _forward_log_det_jacobian(self, x): + # Let Y be a symmetric, positive definite matrix and write: + # Y = X X.T + # where X is lower-triangular. + # + # Observe that, + # dY[i,j]/dX[a,b] + # = d/dX[a,b] { X[i,:] X[j,:] } + # = sum_{d=1}^p { I[i=a] I[d=b] X[j,d] + I[j=a] I[d=b] X[i,d] } + # + # To compute the Jacobian dX/dY we must represent X,Y as vectors. Since Y is + # symmetric and X is lower-triangular, we need vectors of dimension: + # d = p (p + 1) / 2 + # where X, Y are p x p matrices, p > 0. We use a row-major mapping, i.e., + # k = { i (i + 1) / 2 + j i>=j + # { undef ij thus i,j!=a. + # + # Since the Jacobian is lower-triangular, we need only compute the product + # of diagonal elements: + # d vec[Y] / d vec[X] @[k(i,j), k(i,j)] + # = X[j,j] + I[i=j] X[i,j] + # = 2 X[j,j]. + # Since there is a 2 X[j,j] term for every lower-triangular element of X we + # conclude: + # |Jac(d vec[Y]/d vec[X])| = 2^p prod_{j=0}^{p-1} X[j,j]^{p-j}. + if self._static_event_ndims == 0: + if self.validate_args: + is_positive = check_ops.assert_positive( + x, message="All elements must be positive.") + x = control_flow_ops.with_dependencies([is_positive], x) + return np.log(2.) + math_ops.log(x) + + diag = array_ops.matrix_diag_part(x) + + # We now ensure diag is columnar. Eg, if `diag = [1, 2, 3]` then the output + # is `[[1], [2], [3]]` and if `diag = [[1, 2, 3], [4, 5, 6]]` then the + # output is unchanged. + diag = self._make_columnar(diag) + + if self.validate_args: + is_matrix = check_ops.assert_rank_at_least( + x, 2, message="Input must be a (batch of) matrix.") + shape = array_ops.shape(x) + is_square = check_ops.assert_equal( + shape[-2], shape[-1], + message="Input must be a (batch of) square matrix.") + # Assuming lower-triangular means we only need check diag>0. + is_positive_definite = check_ops.assert_positive( + diag, message="Input must be positive definite.") + x = control_flow_ops.with_dependencies( + [is_matrix, is_square, is_positive_definite], x) + + # Create a vector equal to: [p, p-1, ..., 2, 1]. + if x.get_shape().ndims is None or x.get_shape()[-1].value is None: + p_int = array_ops.shape(x)[-1] + p_float = math_ops.cast(p_int, dtype=x.dtype) + else: + p_int = x.get_shape()[-1].value + p_float = np.array(p_int, dtype=x.dtype.as_numpy_dtype) + exponents = math_ops.linspace(p_float, 1., p_int) + + sum_weighted_log_diag = array_ops.squeeze( + math_ops.matmul(math_ops.log(diag), + exponents[..., array_ops.newaxis]), + squeeze_dims=-1) + fldj = p_float * np.log(2.) + sum_weighted_log_diag + + return fldj + + def _make_columnar(self, x): + """Ensures non-scalar input has at least one column. + + Example: + If `x = [1, 2, 3]` then the output is `[[1], [2], [3]]`. + + If `x = [[1, 2, 3], [4, 5, 6]]` then the output is unchanged. + + If `x = 1` then the output is unchanged. + + Args: + x: `Tensor`. + + Returns: + columnar_x: `Tensor` with at least two dimensions. + """ + if x.get_shape().ndims is not None: + if x.get_shape().ndims == 1: + x = x[array_ops.newaxis, :] + return x + shape = array_ops.shape(x) + maybe_expanded_shape = array_ops.concat([ + shape[:-1], + distribution_util.pick_vector( + math_ops.equal(array_ops.rank(x), 1), + [1], np.array([], dtype=np.int32)), + shape[-1:], + ], 0) + return array_ops.reshape(x, maybe_expanded_shape) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product_impl.py deleted file mode 100644 index cbd60f92a60612c6cf791b2c7708a3310c6e2b6b..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product_impl.py +++ /dev/null @@ -1,236 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""CholeskyOuterProduct bijector.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_util -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import check_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import linalg_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops.distributions import bijector -from tensorflow.python.ops.distributions import util as distribution_util - - -__all__ = [ - "CholeskyOuterProduct", -] - - -class CholeskyOuterProduct(bijector.Bijector): - """Compute `g(X) = X @ X.T`; X is lower-triangular, positive-diagonal matrix. - - `event_ndims` must be 0 or 2, i.e., scalar or matrix. - - Note: the upper-triangular part of X is ignored (whether or not its zero). - - The surjectivity of g as a map from the set of n x n positive-diagonal - lower-triangular matrices to the set of SPD matrices follows immediately from - executing the Cholesky factorization algorithm on an SPD matrix A to produce a - positive-diagonal lower-triangular matrix L such that `A = L @ L.T`. - - To prove the injectivity of g, suppose that L_1 and L_2 are lower-triangular - with positive diagonals and satisfy `A = L_1 @ L_1.T = L_2 @ L_2.T`. Then - `inv(L_1) @ A @ inv(L_1).T = [inv(L_1) @ L_2] @ [inv(L_1) @ L_2].T = I`. - Setting `L_3 := inv(L_1) @ L_2`, that L_3 is a positive-diagonal - lower-triangular matrix follows from `inv(L_1)` being positive-diagonal - lower-triangular (which follows from the diagonal of a triangular matrix being - its spectrum), and that the product of two positive-diagonal lower-triangular - matrices is another positive-diagonal lower-triangular matrix. - - A simple inductive argument (proceding one column of L_3 at a time) shows - that, if `I = L_3 @ L_3.T`, with L_3 being lower-triangular with positive- - diagonal, then `L_3 = I`. Thus, `L_1 = L_2`, proving injectivity of g. - - Examples: - - ```python - bijector.CholeskyOuterProduct(event_ndims=2).forward(x=[[1., 0], [2, 1]]) - # Result: [[1., 2], [2, 5]], i.e., x @ x.T - - bijector.CholeskyOuterProduct(event_ndims=2).inverse(y=[[1., 2], [2, 5]]) - # Result: [[1., 0], [2, 1]], i.e., cholesky(y). - ``` - - """ - - def __init__(self, event_ndims=2, validate_args=False, - name="cholesky_outer_product"): - """Instantiates the `CholeskyOuterProduct` bijector. - - Args: - event_ndims: `constant` `int32` scalar `Tensor` indicating the number of - dimensions associated with a particular draw from the distribution. Must - be 0 or 2. - validate_args: Python `bool` indicating whether arguments should be - checked for correctness. - name: Python `str` name given to ops managed by this object. - - Raises: - ValueError: if event_ndims is neither 0 or 2. - """ - self._graph_parents = [] - self._name = name - with self._name_scope("init", values=[event_ndims]): - event_ndims = ops.convert_to_tensor(event_ndims, name="event_ndims") - event_ndims = tensor_util.constant_value(event_ndims) - if event_ndims is None or event_ndims not in [0, 2]: - raise ValueError("`event_ndims` must be a TF constant which is 0 or 2") - self._static_event_ndims = event_ndims - super(CholeskyOuterProduct, self).__init__( - event_ndims=event_ndims, - validate_args=validate_args, - name=name) - - def _forward(self, x): - if self._static_event_ndims == 0: - return math_ops.square(x) - if self.validate_args: - is_matrix = check_ops.assert_rank_at_least(x, 2) - shape = array_ops.shape(x) - is_square = check_ops.assert_equal(shape[-2], shape[-1]) - x = control_flow_ops.with_dependencies([is_matrix, is_square], x) - # For safety, explicitly zero-out the upper triangular part. - x = array_ops.matrix_band_part(x, -1, 0) - return math_ops.matmul(x, x, adjoint_b=True) - - def _inverse(self, y): - return (math_ops.sqrt(y) if self._static_event_ndims == 0 - else linalg_ops.cholesky(y)) - - def _inverse_log_det_jacobian(self, y): - return -self._forward_log_det_jacobian(x=self._inverse(y)) - - def _forward_log_det_jacobian(self, x): - # Let Y be a symmetric, positive definite matrix and write: - # Y = X X.T - # where X is lower-triangular. - # - # Observe that, - # dY[i,j]/dX[a,b] - # = d/dX[a,b] { X[i,:] X[j,:] } - # = sum_{d=1}^p { I[i=a] I[d=b] X[j,d] + I[j=a] I[d=b] X[i,d] } - # - # To compute the Jacobian dX/dY we must represent X,Y as vectors. Since Y is - # symmetric and X is lower-triangular, we need vectors of dimension: - # d = p (p + 1) / 2 - # where X, Y are p x p matrices, p > 0. We use a row-major mapping, i.e., - # k = { i (i + 1) / 2 + j i>=j - # { undef ij thus i,j!=a. - # - # Since the Jacobian is lower-triangular, we need only compute the product - # of diagonal elements: - # d vec[Y] / d vec[X] @[k(i,j), k(i,j)] - # = X[j,j] + I[i=j] X[i,j] - # = 2 X[j,j]. - # Since there is a 2 X[j,j] term for every lower-triangular element of X we - # conclude: - # |Jac(d vec[Y]/d vec[X])| = 2^p prod_{j=0}^{p-1} X[j,j]^{p-j}. - if self._static_event_ndims == 0: - if self.validate_args: - is_positive = check_ops.assert_positive( - x, message="All elements must be positive.") - x = control_flow_ops.with_dependencies([is_positive], x) - return np.log(2.) + math_ops.log(x) - - diag = array_ops.matrix_diag_part(x) - - # We now ensure diag is columnar. Eg, if `diag = [1, 2, 3]` then the output - # is `[[1], [2], [3]]` and if `diag = [[1, 2, 3], [4, 5, 6]]` then the - # output is unchanged. - diag = self._make_columnar(diag) - - if self.validate_args: - is_matrix = check_ops.assert_rank_at_least( - x, 2, message="Input must be a (batch of) matrix.") - shape = array_ops.shape(x) - is_square = check_ops.assert_equal( - shape[-2], shape[-1], - message="Input must be a (batch of) square matrix.") - # Assuming lower-triangular means we only need check diag>0. - is_positive_definite = check_ops.assert_positive( - diag, message="Input must be positive definite.") - x = control_flow_ops.with_dependencies( - [is_matrix, is_square, is_positive_definite], x) - - # Create a vector equal to: [p, p-1, ..., 2, 1]. - if x.get_shape().ndims is None or x.get_shape()[-1].value is None: - p_int = array_ops.shape(x)[-1] - p_float = math_ops.cast(p_int, dtype=x.dtype) - else: - p_int = x.get_shape()[-1].value - p_float = np.array(p_int, dtype=x.dtype.as_numpy_dtype) - exponents = math_ops.linspace(p_float, 1., p_int) - - sum_weighted_log_diag = array_ops.squeeze( - math_ops.matmul(math_ops.log(diag), - exponents[..., array_ops.newaxis]), - squeeze_dims=-1) - fldj = p_float * np.log(2.) + sum_weighted_log_diag - - return fldj - - def _make_columnar(self, x): - """Ensures non-scalar input has at least one column. - - Example: - If `x = [1, 2, 3]` then the output is `[[1], [2], [3]]`. - - If `x = [[1, 2, 3], [4, 5, 6]]` then the output is unchanged. - - If `x = 1` then the output is unchanged. - - Args: - x: `Tensor`. - - Returns: - columnar_x: `Tensor` with at least two dimensions. - """ - if x.get_shape().ndims is not None: - if x.get_shape().ndims == 1: - x = x[array_ops.newaxis, :] - return x - shape = array_ops.shape(x) - maybe_expanded_shape = array_ops.concat([ - shape[:-1], - distribution_util.pick_vector( - math_ops.equal(array_ops.rank(x), 1), - [1], np.array([], dtype=np.int32)), - shape[-1:], - ], 0) - return array_ops.reshape(x, maybe_expanded_shape) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/conditional_bijector.py b/tensorflow/contrib/distributions/python/ops/bijectors/conditional_bijector.py index d254b635d28099a09a2054536f04ffee3a355b2f..ccb1f029277bc07011df7be047a075274f2b3a27 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/conditional_bijector.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/conditional_bijector.py @@ -18,12 +18,38 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.distributions.python.ops.bijectors.conditional_bijector_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented +from tensorflow.python.ops.distributions import bijector +from tensorflow.python.ops.distributions import util as distribution_util -_allowed_symbols = ["ConditionalBijector"] -remove_undocumented(__name__, _allowed_symbols) +__all__ = ["ConditionalBijector"] + + +class ConditionalBijector(bijector.Bijector): + """Conditional Bijector is a Bijector that allows intrinsic conditioning.""" + + @distribution_util.AppendDocstring(kwargs_dict={ + "**condition_kwargs": + "Named arguments forwarded to subclass implementation."}) + def forward(self, x, name="forward", **condition_kwargs): + return self._call_forward(x, name, **condition_kwargs) + + @distribution_util.AppendDocstring(kwargs_dict={ + "**condition_kwargs": + "Named arguments forwarded to subclass implementation."}) + def inverse(self, y, name="inverse", **condition_kwargs): + return self._call_inverse(y, name, **condition_kwargs) + + @distribution_util.AppendDocstring(kwargs_dict={ + "**condition_kwargs": + "Named arguments forwarded to subclass implementation."}) + def inverse_log_det_jacobian( + self, y, name="inverse_log_det_jacobian", **condition_kwargs): + return self._call_inverse_log_det_jacobian(y, name, **condition_kwargs) + + @distribution_util.AppendDocstring(kwargs_dict={ + "**condition_kwargs": + "Named arguments forwarded to subclass implementation."}) + def forward_log_det_jacobian( + self, x, name="forward_log_det_jacobian", **condition_kwargs): + return self._call_forward_log_det_jacobian(x, name, **condition_kwargs) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/conditional_bijector_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/conditional_bijector_impl.py deleted file mode 100644 index ccb1f029277bc07011df7be047a075274f2b3a27..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distributions/python/ops/bijectors/conditional_bijector_impl.py +++ /dev/null @@ -1,55 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""ConditionalBijector base.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python.ops.distributions import bijector -from tensorflow.python.ops.distributions import util as distribution_util - - -__all__ = ["ConditionalBijector"] - - -class ConditionalBijector(bijector.Bijector): - """Conditional Bijector is a Bijector that allows intrinsic conditioning.""" - - @distribution_util.AppendDocstring(kwargs_dict={ - "**condition_kwargs": - "Named arguments forwarded to subclass implementation."}) - def forward(self, x, name="forward", **condition_kwargs): - return self._call_forward(x, name, **condition_kwargs) - - @distribution_util.AppendDocstring(kwargs_dict={ - "**condition_kwargs": - "Named arguments forwarded to subclass implementation."}) - def inverse(self, y, name="inverse", **condition_kwargs): - return self._call_inverse(y, name, **condition_kwargs) - - @distribution_util.AppendDocstring(kwargs_dict={ - "**condition_kwargs": - "Named arguments forwarded to subclass implementation."}) - def inverse_log_det_jacobian( - self, y, name="inverse_log_det_jacobian", **condition_kwargs): - return self._call_inverse_log_det_jacobian(y, name, **condition_kwargs) - - @distribution_util.AppendDocstring(kwargs_dict={ - "**condition_kwargs": - "Named arguments forwarded to subclass implementation."}) - def forward_log_det_jacobian( - self, x, name="forward_log_det_jacobian", **condition_kwargs): - return self._call_forward_log_det_jacobian(x, name, **condition_kwargs) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/exp.py b/tensorflow/contrib/distributions/python/ops/bijectors/exp.py index 399d713098eb7223601beb9518dc51dd6160ad64..b1ff840d62a73c941a4d67dec73b5c9f4d5353f9 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/exp.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/exp.py @@ -18,12 +18,49 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.distributions.python.ops.bijectors.exp_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented +from tensorflow.contrib.distributions.python.ops.bijectors import power_transform -_allowed_symbols = ["Exp"] -remove_undocumented(__name__, _allowed_symbols) +__all__ = [ + "Exp", +] + + +class Exp(power_transform.PowerTransform): + """Compute `Y = g(X) = exp(X)`. + + Example Use: + + ```python + # Create the Y=g(X)=exp(X) transform which works only on Tensors with 1 + # batch ndim and 2 event ndims (i.e., vector of matrices). + exp = Exp(event_ndims=2) + x = [[[1., 2], + [3, 4]], + [[5, 6], + [7, 8]]] + exp(x) == exp.forward(x) + log(x) == exp.inverse(x) + ``` + + Note: the exp(.) is applied element-wise but the Jacobian is a reduction + over the event space. + """ + + def __init__(self, + event_ndims=0, + validate_args=False, + name="exp"): + """Instantiates the `Exp` bijector. + + Args: + event_ndims: Scalar `int32` `Tensor` indicating the number of dimensions + associated with a particular draw from the distribution. + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str` name given to ops managed by this object. + """ + super(Exp, self).__init__( + event_ndims=event_ndims, + validate_args=validate_args, + name=name) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/exp_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/exp_impl.py deleted file mode 100644 index b1ff840d62a73c941a4d67dec73b5c9f4d5353f9..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distributions/python/ops/bijectors/exp_impl.py +++ /dev/null @@ -1,66 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Exp bijector.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.distributions.python.ops.bijectors import power_transform - - -__all__ = [ - "Exp", -] - - -class Exp(power_transform.PowerTransform): - """Compute `Y = g(X) = exp(X)`. - - Example Use: - - ```python - # Create the Y=g(X)=exp(X) transform which works only on Tensors with 1 - # batch ndim and 2 event ndims (i.e., vector of matrices). - exp = Exp(event_ndims=2) - x = [[[1., 2], - [3, 4]], - [[5, 6], - [7, 8]]] - exp(x) == exp.forward(x) - log(x) == exp.inverse(x) - ``` - - Note: the exp(.) is applied element-wise but the Jacobian is a reduction - over the event space. - """ - - def __init__(self, - event_ndims=0, - validate_args=False, - name="exp"): - """Instantiates the `Exp` bijector. - - Args: - event_ndims: Scalar `int32` `Tensor` indicating the number of dimensions - associated with a particular draw from the distribution. - validate_args: Python `bool` indicating whether arguments should be - checked for correctness. - name: Python `str` name given to ops managed by this object. - """ - super(Exp, self).__init__( - event_ndims=event_ndims, - validate_args=validate_args, - name=name) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/gumbel.py b/tensorflow/contrib/distributions/python/ops/bijectors/gumbel.py index cf37aa51115ed98ab263bc03bcb297a03432a7ae..67f39785563255be0fe154aca3cbcf01c6a01e73 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/gumbel.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/gumbel.py @@ -18,12 +18,107 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.distributions.python.ops.bijectors.gumbel_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import bijector -_allowed_symbols = ["Gumbel"] +__all__ = [ + "Gumbel", +] -remove_undocumented(__name__, _allowed_symbols) + +class Gumbel(bijector.Bijector): + """Compute `Y = g(X) = exp(-exp(-(X - loc) / scale))`. + + This bijector maps inputs from `[-inf, inf]` to [0, 1]`. The inverse of the + bijector applied to a uniform random variable `X ~ U(0, 1) gives back a + random variable with the + [Gumbel distribution](https://en.wikipedia.org/wiki/Gumbel_distribution): + + ```none + Y ~ Gumbel(loc, scale) + pdf(y; loc, scale) = exp( + -( (y - loc) / scale + exp(- (y - loc) / scale) ) ) / scale + ``` + """ + + def __init__(self, + loc=0., + scale=1., + event_ndims=0, + validate_args=False, + name="gumbel"): + """Instantiates the `Gumbel` bijector. + + Args: + loc: Float-like `Tensor` that is the same dtype and is + broadcastable with `scale`. + This is `loc` in `Y = g(X) = exp(-exp(-(X - loc) / scale))`. + scale: Positive Float-like `Tensor` that is the same dtype and is + broadcastable with `loc`. + This is `scale` in `Y = g(X) = exp(-exp(-(X - loc) / scale))`. + event_ndims: Python scalar indicating the number of dimensions associated + with a particular draw from the distribution. + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str` name given to ops managed by this object. + """ + self._graph_parents = [] + self._name = name + self._validate_args = validate_args + with self._name_scope("init", values=[loc, scale]): + self._loc = ops.convert_to_tensor(loc, name="loc") + self._scale = ops.convert_to_tensor(scale, name="scale") + check_ops.assert_same_float_dtype([self._loc, self._scale]) + if validate_args: + self._scale = control_flow_ops.with_dependencies([ + check_ops.assert_positive( + self._scale, message="Argument scale was not positive") + ], self._scale) + + super(Gumbel, self).__init__( + event_ndims=event_ndims, validate_args=validate_args, name=name) + + @property + def loc(self): + """The `loc` in `Y = g(X) = exp(-exp(-(X - loc) / scale))`.""" + return self._loc + + @property + def scale(self): + """This is `scale` in `Y = g(X) = exp(-exp(-(X - loc) / scale))`.""" + return self._scale + + def _forward(self, x): + z = (x - self.loc) / self.scale + return math_ops.exp(-math_ops.exp(-z)) + + def _inverse(self, y): + y = self._maybe_assert_valid_y(y) + return self.loc - self.scale * math_ops.log(-math_ops.log(y)) + + def _inverse_log_det_jacobian(self, y): + y = self._maybe_assert_valid_y(y) + event_dims = self._event_dims_tensor(y) + return math_ops.reduce_sum( + math_ops.log(self.scale / (-math_ops.log(y) * y)), axis=event_dims) + + def _forward_log_det_jacobian(self, x): + event_dims = self._event_dims_tensor(x) + z = (x - self.loc) / self.scale + return math_ops.reduce_sum( + -z - math_ops.exp(-z) - math_ops.log(self.scale), axis=event_dims) + + def _maybe_assert_valid_y(self, y): + if not self.validate_args: + return y + is_positive = check_ops.assert_non_negative( + y, message="Inverse transformation input must be greater than 0.") + less_than_one = check_ops.assert_less_equal( + y, + constant_op.constant(1., y.dtype), + message="Inverse transformation input must be less than or equal to 1.") + return control_flow_ops.with_dependencies([is_positive, less_than_one], y) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/gumbel_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/gumbel_impl.py deleted file mode 100644 index 67f39785563255be0fe154aca3cbcf01c6a01e73..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distributions/python/ops/bijectors/gumbel_impl.py +++ /dev/null @@ -1,124 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Gumbel bijector.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import ops -from tensorflow.python.ops import check_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops.distributions import bijector - -__all__ = [ - "Gumbel", -] - - -class Gumbel(bijector.Bijector): - """Compute `Y = g(X) = exp(-exp(-(X - loc) / scale))`. - - This bijector maps inputs from `[-inf, inf]` to [0, 1]`. The inverse of the - bijector applied to a uniform random variable `X ~ U(0, 1) gives back a - random variable with the - [Gumbel distribution](https://en.wikipedia.org/wiki/Gumbel_distribution): - - ```none - Y ~ Gumbel(loc, scale) - pdf(y; loc, scale) = exp( - -( (y - loc) / scale + exp(- (y - loc) / scale) ) ) / scale - ``` - """ - - def __init__(self, - loc=0., - scale=1., - event_ndims=0, - validate_args=False, - name="gumbel"): - """Instantiates the `Gumbel` bijector. - - Args: - loc: Float-like `Tensor` that is the same dtype and is - broadcastable with `scale`. - This is `loc` in `Y = g(X) = exp(-exp(-(X - loc) / scale))`. - scale: Positive Float-like `Tensor` that is the same dtype and is - broadcastable with `loc`. - This is `scale` in `Y = g(X) = exp(-exp(-(X - loc) / scale))`. - event_ndims: Python scalar indicating the number of dimensions associated - with a particular draw from the distribution. - validate_args: Python `bool` indicating whether arguments should be - checked for correctness. - name: Python `str` name given to ops managed by this object. - """ - self._graph_parents = [] - self._name = name - self._validate_args = validate_args - with self._name_scope("init", values=[loc, scale]): - self._loc = ops.convert_to_tensor(loc, name="loc") - self._scale = ops.convert_to_tensor(scale, name="scale") - check_ops.assert_same_float_dtype([self._loc, self._scale]) - if validate_args: - self._scale = control_flow_ops.with_dependencies([ - check_ops.assert_positive( - self._scale, message="Argument scale was not positive") - ], self._scale) - - super(Gumbel, self).__init__( - event_ndims=event_ndims, validate_args=validate_args, name=name) - - @property - def loc(self): - """The `loc` in `Y = g(X) = exp(-exp(-(X - loc) / scale))`.""" - return self._loc - - @property - def scale(self): - """This is `scale` in `Y = g(X) = exp(-exp(-(X - loc) / scale))`.""" - return self._scale - - def _forward(self, x): - z = (x - self.loc) / self.scale - return math_ops.exp(-math_ops.exp(-z)) - - def _inverse(self, y): - y = self._maybe_assert_valid_y(y) - return self.loc - self.scale * math_ops.log(-math_ops.log(y)) - - def _inverse_log_det_jacobian(self, y): - y = self._maybe_assert_valid_y(y) - event_dims = self._event_dims_tensor(y) - return math_ops.reduce_sum( - math_ops.log(self.scale / (-math_ops.log(y) * y)), axis=event_dims) - - def _forward_log_det_jacobian(self, x): - event_dims = self._event_dims_tensor(x) - z = (x - self.loc) / self.scale - return math_ops.reduce_sum( - -z - math_ops.exp(-z) - math_ops.log(self.scale), axis=event_dims) - - def _maybe_assert_valid_y(self, y): - if not self.validate_args: - return y - is_positive = check_ops.assert_non_negative( - y, message="Inverse transformation input must be greater than 0.") - less_than_one = check_ops.assert_less_equal( - y, - constant_op.constant(1., y.dtype), - message="Inverse transformation input must be less than or equal to 1.") - return control_flow_ops.with_dependencies([is_positive, less_than_one], y) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/inline.py b/tensorflow/contrib/distributions/python/ops/bijectors/inline.py index db10c3fc3a9135b4c408ada74622ba9b360f9ec1..fab1b22fbf92e7b92a5ec86ec62d66bec71a8c94 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/inline.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/inline.py @@ -18,12 +18,124 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.distributions.python.ops.bijectors.inline_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented +from tensorflow.python.ops.distributions import bijector -_allowed_symbols = ["Inline"] -remove_undocumented(__name__, _allowed_symbols) +__all__ = [ + "Inline", +] + + +class Inline(bijector.Bijector): + """Bijector constructed from custom callables. + + Example Use: + + ```python + exp = Inline( + forward_fn=tf.exp, + inverse_fn=tf.log, + inverse_log_det_jacobian_fn=( + lambda y: -tf.reduce_sum(tf.log(y), axis=-1)), + name="exp") + ``` + + The above example is equivalent to the `Bijector` `Exp(event_ndims=1)`. + """ + + def __init__(self, + forward_fn=None, + inverse_fn=None, + inverse_log_det_jacobian_fn=None, + forward_log_det_jacobian_fn=None, + forward_event_shape_fn=None, + forward_event_shape_tensor_fn=None, + inverse_event_shape_fn=None, + inverse_event_shape_tensor_fn=None, + is_constant_jacobian=False, + validate_args=False, + name="inline"): + """Creates a `Bijector` from callables. + + Args: + forward_fn: Python callable implementing the forward transformation. + inverse_fn: Python callable implementing the inverse transformation. + inverse_log_det_jacobian_fn: Python callable implementing the + log o det o jacobian of the inverse transformation. + forward_log_det_jacobian_fn: Python callable implementing the + log o det o jacobian of the forward transformation. + forward_event_shape_fn: Python callable implementing non-identical + static event shape changes. Default: shape is assumed unchanged. + forward_event_shape_tensor_fn: Python callable implementing non-identical + event shape changes. Default: shape is assumed unchanged. + inverse_event_shape_fn: Python callable implementing non-identical + static event shape changes. Default: shape is assumed unchanged. + inverse_event_shape_tensor_fn: Python callable implementing non-identical + event shape changes. Default: shape is assumed unchanged. + is_constant_jacobian: Python `bool` indicating that the Jacobian is + constant for all input arguments. + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str`, name given to ops managed by this object. + """ + super(Inline, self).__init__( + event_ndims=0, + is_constant_jacobian=is_constant_jacobian, + validate_args=validate_args, + name=name) + self._forward_fn = forward_fn + self._inverse_fn = inverse_fn + self._inverse_log_det_jacobian_fn = inverse_log_det_jacobian_fn + self._forward_log_det_jacobian_fn = forward_log_det_jacobian_fn + self._forward_event_shape_fn = forward_event_shape_fn + self._forward_event_shape_tensor_fn = forward_event_shape_tensor_fn + self._inverse_event_shape_fn = inverse_event_shape_fn + self._inverse_event_shape_tensor_fn = inverse_event_shape_tensor_fn + + def _forward_event_shape(self, input_shape): + if self._forward_event_shape_fn is None: + # By default assume shape doesn't change. + return input_shape + return self._forward_event_shape_fn(input_shape) + + def _forward_event_shape_tensor(self, input_shape): + if self._forward_event_shape_tensor_fn is None: + # By default assume shape doesn't change. + return input_shape + return self._forward_event_shape_tensor_fn(input_shape) + + def _inverse_event_shape(self, output_shape): + if self._inverse_event_shape_fn is None: + # By default assume shape doesn't change. + return output_shape + return self._inverse_event_shape_fn(output_shape) + + def _inverse_event_shape_tensor(self, output_shape): + if self._inverse_event_shape_tensor_fn is None: + # By default assume shape doesn't change. + return output_shape + return self._inverse_event_shape_tensor_fn(output_shape) + + def _forward(self, x, **kwargs): + if not callable(self._forward_fn): + raise NotImplementedError( + "forward_fn is not a callable function.") + return self._forward_fn(x, **kwargs) + + def _inverse(self, y, **kwargs): + if not callable(self._inverse_fn): + raise NotImplementedError( + "inverse_fn is not a callable function.") + return self._inverse_fn(y, **kwargs) + + def _inverse_log_det_jacobian(self, y, **kwargs): + if not callable(self._inverse_log_det_jacobian_fn): + raise NotImplementedError( + "inverse_log_det_jacobian_fn is not a callable function.") + return self._inverse_log_det_jacobian_fn(y, **kwargs) + + def _forward_log_det_jacobian(self, y, **kwargs): + if not callable(self._forward_log_det_jacobian_fn): + raise NotImplementedError( + "forward_log_det_jacobian_fn is not a callable function.") + return self._forward_log_det_jacobian_fn(y, **kwargs) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/inline_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/inline_impl.py deleted file mode 100644 index fab1b22fbf92e7b92a5ec86ec62d66bec71a8c94..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distributions/python/ops/bijectors/inline_impl.py +++ /dev/null @@ -1,141 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Inline bijector.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python.ops.distributions import bijector - - -__all__ = [ - "Inline", -] - - -class Inline(bijector.Bijector): - """Bijector constructed from custom callables. - - Example Use: - - ```python - exp = Inline( - forward_fn=tf.exp, - inverse_fn=tf.log, - inverse_log_det_jacobian_fn=( - lambda y: -tf.reduce_sum(tf.log(y), axis=-1)), - name="exp") - ``` - - The above example is equivalent to the `Bijector` `Exp(event_ndims=1)`. - """ - - def __init__(self, - forward_fn=None, - inverse_fn=None, - inverse_log_det_jacobian_fn=None, - forward_log_det_jacobian_fn=None, - forward_event_shape_fn=None, - forward_event_shape_tensor_fn=None, - inverse_event_shape_fn=None, - inverse_event_shape_tensor_fn=None, - is_constant_jacobian=False, - validate_args=False, - name="inline"): - """Creates a `Bijector` from callables. - - Args: - forward_fn: Python callable implementing the forward transformation. - inverse_fn: Python callable implementing the inverse transformation. - inverse_log_det_jacobian_fn: Python callable implementing the - log o det o jacobian of the inverse transformation. - forward_log_det_jacobian_fn: Python callable implementing the - log o det o jacobian of the forward transformation. - forward_event_shape_fn: Python callable implementing non-identical - static event shape changes. Default: shape is assumed unchanged. - forward_event_shape_tensor_fn: Python callable implementing non-identical - event shape changes. Default: shape is assumed unchanged. - inverse_event_shape_fn: Python callable implementing non-identical - static event shape changes. Default: shape is assumed unchanged. - inverse_event_shape_tensor_fn: Python callable implementing non-identical - event shape changes. Default: shape is assumed unchanged. - is_constant_jacobian: Python `bool` indicating that the Jacobian is - constant for all input arguments. - validate_args: Python `bool` indicating whether arguments should be - checked for correctness. - name: Python `str`, name given to ops managed by this object. - """ - super(Inline, self).__init__( - event_ndims=0, - is_constant_jacobian=is_constant_jacobian, - validate_args=validate_args, - name=name) - self._forward_fn = forward_fn - self._inverse_fn = inverse_fn - self._inverse_log_det_jacobian_fn = inverse_log_det_jacobian_fn - self._forward_log_det_jacobian_fn = forward_log_det_jacobian_fn - self._forward_event_shape_fn = forward_event_shape_fn - self._forward_event_shape_tensor_fn = forward_event_shape_tensor_fn - self._inverse_event_shape_fn = inverse_event_shape_fn - self._inverse_event_shape_tensor_fn = inverse_event_shape_tensor_fn - - def _forward_event_shape(self, input_shape): - if self._forward_event_shape_fn is None: - # By default assume shape doesn't change. - return input_shape - return self._forward_event_shape_fn(input_shape) - - def _forward_event_shape_tensor(self, input_shape): - if self._forward_event_shape_tensor_fn is None: - # By default assume shape doesn't change. - return input_shape - return self._forward_event_shape_tensor_fn(input_shape) - - def _inverse_event_shape(self, output_shape): - if self._inverse_event_shape_fn is None: - # By default assume shape doesn't change. - return output_shape - return self._inverse_event_shape_fn(output_shape) - - def _inverse_event_shape_tensor(self, output_shape): - if self._inverse_event_shape_tensor_fn is None: - # By default assume shape doesn't change. - return output_shape - return self._inverse_event_shape_tensor_fn(output_shape) - - def _forward(self, x, **kwargs): - if not callable(self._forward_fn): - raise NotImplementedError( - "forward_fn is not a callable function.") - return self._forward_fn(x, **kwargs) - - def _inverse(self, y, **kwargs): - if not callable(self._inverse_fn): - raise NotImplementedError( - "inverse_fn is not a callable function.") - return self._inverse_fn(y, **kwargs) - - def _inverse_log_det_jacobian(self, y, **kwargs): - if not callable(self._inverse_log_det_jacobian_fn): - raise NotImplementedError( - "inverse_log_det_jacobian_fn is not a callable function.") - return self._inverse_log_det_jacobian_fn(y, **kwargs) - - def _forward_log_det_jacobian(self, y, **kwargs): - if not callable(self._forward_log_det_jacobian_fn): - raise NotImplementedError( - "forward_log_det_jacobian_fn is not a callable function.") - return self._forward_log_det_jacobian_fn(y, **kwargs) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/invert.py b/tensorflow/contrib/distributions/python/ops/bijectors/invert.py index c134e10109ce5065eb58de1d847e3c487258954c..2c603fe61f36dd27f4984fe6c13c11f2fb534321 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/invert.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/invert.py @@ -18,12 +18,85 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.distributions.python.ops.bijectors.invert_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented +from tensorflow.python.ops.distributions import bijector as bijector_lib -_allowed_symbols = ["Invert"] +__all__ = [ + "Invert", +] -remove_undocumented(__name__, _allowed_symbols) + +class Invert(bijector_lib.Bijector): + """Bijector which inverts another Bijector. + + Example Use: [ExpGammaDistribution (see Background & Context)]( + https://reference.wolfram.com/language/ref/ExpGammaDistribution.html) + models `Y=log(X)` where `X ~ Gamma`. + + ```python + exp_gamma_distribution = TransformedDistribution( + distribution=Gamma(concentration=1., rate=2.), + bijector=bijector.Invert(bijector.Exp()) + ``` + + """ + + def __init__(self, bijector, validate_args=False, name=None): + """Creates a `Bijector` which swaps the meaning of `inverse` and `forward`. + + Note: An inverted bijector's `inverse_log_det_jacobian` is often more + efficient if the base bijector implements `_forward_log_det_jacobian`. If + `_forward_log_det_jacobian` is not implemented then the following code is + used: + + ```python + y = self.inverse(x, **kwargs) + return -self.inverse_log_det_jacobian(y, **kwargs) + ``` + + Args: + bijector: Bijector instance. + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str`, name given to ops managed by this object. + """ + + if not bijector._is_injective: # pylint: disable=protected-access + raise NotImplementedError( + "Invert is not implemented for non-injective bijectors.") + + self._bijector = bijector + super(Invert, self).__init__( + event_ndims=bijector.event_ndims, + graph_parents=bijector.graph_parents, + is_constant_jacobian=bijector.is_constant_jacobian, + validate_args=validate_args, + dtype=bijector.dtype, + name=name or "_".join(["invert", bijector.name])) + + def _forward_event_shape(self, input_shape): + return self.bijector._inverse_event_shape(input_shape) # pylint: disable=protected-access + + def _forward_event_shape_tensor(self, input_shape): + return self.bijector._inverse_event_shape_tensor(input_shape) # pylint: disable=protected-access + + def _inverse_event_shape(self, output_shape): + return self.bijector._forward_event_shape(output_shape) # pylint: disable=protected-access + + def _inverse_event_shape_tensor(self, output_shape): + return self.bijector._forward_event_shape_tensor(output_shape) # pylint: disable=protected-access + + @property + def bijector(self): + return self._bijector + + def _forward(self, x, **kwargs): + return self.bijector._inverse(x, **kwargs) # pylint: disable=protected-access + + def _inverse(self, y, **kwargs): + return self.bijector._forward(y, **kwargs) # pylint: disable=protected-access + + def _inverse_log_det_jacobian(self, y, **kwargs): + return self.bijector._forward_log_det_jacobian(y, **kwargs) # pylint: disable=protected-access + + def _forward_log_det_jacobian(self, x, **kwargs): + return self.bijector._inverse_log_det_jacobian(x, **kwargs) # pylint: disable=protected-access diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/invert_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/invert_impl.py deleted file mode 100644 index 2c603fe61f36dd27f4984fe6c13c11f2fb534321..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distributions/python/ops/bijectors/invert_impl.py +++ /dev/null @@ -1,102 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Invert bijector.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python.ops.distributions import bijector as bijector_lib - -__all__ = [ - "Invert", -] - - -class Invert(bijector_lib.Bijector): - """Bijector which inverts another Bijector. - - Example Use: [ExpGammaDistribution (see Background & Context)]( - https://reference.wolfram.com/language/ref/ExpGammaDistribution.html) - models `Y=log(X)` where `X ~ Gamma`. - - ```python - exp_gamma_distribution = TransformedDistribution( - distribution=Gamma(concentration=1., rate=2.), - bijector=bijector.Invert(bijector.Exp()) - ``` - - """ - - def __init__(self, bijector, validate_args=False, name=None): - """Creates a `Bijector` which swaps the meaning of `inverse` and `forward`. - - Note: An inverted bijector's `inverse_log_det_jacobian` is often more - efficient if the base bijector implements `_forward_log_det_jacobian`. If - `_forward_log_det_jacobian` is not implemented then the following code is - used: - - ```python - y = self.inverse(x, **kwargs) - return -self.inverse_log_det_jacobian(y, **kwargs) - ``` - - Args: - bijector: Bijector instance. - validate_args: Python `bool` indicating whether arguments should be - checked for correctness. - name: Python `str`, name given to ops managed by this object. - """ - - if not bijector._is_injective: # pylint: disable=protected-access - raise NotImplementedError( - "Invert is not implemented for non-injective bijectors.") - - self._bijector = bijector - super(Invert, self).__init__( - event_ndims=bijector.event_ndims, - graph_parents=bijector.graph_parents, - is_constant_jacobian=bijector.is_constant_jacobian, - validate_args=validate_args, - dtype=bijector.dtype, - name=name or "_".join(["invert", bijector.name])) - - def _forward_event_shape(self, input_shape): - return self.bijector._inverse_event_shape(input_shape) # pylint: disable=protected-access - - def _forward_event_shape_tensor(self, input_shape): - return self.bijector._inverse_event_shape_tensor(input_shape) # pylint: disable=protected-access - - def _inverse_event_shape(self, output_shape): - return self.bijector._forward_event_shape(output_shape) # pylint: disable=protected-access - - def _inverse_event_shape_tensor(self, output_shape): - return self.bijector._forward_event_shape_tensor(output_shape) # pylint: disable=protected-access - - @property - def bijector(self): - return self._bijector - - def _forward(self, x, **kwargs): - return self.bijector._inverse(x, **kwargs) # pylint: disable=protected-access - - def _inverse(self, y, **kwargs): - return self.bijector._forward(y, **kwargs) # pylint: disable=protected-access - - def _inverse_log_det_jacobian(self, y, **kwargs): - return self.bijector._forward_log_det_jacobian(y, **kwargs) # pylint: disable=protected-access - - def _forward_log_det_jacobian(self, x, **kwargs): - return self.bijector._inverse_log_det_jacobian(x, **kwargs) # pylint: disable=protected-access diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py b/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py index 132dc570f94719b6c71fb269866c943774481b7e..dc8ae1eed19eda772219287d8661f534ac242d10 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py @@ -18,16 +18,484 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.distributions.python.ops.bijectors.masked_autoregressive_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented +import numpy as np -_allowed_symbols = [ +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.layers import core as layers +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import clip_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import template as template_ops +from tensorflow.python.ops import variable_scope as variable_scope_lib +from tensorflow.python.ops.distributions import bijector as bijector_lib + + +__all__ = [ "MaskedAutoregressiveFlow", - "masked_dense", "masked_autoregressive_default_template", + "masked_dense", ] -remove_undocumented(__name__, _allowed_symbols) + +class MaskedAutoregressiveFlow(bijector_lib.Bijector): + """Affine MaskedAutoregressiveFlow bijector for vector-valued events. + + The affine autoregressive flow [1] provides a relatively simple framework for + user-specified (deep) architectures to learn a distribution over vector-valued + events. Regarding terminology, + + "Autoregressive models decompose the joint density as a product of + conditionals, and model each conditional in turn. Normalizing flows + transform a base density (e.g. a standard Gaussian) into the target density + by an invertible transformation with tractable Jacobian." [1] + + In other words, the "autoregressive property" is equivalent to the + decomposition, `p(x) = prod{ p(x[i] | x[0:i]) : i=0, ..., d }`. The provided + `shift_and_log_scale_fn`, `masked_autoregressive_default_template`, achieves + this property by zeroing out weights in its `masked_dense` layers. + + In the `tf.distributions` framework, a "normalizing flow" is implemented as a + `tf.distributions.bijectors.Bijector`. The `forward` "autoregression" + is implemented using a `tf.while_loop` and a deep neural network (DNN) with + masked weights such that the autoregressive property is automatically met in + the `inverse`. + + A `TransformedDistribution` using `MaskedAutoregressiveFlow(...)` uses the + (expensive) forward-mode calculation to draw samples and the (cheap) + reverse-mode calculation to compute log-probabilities. Conversely, a + `TransformedDistribution` using `Invert(MaskedAutoregressiveFlow(...))` uses + the (expensive) forward-mode calculation to compute log-probabilities and the + (cheap) reverse-mode calculation to compute samples. See "Example Use" + [below] for more details. + + Given a `shift_and_log_scale_fn`, the forward and inverse transformations are + (a sequence of) affine transformations. A "valid" `shift_and_log_scale_fn` + must compute each `shift` (aka `loc` or "mu" [2]) and `log(scale)` (aka + "alpha" [2]) such that each are broadcastable with the arguments to `forward` + and `inverse`, i.e., such that the calculations in `forward`, `inverse` + [below] are possible. + + For convenience, `masked_autoregressive_default_template` is offered as a + possible `shift_and_log_scale_fn` function. It implements the MADE + architecture [2]. MADE is a feed-forward network that computes a `shift` and + `log(scale)` using `masked_dense` layers in a deep neural network. Weights are + masked to ensure the autoregressive property. It is possible that this + architecture is suboptimal for your task. To build alternative networks, + either change the arguments to `masked_autoregressive_default_template`, use + the `masked_dense` function to roll-out your own, or use some other + architecture, e.g., using `tf.layers`. + + Warning: no attempt is made to validate that the `shift_and_log_scale_fn` + enforces the "autoregressive property". + + Assuming `shift_and_log_scale_fn` has valid shape and autoregressive + semantics, the forward transformation is, + + ```python + def forward(x): + y = zeros_like(x) + event_size = x.shape[-1] + for _ in range(event_size): + shift, log_scale = shift_and_log_scale_fn(y) + y = x * math_ops.exp(log_scale) + shift + return y + ``` + + and the inverse transformation is, + + ```python + def inverse(y): + shift, log_scale = shift_and_log_scale_fn(y) + return (y - shift) / math_ops.exp(log_scale) + ``` + + Notice that the `inverse` does not need a for-loop. This is because in the + forward pass each calculation of `shift` and `log_scale` is based on the `y` + calculated so far (not `x`). In the `inverse`, the `y` is fully known, thus is + equivalent to the scaling used in `forward` after `event_size` passes, i.e., + the "last" `y` used to compute `shift`, `log_scale`. (Roughly speaking, this + also proves the transform is bijective.) + + #### Example Use + + ```python + tfd = tf.contrib.distributions + tfb = tfd.bijectors + + dims = 5 + + # A common choice for a normalizing flow is to use a Gaussian for the base + # distribution. (However, any continuous distribution would work.) E.g., + maf = tfd.TransformedDistribution( + distribution=tfd.Normal(loc=0., scale=1.), + bijector=tfb.MaskedAutoregressiveFlow( + shift_and_log_scale_fn=tfb.masked_autoregressive_default_template( + hidden_layers=[512, 512])), + event_shape=[dims]) + + x = maf.sample() # Expensive; uses `tf.while_loop`, no Bijector caching. + maf.log_prob(x) # Almost free; uses Bijector caching. + maf.log_prob(0.) # Cheap; no `tf.while_loop` despite no Bijector caching. + + # [1] also describes an "Inverse Autoregressive Flow", e.g., + iaf = tfd.TransformedDistribution( + distribution=tfd.Normal(loc=0., scale=1.), + bijector=tfb.Invert(tfb.MaskedAutoregressiveFlow( + shift_and_log_scale_fn=tfb.masked_autoregressive_default_template( + hidden_layers=[512, 512]))), + event_shape=[dims]) + + x = iaf.sample() # Cheap; no `tf.while_loop` despite no Bijector caching. + iaf.log_prob(x) # Almost free; uses Bijector caching. + iaf.log_prob(0.) # Expensive; uses `tf.while_loop`, no Bijector caching. + + # In many (if not most) cases the default `shift_and_log_scale_fn` will be a + # poor choice. Here's an example of using a "shift only" version and with a + # different number/depth of hidden layers. + shift_only = True + maf_no_scale_hidden2 = tfd.TransformedDistribution( + distribution=tfd.Normal(loc=0., scale=1.), + bijector=tfb.MaskedAutoregressiveFlow( + tfb.masked_autoregressive_default_template( + hidden_layers=[32], + shift_only=shift_only), + is_constant_jacobian=shift_only), + event_shape=[dims]) + ``` + + [1]: "Masked Autoregressive Flow for Density Estimation." + George Papamakarios, Theo Pavlakou, Iain Murray. Arxiv. 2017. + https://arxiv.org/abs/1705.07057 + + [2]: "MADE: Masked Autoencoder for Distribution Estimation." + Mathieu Germain, Karol Gregor, Iain Murray, Hugo Larochelle. ICML. 2015. + https://arxiv.org/abs/1502.03509 + + """ + + def __init__(self, + shift_and_log_scale_fn, + is_constant_jacobian=False, + validate_args=False, + unroll_loop=False, + name=None): + """Creates the MaskedAutoregressiveFlow bijector. + + Args: + shift_and_log_scale_fn: Python `callable` which computes `shift` and + `log_scale` from both the forward domain (`x`) and the inverse domain + (`y`). Calculation must respect the "autoregressive property" (see class + docstring). Suggested default + `masked_autoregressive_default_template(hidden_layers=...)`. + Typically the function contains `tf.Variables` and is wrapped using + `tf.make_template`. Returning `None` for either (both) `shift`, + `log_scale` is equivalent to (but more efficient than) returning zero. + is_constant_jacobian: Python `bool`. Default: `False`. When `True` the + implementation assumes `log_scale` does not depend on the forward domain + (`x`) or inverse domain (`y`) values. (No validation is made; + `is_constant_jacobian=False` is always safe but possibly computationally + inefficient.) + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + unroll_loop: Python `bool` indicating whether the `tf.while_loop` in + `_forward` should be replaced with a static for loop. Requires that + the final dimension of `x` be known at graph construction time. Defaults + to `False`. + name: Python `str`, name given to ops managed by this object. + """ + name = name or "masked_autoregressive_flow" + self._shift_and_log_scale_fn = shift_and_log_scale_fn + self._unroll_loop = unroll_loop + super(MaskedAutoregressiveFlow, self).__init__( + is_constant_jacobian=is_constant_jacobian, + validate_args=validate_args, + name=name) + + def _forward(self, x): + if self._unroll_loop: + event_size = x.shape.with_rank_at_least(1)[-1].value + if event_size is None: + raise ValueError( + "The final dimension of `x` must be known at graph construction " + "time if `unroll_loop=True`. `x.shape: %r`" % x.shape) + y = array_ops.zeros_like(x, name="y0") + + for _ in range(event_size): + shift, log_scale = self._shift_and_log_scale_fn(y) + # next_y = scale * x + shift + next_y = x + if log_scale is not None: + next_y *= math_ops.exp(log_scale) + if shift is not None: + next_y += shift + y = next_y + return y + + event_size = array_ops.shape(x)[-1] + y0 = array_ops.zeros_like(x, name="y0") + # call the template once to ensure creation + _ = self._shift_and_log_scale_fn(y0) + def _loop_body(index, y0): + """While-loop body for autoregression calculation.""" + # Set caching device to avoid re-getting the tf.Variable for every while + # loop iteration. + with variable_scope_lib.variable_scope( + variable_scope_lib.get_variable_scope()) as vs: + if vs.caching_device is None: + vs.set_caching_device(lambda op: op.device) + shift, log_scale = self._shift_and_log_scale_fn(y0) + y = x + if log_scale is not None: + y *= math_ops.exp(log_scale) + if shift is not None: + y += shift + return index + 1, y + _, y = control_flow_ops.while_loop( + cond=lambda index, _: index < event_size, + body=_loop_body, + loop_vars=[0, y0]) + return y + + def _inverse(self, y): + shift, log_scale = self._shift_and_log_scale_fn(y) + x = y + if shift is not None: + x -= shift + if log_scale is not None: + x *= math_ops.exp(-log_scale) + return x + + def _inverse_log_det_jacobian(self, y): + _, log_scale = self._shift_and_log_scale_fn(y) + if log_scale is None: + return constant_op.constant(0., dtype=y.dtype, name="ildj") + return -math_ops.reduce_sum(log_scale, axis=-1) + + +MASK_INCLUSIVE = "inclusive" +MASK_EXCLUSIVE = "exclusive" + + +def _gen_slices(num_blocks, n_in, n_out, mask_type=MASK_EXCLUSIVE): + """Generate the slices for building an autoregressive mask.""" + # TODO(b/67594795): Better support of dynamic shape. + slices = [] + col = 0 + d_in = n_in // num_blocks + d_out = n_out // num_blocks + row = d_out if mask_type == MASK_EXCLUSIVE else 0 + for _ in range(num_blocks): + row_slice = slice(row, None) + col_slice = slice(col, col + d_in) + slices.append([row_slice, col_slice]) + col += d_in + row += d_out + return slices + + +def _gen_mask(num_blocks, + n_in, + n_out, + mask_type=MASK_EXCLUSIVE, + dtype=dtypes.float32): + """Generate the mask for building an autoregressive dense layer.""" + # TODO(b/67594795): Better support of dynamic shape. + mask = np.zeros([n_out, n_in], dtype=dtype.as_numpy_dtype()) + slices = _gen_slices(num_blocks, n_in, n_out, mask_type=mask_type) + for [row_slice, col_slice] in slices: + mask[row_slice, col_slice] = 1 + return mask + + +def masked_dense(inputs, + units, + num_blocks=None, + exclusive=False, + kernel_initializer=None, + reuse=None, + name=None, + *args, + **kwargs): + """A autoregressively masked dense layer. Analogous to `tf.layers.dense`. + + See [1] for detailed explanation. + + [1]: "MADE: Masked Autoencoder for Distribution Estimation." + Mathieu Germain, Karol Gregor, Iain Murray, Hugo Larochelle. ICML. 2015. + https://arxiv.org/abs/1502.03509 + + Arguments: + inputs: Tensor input. + units: Python `int` scalar representing the dimensionality of the output + space. + num_blocks: Python `int` scalar representing the number of blocks for the + MADE masks. + exclusive: Python `bool` scalar representing whether to zero the diagonal of + the mask, used for the first layer of a MADE. + kernel_initializer: Initializer function for the weight matrix. + If `None` (default), weights are initialized using the + `tf.glorot_random_initializer`. + reuse: Python `bool` scalar representing whether to reuse the weights of a + previous layer by the same name. + name: Python `str` used to describe ops managed by this function. + *args: `tf.layers.dense` arguments. + **kwargs: `tf.layers.dense` keyword arguments. + + Returns: + Output tensor. + + Raises: + NotImplementedError: if rightmost dimension of `inputs` is unknown prior to + graph execution. + """ + # TODO(b/67594795): Better support of dynamic shape. + input_depth = inputs.shape.with_rank_at_least(1)[-1].value + if input_depth is None: + raise NotImplementedError( + "Rightmost dimension must be known prior to graph execution.") + + mask = _gen_mask(num_blocks, input_depth, units, + MASK_EXCLUSIVE if exclusive else MASK_INCLUSIVE).T + + if kernel_initializer is None: + kernel_initializer = init_ops.glorot_normal_initializer() + + def masked_initializer(shape, dtype=None, partition_info=None): + return mask * kernel_initializer(shape, dtype, partition_info) + + with ops.name_scope(name, "masked_dense", [inputs, units, num_blocks]): + layer = layers.Dense( + units, + kernel_initializer=masked_initializer, + kernel_constraint=lambda x: mask * x, + name=name, + dtype=inputs.dtype.base_dtype, + _scope=name, + _reuse=reuse, + *args, + **kwargs) + return layer.apply(inputs) + + +def masked_autoregressive_default_template( + hidden_layers, + shift_only=False, + activation=nn_ops.relu, + log_scale_min_clip=-5., + log_scale_max_clip=3., + log_scale_clip_gradient=False, + name=None, + *args, + **kwargs): + """Build the MADE Model [1]. + + This will be wrapped in a make_template to ensure the variables are only + created once. It takes the input and returns the `loc` ("mu" [1]) and + `log_scale` ("alpha" [1]) from the MADE network. + + Warning: This function uses `masked_dense` to create randomly initialized + `tf.Variables`. It is presumed that these will be fit, just as you would any + other neural architecture which uses `tf.layers.dense`. + + #### About Hidden Layers: + + Each element of `hidden_layers` should be greater than the `input_depth` + (i.e., `input_depth = tf.shape(input)[-1]` where `input` is the input to the + neural network). This is necessary to ensure the autoregressivity property. + + #### About Clipping: + + This function also optionally clips the `log_scale` (but possibly not its + gradient). This is useful because if `log_scale` is too small/large it might + underflow/overflow making it impossible for the `MaskedAutoregressiveFlow` + bijector to implement a bijection. Additionally, the `log_scale_clip_gradient` + `bool` indicates whether the gradient should also be clipped. The default does + not clip the gradient; this is useful because it still provides gradient + information (for fitting) yet solves the numerical stability problem. I.e., + `log_scale_clip_gradient = False` means + `grad[exp(clip(x))] = grad[x] exp(clip(x))` rather than the usual + `grad[clip(x)] exp(clip(x))`. + + [1]: "MADE: Masked Autoencoder for Distribution Estimation." + Mathieu Germain, Karol Gregor, Iain Murray, Hugo Larochelle. ICML. 2015. + https://arxiv.org/abs/1502.03509 + + Arguments: + hidden_layers: Python `list`-like of non-negative integer, scalars + indicating the number of units in each hidden layer. Default: `[512, 512]. + shift_only: Python `bool` indicating if only the `shift` term shall be + computed. Default: `False`. + activation: Activation function (callable). Explicitly setting to `None` + implies a linear activation. + log_scale_min_clip: `float`-like scalar `Tensor`, or a `Tensor` with the + same shape as `log_scale`. The minimum value to clip by. Default: -5. + log_scale_max_clip: `float`-like scalar `Tensor`, or a `Tensor` with the + same shape as `log_scale`. The maximum value to clip by. Default: 3. + log_scale_clip_gradient: Python `bool` indicating that the gradient of + `tf.clip_by_value` should be preserved. Default: `False`. + name: A name for ops managed by this function. Default: + "masked_autoregressive_default_template". + *args: `tf.layers.dense` arguments. + **kwargs: `tf.layers.dense` keyword arguments. + + Returns: + shift: `Float`-like `Tensor` of shift terms (the "mu" in [2]). + log_scale: `Float`-like `Tensor` of log(scale) terms (the "alpha" in [2]). + + Raises: + NotImplementedError: if rightmost dimension of `inputs` is unknown prior to + graph execution. + """ + + with ops.name_scope(name, "masked_autoregressive_default_template", + values=[log_scale_min_clip, log_scale_max_clip]): + def _fn(x): + """MADE parameterized via `masked_autoregressive_default_template`.""" + # TODO(b/67594795): Better support of dynamic shape. + input_depth = x.shape.with_rank_at_least(1)[-1].value + if input_depth is None: + raise NotImplementedError( + "Rightmost dimension must be known prior to graph execution.") + input_shape = (np.int32(x.shape.as_list()) if x.shape.is_fully_defined() + else array_ops.shape(x)) + for i, units in enumerate(hidden_layers): + x = masked_dense( + inputs=x, + units=units, + num_blocks=input_depth, + exclusive=True if i == 0 else False, + activation=activation, + *args, + **kwargs) + x = masked_dense( + inputs=x, + units=(1 if shift_only else 2) * input_depth, + num_blocks=input_depth, + activation=None, + *args, + **kwargs) + if shift_only: + x = array_ops.reshape(x, shape=input_shape) + return x, None + x = array_ops.reshape( + x, shape=array_ops.concat([input_shape, [2]], axis=0)) + shift, log_scale = array_ops.unstack(x, num=2, axis=-1) + which_clip = (math_ops.clip_by_value if log_scale_clip_gradient + else _clip_by_value_preserve_grad) + log_scale = which_clip(log_scale, log_scale_min_clip, log_scale_max_clip) + return shift, log_scale + return template_ops.make_template( + "masked_autoregressive_default_template", _fn) + + +def _clip_by_value_preserve_grad(x, clip_value_min, clip_value_max, name=None): + """Clips input while leaving gradient unaltered.""" + with ops.name_scope(name, "clip_by_value_preserve_grad", + [x, clip_value_min, clip_value_max]): + clip_x = clip_ops.clip_by_value(x, clip_value_min, clip_value_max) + return x + array_ops.stop_gradient(clip_x - x) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive_impl.py deleted file mode 100644 index ae142883931274b594dbbafbe86bd71e75c621bc..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive_impl.py +++ /dev/null @@ -1,473 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""MaskedAutoregressiveFlow bijector.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.layers import core as layers -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import clip_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import init_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import nn_ops -from tensorflow.python.ops import template as template_ops -from tensorflow.python.ops import variable_scope as variable_scope_lib -from tensorflow.python.ops.distributions import bijector as bijector_lib - - -__all__ = [ - "MaskedAutoregressiveFlow", - "masked_autoregressive_default_template", - "masked_dense", -] - - -class MaskedAutoregressiveFlow(bijector_lib.Bijector): - """Affine MaskedAutoregressiveFlow bijector for vector-valued events. - - The affine autoregressive flow [1] provides a relatively simple framework for - user-specified (deep) architectures to learn a distribution over vector-valued - events. Regarding terminology, - - "Autoregressive models decompose the joint density as a product of - conditionals, and model each conditional in turn. Normalizing flows - transform a base density (e.g. a standard Gaussian) into the target density - by an invertible transformation with tractable Jacobian." [1] - - In other words, the "autoregressive property" is equivalent to the - decomposition, `p(x) = prod{ p(x[i] | x[0:i]) : i=0, ..., d }`. The provided - `shift_and_log_scale_fn`, `masked_autoregressive_default_template`, achieves - this property by zeroing out weights in its `masked_dense` layers. - - In the `tf.distributions` framework, a "normalizing flow" is implemented as a - `tf.distributions.bijectors.Bijector`. The `forward` "autoregression" - is implemented using a `tf.while_loop` and a deep neural network (DNN) with - masked weights such that the autoregressive property is automatically met in - the `inverse`. - - A `TransformedDistribution` using `MaskedAutoregressiveFlow(...)` uses the - (expensive) forward-mode calculation to draw samples and the (cheap) - reverse-mode calculation to compute log-probabilities. Conversely, a - `TransformedDistribution` using `Invert(MaskedAutoregressiveFlow(...))` uses - the (expensive) forward-mode calculation to compute log-probabilities and the - (cheap) reverse-mode calculation to compute samples. See "Example Use" - [below] for more details. - - Given a `shift_and_log_scale_fn`, the forward and inverse transformations are - (a sequence of) affine transformations. A "valid" `shift_and_log_scale_fn` - must compute each `shift` (aka `loc` or "mu" [2]) and `log(scale)` (aka - "alpha" [2]) such that each are broadcastable with the arguments to `forward` - and `inverse`, i.e., such that the calculations in `forward`, `inverse` - [below] are possible. - - For convenience, `masked_autoregressive_default_template` is offered as a - possible `shift_and_log_scale_fn` function. It implements the MADE - architecture [2]. MADE is a feed-forward network that computes a `shift` and - `log(scale)` using `masked_dense` layers in a deep neural network. Weights are - masked to ensure the autoregressive property. It is possible that this - architecture is suboptimal for your task. To build alternative networks, - either change the arguments to `masked_autoregressive_default_template`, use - the `masked_dense` function to roll-out your own, or use some other - architecture, e.g., using `tf.layers`. - - Warning: no attempt is made to validate that the `shift_and_log_scale_fn` - enforces the "autoregressive property". - - Assuming `shift_and_log_scale_fn` has valid shape and autoregressive - semantics, the forward transformation is, - - ```python - def forward(x): - y = zeros_like(x) - event_size = x.shape[-1] - for _ in range(event_size): - shift, log_scale = shift_and_log_scale_fn(y) - y = x * math_ops.exp(log_scale) + shift - return y - ``` - - and the inverse transformation is, - - ```python - def inverse(y): - shift, log_scale = shift_and_log_scale_fn(y) - return (y - shift) / math_ops.exp(log_scale) - ``` - - Notice that the `inverse` does not need a for-loop. This is because in the - forward pass each calculation of `shift` and `log_scale` is based on the `y` - calculated so far (not `x`). In the `inverse`, the `y` is fully known, thus is - equivalent to the scaling used in `forward` after `event_size` passes, i.e., - the "last" `y` used to compute `shift`, `log_scale`. (Roughly speaking, this - also proves the transform is bijective.) - - #### Example Use - - ```python - ds = tf.contrib.distributions - bs = tf.contrib.distributions.bijectors - - dims = 5 - - # A common choice for a normalizing flow is to use a Gaussian for the base - # distribution. (However, any continuous distribution would work.) E.g., - maf = ds.TransformedDistribution( - distribution=ds.Normal(loc=0., scale=1.), - bijector=bs.MaskedAutoregressiveFlow( - shift_and_log_scale_fn=bs.masked_autoregressive_default_template( - hidden_layers=[512, 512])), - event_shape=[dims]) - - x = maf.sample() # Expensive; uses `tf.while_loop`, no Bijector caching. - maf.log_prob(x) # Almost free; uses Bijector caching. - maf.log_prob(0.) # Cheap; no `tf.while_loop` despite no Bijector caching. - - # [1] also describes an "Inverse Autoregressive Flow", e.g., - iaf = ds.TransformedDistribution( - distribution=ds.Normal(loc=0., scale=1.), - bijector=bs.Invert(bs.MaskedAutoregressiveFlow( - shift_and_log_scale_fn=bs.masked_autoregressive_default_template( - hidden_layers=[512, 512]))), - event_shape=[dims]) - - x = iaf.sample() # Cheap; no `tf.while_loop` despite no Bijector caching. - iaf.log_prob(x) # Almost free; uses Bijector caching. - iaf.log_prob(0.) # Expensive; uses `tf.while_loop`, no Bijector caching. - - # In many (if not most) cases the default `shift_and_log_scale_fn` will be a - # poor choice. Here's an example of using a "shift only" version and with a - # different number/depth of hidden layers. - shift_only = True - maf_no_scale_hidden2 = ds.TransformedDistribution( - distribution=ds.Normal(loc=0., scale=1.), - bijector=bs.MaskedAutoregressiveFlow( - bs.masked_autoregressive_default_template( - hidden_layers=[32], - shift_only=shift_only), - is_constant_jacobian=shift_only), - event_shape=[dims]) - ``` - - [1]: "Masked Autoregressive Flow for Density Estimation." - George Papamakarios, Theo Pavlakou, Iain Murray. Arxiv. 2017. - https://arxiv.org/abs/1705.07057 - - [2]: "MADE: Masked Autoencoder for Distribution Estimation." - Mathieu Germain, Karol Gregor, Iain Murray, Hugo Larochelle. ICML. 2015. - https://arxiv.org/abs/1502.03509 - - """ - - def __init__(self, - shift_and_log_scale_fn, - is_constant_jacobian=False, - validate_args=False, - name=None): - """Creates the MaskedAutoregressiveFlow bijector. - - Args: - shift_and_log_scale_fn: Python `callable` which computes `shift` and - `log_scale` from both the forward domain (`x`) and the inverse domain - (`y`). Calculation must respect the "autoregressive property" (see class - docstring). Suggested default - `masked_autoregressive_default_template(hidden_layers=...)`. - Typically the function contains `tf.Variables` and is wrapped using - `tf.make_template`. Returning `None` for either (both) `shift`, - `log_scale` is equivalent to (but more efficient than) returning zero. - is_constant_jacobian: Python `bool`. Default: `False`. When `True` the - implementation assumes `log_scale` does not depend on the forward domain - (`x`) or inverse domain (`y`) values. (No validation is made; - `is_constant_jacobian=False` is always safe but possibly computationally - inefficient.) - validate_args: Python `bool` indicating whether arguments should be - checked for correctness. - name: Python `str`, name given to ops managed by this object. - """ - name = name or "masked_autoregressive_flow" - self._shift_and_log_scale_fn = shift_and_log_scale_fn - super(MaskedAutoregressiveFlow, self).__init__( - is_constant_jacobian=is_constant_jacobian, - validate_args=validate_args, - name=name) - - def _forward(self, x): - event_size = array_ops.shape(x)[-1] - def _loop_body(index, y0): - """While-loop body for autoregression calculation.""" - # Set caching device to avoid re-getting the tf.Variable for every while - # loop iteration. - with variable_scope_lib.variable_scope( - variable_scope_lib.get_variable_scope()) as vs: - if vs.caching_device is None: - vs.set_caching_device(lambda op: op.device) - shift, log_scale = self._shift_and_log_scale_fn(y0) - y = x - if log_scale is not None: - y *= math_ops.exp(log_scale) - if shift is not None: - y += shift - return index + 1, y - _, y = control_flow_ops.while_loop( - cond=lambda index, _: index < event_size, - body=_loop_body, - loop_vars=[0, array_ops.zeros_like(x, name="y0")]) - return y - - def _inverse(self, y): - shift, log_scale = self._shift_and_log_scale_fn(y) - x = y - if shift is not None: - x -= shift - if log_scale is not None: - x *= math_ops.exp(-log_scale) - return x - - def _inverse_log_det_jacobian(self, y): - _, log_scale = self._shift_and_log_scale_fn(y) - if log_scale is None: - return constant_op.constant(0., dtype=y.dtype, name="ildj") - return -math_ops.reduce_sum(log_scale, axis=-1) - - -MASK_INCLUSIVE = "inclusive" -MASK_EXCLUSIVE = "exclusive" - - -def _gen_slices(num_blocks, n_in, n_out, mask_type=MASK_EXCLUSIVE): - """Generate the slices for building an autoregressive mask.""" - # TODO(b/67594795): Better support of dynamic shape. - slices = [] - col = 0 - d_in = n_in // num_blocks - d_out = n_out // num_blocks - row = d_out if mask_type == MASK_EXCLUSIVE else 0 - for _ in range(num_blocks): - row_slice = slice(row, None) - col_slice = slice(col, col + d_in) - slices.append([row_slice, col_slice]) - col += d_in - row += d_out - return slices - - -def _gen_mask(num_blocks, - n_in, - n_out, - mask_type=MASK_EXCLUSIVE, - dtype=dtypes.float32): - """Generate the mask for building an autoregressive dense layer.""" - # TODO(b/67594795): Better support of dynamic shape. - mask = np.zeros([n_out, n_in], dtype=dtype.as_numpy_dtype()) - slices = _gen_slices(num_blocks, n_in, n_out, mask_type=mask_type) - for [row_slice, col_slice] in slices: - mask[row_slice, col_slice] = 1 - return mask - - -def masked_dense(inputs, - units, - num_blocks=None, - exclusive=False, - kernel_initializer=None, - reuse=None, - name=None, - *args, - **kwargs): - """A autoregressively masked dense layer. Analogous to `tf.layers.dense`. - - See [1] for detailed explanation. - - [1]: "MADE: Masked Autoencoder for Distribution Estimation." - Mathieu Germain, Karol Gregor, Iain Murray, Hugo Larochelle. ICML. 2015. - https://arxiv.org/abs/1502.03509 - - Arguments: - inputs: Tensor input. - units: Python `int` scalar representing the dimensionality of the output - space. - num_blocks: Python `int` scalar representing the number of blocks for the - MADE masks. - exclusive: Python `bool` scalar representing whether to zero the diagonal of - the mask, used for the first layer of a MADE. - kernel_initializer: Initializer function for the weight matrix. - If `None` (default), weights are initialized using the - `tf.glorot_random_initializer`. - reuse: Python `bool` scalar representing whether to reuse the weights of a - previous layer by the same name. - name: Python `str` used to describe ops managed by this function. - *args: `tf.layers.dense` arguments. - **kwargs: `tf.layers.dense` keyword arguments. - - Returns: - Output tensor. - - Raises: - NotImplementedError: if rightmost dimension of `inputs` is unknown prior to - graph execution. - """ - # TODO(b/67594795): Better support of dynamic shape. - input_depth = inputs.shape.with_rank_at_least(1)[-1].value - if input_depth is None: - raise NotImplementedError( - "Rightmost dimension must be known prior to graph execution.") - - mask = _gen_mask(num_blocks, input_depth, units, - MASK_EXCLUSIVE if exclusive else MASK_INCLUSIVE).T - - if kernel_initializer is None: - kernel_initializer = init_ops.glorot_normal_initializer() - - def masked_initializer(shape, dtype=None, partition_info=None): - return mask * kernel_initializer(shape, dtype, partition_info) - - with ops.name_scope(name, "masked_dense", [inputs, units, num_blocks]): - layer = layers.Dense( - units, - kernel_initializer=masked_initializer, - kernel_constraint=lambda x: mask * x, - name=name, - dtype=inputs.dtype.base_dtype, - _scope=name, - _reuse=reuse, - *args, - **kwargs) - return layer.apply(inputs) - - -def masked_autoregressive_default_template( - hidden_layers, - shift_only=False, - activation=nn_ops.relu, - log_scale_min_clip=-5., - log_scale_max_clip=3., - log_scale_clip_gradient=False, - name=None, - *args, - **kwargs): - """Build the MADE Model [1]. - - This will be wrapped in a make_template to ensure the variables are only - created once. It takes the input and returns the `loc` ("mu" [1]) and - `log_scale` ("alpha" [1]) from the MADE network. - - Warning: This function uses `masked_dense` to create randomly initialized - `tf.Variables`. It is presumed that these will be fit, just as you would any - other neural architecture which uses `tf.layers.dense`. - - #### About Hidden Layers: - - Each element of `hidden_layers` should be greater than the `input_depth` - (i.e., `input_depth = tf.shape(input)[-1]` where `input` is the input to the - neural network). This is necessary to ensure the autoregressivity property. - - #### About Clipping: - - This function also optionally clips the `log_scale` (but possibly not its - gradient). This is useful because if `log_scale` is too small/large it might - underflow/overflow making it impossible for the `MaskedAutoregressiveFlow` - bijector to implement a bijection. Additionally, the `log_scale_clip_gradient` - `bool` indicates whether the gradient should also be clipped. The default does - not clip the gradient; this is useful because it still provides gradient - information (for fitting) yet solves the numerical stability problem. I.e., - `log_scale_clip_gradient = False` means - `grad[exp(clip(x))] = grad[x] exp(clip(x))` rather than the usual - `grad[clip(x)] exp(clip(x))`. - - [1]: "MADE: Masked Autoencoder for Distribution Estimation." - Mathieu Germain, Karol Gregor, Iain Murray, Hugo Larochelle. ICML. 2015. - https://arxiv.org/abs/1502.03509 - - Arguments: - hidden_layers: Python `list`-like of non-negative integer, scalars - indicating the number of units in each hidden layer. Default: `[512, 512]. - shift_only: Python `bool` indicating if only the `shift` term shall be - computed. Default: `False`. - activation: Activation function (callable). Explicitly setting to `None` - implies a linear activation. - log_scale_min_clip: `float`-like scalar `Tensor`, or a `Tensor` with the - same shape as `log_scale`. The minimum value to clip by. Default: -5. - log_scale_max_clip: `float`-like scalar `Tensor`, or a `Tensor` with the - same shape as `log_scale`. The maximum value to clip by. Default: 3. - log_scale_clip_gradient: Python `bool` indicating that the gradient of - `tf.clip_by_value` should be preserved. Default: `False`. - name: A name for ops managed by this function. Default: - "masked_autoregressive_default_template". - *args: `tf.layers.dense` arguments. - **kwargs: `tf.layers.dense` keyword arguments. - - Returns: - shift: `Float`-like `Tensor` of shift terms (the "mu" in [2]). - log_scale: `Float`-like `Tensor` of log(scale) terms (the "alpha" in [2]). - - Raises: - NotImplementedError: if rightmost dimension of `inputs` is unknown prior to - graph execution. - """ - - with ops.name_scope(name, "masked_autoregressive_default_template", - values=[log_scale_min_clip, log_scale_max_clip]): - def _fn(x): - """MADE parameterized via `masked_autoregressive_default_template`.""" - # TODO(b/67594795): Better support of dynamic shape. - input_depth = x.shape.with_rank_at_least(1)[-1].value - if input_depth is None: - raise NotImplementedError( - "Rightmost dimension must be known prior to graph execution.") - input_shape = (np.int32(x.shape.as_list()) if x.shape.is_fully_defined() - else array_ops.shape(x)) - for i, units in enumerate(hidden_layers): - x = masked_dense( - inputs=x, - units=units, - num_blocks=input_depth, - exclusive=True if i == 0 else False, - activation=activation, - *args, - **kwargs) - x = masked_dense( - inputs=x, - units=(1 if shift_only else 2) * input_depth, - num_blocks=input_depth, - activation=None, - *args, - **kwargs) - if shift_only: - x = array_ops.reshape(x, shape=input_shape) - return x, None - x = array_ops.reshape( - x, shape=array_ops.concat([input_shape, [2]], axis=0)) - shift, log_scale = array_ops.unstack(x, num=2, axis=-1) - which_clip = (math_ops.clip_by_value if log_scale_clip_gradient - else _clip_by_value_preserve_grad) - log_scale = which_clip(log_scale, log_scale_min_clip, log_scale_max_clip) - return shift, log_scale - return template_ops.make_template( - "masked_autoregressive_default_template", _fn) - - -def _clip_by_value_preserve_grad(x, clip_value_min, clip_value_max, name=None): - """Clips input while leaving gradient unaltered.""" - with ops.name_scope(name, "clip_by_value_preserve_grad", - [x, clip_value_min, clip_value_max]): - clip_x = clip_ops.clip_by_value(x, clip_value_min, clip_value_max) - return x + array_ops.stop_gradient(clip_x - x) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/permute.py b/tensorflow/contrib/distributions/python/ops/bijectors/permute.py index a187ce22d686ee1203802ae2bfe64b0e1a3ea850..8654cc39d0c41ec4f1b85cd5fc4366ceaf4b224d 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/permute.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/permute.py @@ -12,18 +12,127 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Permute bijector.""" +"""Permutation bijectors.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.distributions.python.ops.bijectors.permute_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented +import numpy as np -_allowed_symbols = ["Permute"] +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops.distributions import bijector as bijector_lib -remove_undocumented(__name__, _allowed_symbols) + +__all__ = [ + "Permute", +] + + +class Permute(bijector_lib.Bijector): + """Permutes the rightmost dimension of a `Tensor`. + + ```python + tfd = tf.contrib.distributions + + reverse = tfd.bijectors.Permute(permutation=[2, 1, 0]) + + reverse.forward([-1., 0., 1.]) + # ==> [1., 0., -1] + + reverse.inverse([1., 0., -1]) + # ==> [-1., 0., 1.] + + reverse.forward_log_det_jacobian(any_value) + # ==> 0. + + reverse.inverse_log_det_jacobian(any_value) + # ==> 0. + ``` + + Warning: `tf.estimator` may repeatedly build the graph thus + `Permute(np.random.permutation(event_size)).astype("int32"))` is not a + reliable parameterization (nor would it be even if using `tf.constant`). A + safe alternative is to use `tf.get_variable` to achieve "init once" behavior, + i.e., + + ```python + def init_once(x, name): + return tf.get_variable(name, initializer=x, trainable=False) + + Permute(permutation=init_once( + np.random.permutation(event_size).astype("int32"), + name="permutation")) + ``` + + """ + + def __init__(self, permutation, validate_args=False, name=None): + """Creates the `Permute` bijector. + + Args: + permutation: An `int`-like vector-shaped `Tensor` representing the + permutation to apply to the rightmost dimension of the transformed + `Tensor`. + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str`, name given to ops managed by this object. + + Raises: + TypeError: if `not permutation.dtype.is_integer`. + ValueError: if `permutation` does not contain exactly one of each of + `{0, 1, ..., d}`. + """ + with ops.name_scope(name, "permute", values=[permutation]): + permutation = ops.convert_to_tensor( + permutation, + name="permutation") + if not permutation.dtype.is_integer: + raise TypeError("permutation.dtype ({}) should be `int`-like.".format( + permutation.dtype.name)) + p = tensor_util.constant_value(permutation) + if p is not None: + if set(p) != set(np.arange(p.size)): + raise ValueError("Permutation over `d` must contain exactly one of " + "each of `{0, 1, ..., d}`.") + elif validate_args: + p, _ = nn_ops.top_k(-permutation, + k=array_ops.shape(permutation)[-1], + sorted=True) + permutation = control_flow_ops.with_dependencies([ + check_ops.assert_equal( + -p, math_ops.range(array_ops.size(p)), + message=("Permutation over `d` must contain exactly one of " + "each of `{0, 1, ..., d}`.")), + ], permutation) + self._permutation = permutation + super(Permute, self).__init__( + is_constant_jacobian=True, + validate_args=validate_args, + name=name or "permute") + + @property + def permutation(self): + return self._permutation + + def _forward(self, x): + return array_ops.gather(x, self.permutation, axis=-1) + + def _inverse(self, y): + return array_ops.gather( + y, + array_ops.invert_permutation(self.permutation), + axis=-1) + + def _inverse_log_det_jacobian(self, y): + return constant_op.constant(0., dtype=y.dtype) + + def _forward_log_det_jacobian(self, x): + return constant_op.constant(0., dtype=x.dtype) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/permute_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/permute_impl.py deleted file mode 100644 index b1d8f2f41b28a88208a19824377f93882b767f03..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distributions/python/ops/bijectors/permute_impl.py +++ /dev/null @@ -1,138 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Permutation bijectors.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_util -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import check_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import nn_ops -from tensorflow.python.ops.distributions import bijector as bijector_lib - - -__all__ = [ - "Permute", -] - - -class Permute(bijector_lib.Bijector): - """Permutes the rightmost dimension of a `Tensor`. - - ```python - bs = tf.contrib.distributions.bijectors - - reverse = bs.Permute(permutation=[2, 1, 0]) - - reverse.forward([-1., 0., 1.]) - # ==> [1., 0., -1] - - reverse.inverse([1., 0., -1]) - # ==> [-1., 0., 1.] - - reverse.forward_log_det_jacobian(any_value) - # ==> 0. - - reverse.inverse_log_det_jacobian(any_value) - # ==> 0. - ``` - - Warning: `tf.estimator` may repeatedly build the graph thus - `Permute(np.random.permutation(event_size)).astype("int32"))` is not a - reliable parameterization (nor would it be even if using `tf.constant`). A - safe alternative is to use `tf.get_variable` to achieve "init once" behavior, - i.e., - - ```python - def init_once(x, name): - return tf.get_variable(name, initializer=x, trainable=False) - - Permute(permutation=init_once( - np.random.permutation(event_size).astype("int32"), - name="permutation")) - ``` - - """ - - def __init__(self, permutation, validate_args=False, name=None): - """Creates the `Permute` bijector. - - Args: - permutation: An `int`-like vector-shaped `Tensor` representing the - permutation to apply to the rightmost dimension of the transformed - `Tensor`. - validate_args: Python `bool` indicating whether arguments should be - checked for correctness. - name: Python `str`, name given to ops managed by this object. - - Raises: - TypeError: if `not permutation.dtype.is_integer`. - ValueError: if `permutation` does not contain exactly one of each of - `{0, 1, ..., d}`. - """ - with ops.name_scope(name, "permute", values=[permutation]): - permutation = ops.convert_to_tensor( - permutation, - name="permutation") - if not permutation.dtype.is_integer: - raise TypeError("permutation.dtype ({}) should be `int`-like.".format( - permutation.dtype.name)) - p = tensor_util.constant_value(permutation) - if p is not None: - if set(p) != set(np.arange(p.size)): - raise ValueError("Permutation over `d` must contain exactly one of " - "each of `{0, 1, ..., d}`.") - elif validate_args: - p, _ = nn_ops.top_k(-permutation, - k=array_ops.shape(permutation)[-1], - sorted=True) - permutation = control_flow_ops.with_dependencies([ - check_ops.assert_equal( - -p, math_ops.range(array_ops.size(p)), - message=("Permutation over `d` must contain exactly one of " - "each of `{0, 1, ..., d}`.")), - ], permutation) - self._permutation = permutation - super(Permute, self).__init__( - is_constant_jacobian=True, - validate_args=validate_args, - name=name or "permute") - - @property - def permutation(self): - return self._permutation - - def _forward(self, x): - return array_ops.gather(x, self.permutation, axis=-1) - - def _inverse(self, y): - return array_ops.gather( - y, - array_ops.invert_permutation(self.permutation), - axis=-1) - - def _inverse_log_det_jacobian(self, y): - return constant_op.constant(0., dtype=y.dtype) - - def _forward_log_det_jacobian(self, x): - return constant_op.constant(0., dtype=x.dtype) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/power_transform.py b/tensorflow/contrib/distributions/python/ops/bijectors/power_transform.py index a83199549cd16101ab7b39b43d19a17bc66f03df..c37db61720d10949f294ff7b2e9778ba6efa57f0 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/power_transform.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/power_transform.py @@ -18,12 +18,110 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.distributions.python.ops.bijectors.power_transform_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import bijector -_allowed_symbols = ["PowerTransform"] -remove_undocumented(__name__, _allowed_symbols) +__all__ = [ + "PowerTransform", +] + + +class PowerTransform(bijector.Bijector): + """Compute `Y = g(X) = (1 + X * c)**(1 / c), X >= -1 / c`. + + The [power transform](https://en.wikipedia.org/wiki/Power_transform) maps + inputs from `[0, inf]` to `[-1/c, inf]`; this is equivalent to the `inverse` + of this bijector. + + This bijector is equivalent to the `Exp` bijector when `c=0`. + """ + + def __init__(self, + power=0., + event_ndims=0, + validate_args=False, + name="power_transform"): + """Instantiates the `PowerTransform` bijector. + + Args: + power: Python `float` scalar indicating the transform power, i.e., + `Y = g(X) = (1 + X * c)**(1 / c)` where `c` is the `power`. + event_ndims: Python scalar indicating the number of dimensions associated + with a particular draw from the distribution. + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str` name given to ops managed by this object. + + Raises: + ValueError: if `power < 0` or is not known statically. + """ + self._graph_parents = [] + self._name = name + self._validate_args = validate_args + with self._name_scope("init", values=[power]): + power = tensor_util.constant_value( + ops.convert_to_tensor(power, name="power")) + if power is None or power < 0: + raise ValueError("`power` must be a non-negative TF constant.") + self._power = power + super(PowerTransform, self).__init__( + event_ndims=event_ndims, + validate_args=validate_args, + name=name) + + @property + def power(self): + """The `c` in: `Y = g(X) = (1 + X * c)**(1 / c)`.""" + return self._power + + def _forward(self, x): + x = self._maybe_assert_valid_x(x) + if self.power == 0.: + return math_ops.exp(x) + # If large x accuracy is an issue, consider using: + # (1. + x * self.power)**(1. / self.power) when x >> 1. + return math_ops.exp(math_ops.log1p(x * self.power) / self.power) + + def _inverse(self, y): + y = self._maybe_assert_valid_y(y) + if self.power == 0.: + return math_ops.log(y) + # If large y accuracy is an issue, consider using: + # (y**self.power - 1.) / self.power when y >> 1. + return math_ops.expm1(math_ops.log(y) * self.power) / self.power + + def _inverse_log_det_jacobian(self, y): + y = self._maybe_assert_valid_y(y) + event_dims = self._event_dims_tensor(y) + return (self.power - 1.) * math_ops.reduce_sum( + math_ops.log(y), axis=event_dims) + + def _forward_log_det_jacobian(self, x): + x = self._maybe_assert_valid_x(x) + event_dims = self._event_dims_tensor(x) + if self.power == 0.: + return math_ops.reduce_sum(x, axis=event_dims) + return (1. / self.power - 1.) * math_ops.reduce_sum( + math_ops.log1p(x * self.power), + axis=event_dims) + + def _maybe_assert_valid_x(self, x): + if not self.validate_args or self.power == 0.: + return x + is_valid = check_ops.assert_non_negative( + 1. + self.power * x, + message="Forward transformation input must be at least {}.".format( + -1. / self.power)) + return control_flow_ops.with_dependencies([is_valid], x) + + def _maybe_assert_valid_y(self, y): + if not self.validate_args: + return y + is_valid = check_ops.assert_positive( + y, message="Inverse transformation input must be greater than 0.") + return control_flow_ops.with_dependencies([is_valid], y) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/power_transform_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/power_transform_impl.py deleted file mode 100644 index c37db61720d10949f294ff7b2e9778ba6efa57f0..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distributions/python/ops/bijectors/power_transform_impl.py +++ /dev/null @@ -1,127 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""PowerTransform bijector.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_util -from tensorflow.python.ops import check_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops.distributions import bijector - - -__all__ = [ - "PowerTransform", -] - - -class PowerTransform(bijector.Bijector): - """Compute `Y = g(X) = (1 + X * c)**(1 / c), X >= -1 / c`. - - The [power transform](https://en.wikipedia.org/wiki/Power_transform) maps - inputs from `[0, inf]` to `[-1/c, inf]`; this is equivalent to the `inverse` - of this bijector. - - This bijector is equivalent to the `Exp` bijector when `c=0`. - """ - - def __init__(self, - power=0., - event_ndims=0, - validate_args=False, - name="power_transform"): - """Instantiates the `PowerTransform` bijector. - - Args: - power: Python `float` scalar indicating the transform power, i.e., - `Y = g(X) = (1 + X * c)**(1 / c)` where `c` is the `power`. - event_ndims: Python scalar indicating the number of dimensions associated - with a particular draw from the distribution. - validate_args: Python `bool` indicating whether arguments should be - checked for correctness. - name: Python `str` name given to ops managed by this object. - - Raises: - ValueError: if `power < 0` or is not known statically. - """ - self._graph_parents = [] - self._name = name - self._validate_args = validate_args - with self._name_scope("init", values=[power]): - power = tensor_util.constant_value( - ops.convert_to_tensor(power, name="power")) - if power is None or power < 0: - raise ValueError("`power` must be a non-negative TF constant.") - self._power = power - super(PowerTransform, self).__init__( - event_ndims=event_ndims, - validate_args=validate_args, - name=name) - - @property - def power(self): - """The `c` in: `Y = g(X) = (1 + X * c)**(1 / c)`.""" - return self._power - - def _forward(self, x): - x = self._maybe_assert_valid_x(x) - if self.power == 0.: - return math_ops.exp(x) - # If large x accuracy is an issue, consider using: - # (1. + x * self.power)**(1. / self.power) when x >> 1. - return math_ops.exp(math_ops.log1p(x * self.power) / self.power) - - def _inverse(self, y): - y = self._maybe_assert_valid_y(y) - if self.power == 0.: - return math_ops.log(y) - # If large y accuracy is an issue, consider using: - # (y**self.power - 1.) / self.power when y >> 1. - return math_ops.expm1(math_ops.log(y) * self.power) / self.power - - def _inverse_log_det_jacobian(self, y): - y = self._maybe_assert_valid_y(y) - event_dims = self._event_dims_tensor(y) - return (self.power - 1.) * math_ops.reduce_sum( - math_ops.log(y), axis=event_dims) - - def _forward_log_det_jacobian(self, x): - x = self._maybe_assert_valid_x(x) - event_dims = self._event_dims_tensor(x) - if self.power == 0.: - return math_ops.reduce_sum(x, axis=event_dims) - return (1. / self.power - 1.) * math_ops.reduce_sum( - math_ops.log1p(x * self.power), - axis=event_dims) - - def _maybe_assert_valid_x(self, x): - if not self.validate_args or self.power == 0.: - return x - is_valid = check_ops.assert_non_negative( - 1. + self.power * x, - message="Forward transformation input must be at least {}.".format( - -1. / self.power)) - return control_flow_ops.with_dependencies([is_valid], x) - - def _maybe_assert_valid_y(self, y): - if not self.validate_args: - return y - is_valid = check_ops.assert_positive( - y, message="Inverse transformation input must be greater than 0.") - return control_flow_ops.with_dependencies([is_valid], y) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/reshape.py b/tensorflow/contrib/distributions/python/ops/bijectors/reshape.py index 8997f7ab6929745275edb38712a5bbb0a9b25ddb..55eca063126797d577653f0d6bcdfddf8192bdb5 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/reshape.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/reshape.py @@ -12,18 +12,303 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Reshape bijector.""" +"""Reshape bijectors.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.distributions.python.ops.bijectors.reshape_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented +import numpy as np -_allowed_symbols = ["Reshape"] +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import bijector as bijector_lib -remove_undocumented(__name__, _allowed_symbols) + +__all__ = [ + "Reshape", +] + + +def _static_ndims_from_shape(shape): + return shape.shape.with_rank_at_least(1)[0].value + + +def _ndims_from_shape(shape): + return array_ops.shape(shape)[0] + + +class Reshape(bijector_lib.Bijector): + """Reshapes the `event_shape` of a `Tensor`. + + The semantics generally follow that of `tf.reshape()`, with + a few differences: + + * The user must provide both the input and output shape, so that + the transformation can be inverted. If an input shape is not + specified, the default assumes a vector-shaped input, i.e., + event_shape_in = (-1,). + * The `Reshape` bijector automatically broadcasts over the leftmost + dimensions of its input (`sample_shape` and `batch_shape`); only + the rightmost `event_ndims_in` dimensions are reshaped. The + number of dimensions to reshape is inferred from the provided + `event_shape_in` (`event_ndims_in = len(event_shape_in)`). + + Example usage: + ```python + + tfd = tf.contrib.distributions + + r = tfd.bijectors.Reshape(event_shape_out=[1, -1]) + + r.forward([3., 4.]) # shape [2] + # ==> [[3., 4.]] # shape [1, 2] + + r.forward([[1., 2.], [3., 4.]]) # shape [2, 2] + # ==> [[[1., 2.]], + # [[3., 4.]]] # shape [2, 1, 2] + + r.inverse([[3., 4.]]) # shape [1,2] + # ==> [3., 4.] # shape [2] + + r.forward_log_det_jacobian(any_value) + # ==> 0. + + r.inverse_log_det_jacobian(any_value) + # ==> 0. + ``` + + """ + + def __init__(self, event_shape_out, event_shape_in=(-1,), + validate_args=False, name=None): + """Creates a `Reshape` bijector. + + Args: + event_shape_out: An `int`-like vector-shaped `Tensor` + representing the event shape of the transformed output. + event_shape_in: An optional `int`-like vector-shape `Tensor` + representing the event shape of the input. This is required in + order to define inverse operations; the default of (-1,) + assumes a vector-shaped input. + validate_args: Python `bool` indicating whether arguments should + be checked for correctness. + name: Python `str`, name given to ops managed by this object. + + Raises: + TypeError: if either `event_shape_in` or `event_shape_out` has + non-integer `dtype`. + ValueError: if either of `event_shape_in` or `event_shape_out` + has non-vector shape (`rank > 1`), or if their sizes do not + match. + """ + with ops.name_scope(name, "reshape", + values=[event_shape_out, event_shape_in]): + + event_shape_out = ops.convert_to_tensor(event_shape_out, + name="event_shape_out", + preferred_dtype=dtypes.int32) + event_shape_in = ops.convert_to_tensor(event_shape_in, + name="event_shape_in", + preferred_dtype=dtypes.int32) + + assertions = [] + assertions.extend(self._maybe_check_valid_shape( + event_shape_out, validate_args)) + assertions.extend(self._maybe_check_valid_shape( + event_shape_in, validate_args)) + + self._assertions = assertions + self._event_shape_in = event_shape_in + self._event_shape_out = event_shape_out + + super(Reshape, self).__init__(is_constant_jacobian=True, + validate_args=validate_args, + name=name or "reshape") + + def _maybe_check_valid_shape(self, shape, validate_args): + """Check that a shape Tensor is int-type and otherwise sane.""" + if not shape.dtype.is_integer: + raise TypeError("{} dtype ({}) should be `int`-like.".format( + shape.op.name, shape.dtype.name)) + + assertions = [] + + ndims = array_ops.rank(shape) + ndims_ = tensor_util.constant_value(ndims) + if ndims_ is not None and ndims_ > 1: + raise ValueError("`{}` rank ({}) should be <= 1.".format( + shape.op.name, ndims_)) + elif validate_args: + assertions.append(check_ops.assert_less_equal( + ndims, 1, message="`{}` rank should be <= 1.".format(shape.op.name))) + + shape_ = tensor_util.constant_value_as_shape(shape) + if shape_.is_fully_defined(): + es = np.int32(shape_.as_list()) + if sum(es == -1) > 1: + raise ValueError( + "`{}` must have at most one `-1` (given {})" + .format(shape.op.name, es)) + if np.any(es < -1): + raise ValueError( + "`{}` elements must be either positive integers or `-1`" + "(given {})." + .format(shape.op.name, es)) + elif validate_args: + assertions.extend([ + check_ops.assert_less_equal( + math_ops.reduce_sum( + math_ops.cast(math_ops.equal(shape, -1), dtypes.int32)), + 1, + message="`{}` elements must have at most one `-1`." + .format(shape.op.name)), + check_ops.assert_greater_equal( + shape, -1, + message="`{}` elements must be either positive integers or `-1`." + .format(shape.op.name)), + ]) + return assertions + + def _reshape_helper(self, x, event_shape_in, event_shape_out): + """Reshape only the event_shape of an input `Tensor`.""" + + event_ndims_in_ = _static_ndims_from_shape(event_shape_in) + event_ndims_in = _ndims_from_shape(event_shape_in) + x_ndims_, x_ndims = x.shape.ndims, array_ops.rank(x) + + assertions = [] + + # Ensure x.event_shape is compatible with event_shape_in. + if (event_ndims_in_ is not None + and x_ndims_ is not None + and x.shape.with_rank_at_least(event_ndims_in_)[ + x_ndims_-event_ndims_in_:].is_fully_defined()): + x_event_shape_, x_event_shape = [ # pylint: disable=unbalanced-tuple-unpacking + np.int32(x.shape[x_ndims_-event_ndims_in_:])]*2 + else: + x_event_shape_, x_event_shape = ( + None, array_ops.shape(x)[x_ndims-event_ndims_in:]) + + event_shape_in_ = tensor_util.constant_value(event_shape_in) + + if x_event_shape_ is not None and event_shape_in_ is not None: + # Compare the shape dimensions that are fully specified in the + # input (i.e., for which event_shape_in is not -1). If x_event_shape + # matches along all of these dimensions, it is compatible with + # the desired input shape and any further mismatches (i.e., + # imcompatibility with the desired *output* shape) will be + # caught inside of array_ops.reshape() below. + x_event_shape_specified_ = x_event_shape_[event_shape_in_ >= 0] + event_shape_in_specified_ = event_shape_in_[event_shape_in_ >= 0] + if not np.equal(x_event_shape_specified_, + event_shape_in_specified_).all(): + raise ValueError( + "Input `event_shape` does not match `event_shape_in` ({} vs {}).". + format(x_event_shape_, event_shape_in_)) + elif self.validate_args: + # Similarly to the static case, we compare the shape dimensions + # that are fully specified in the input. We extract these + # dimensions using boolean_mask(), which requires that the mask + # have known ndims. We can assume that shape Tensors always have + # ndims==1 (this assumption is verified inside of + # _maybe_check_valid_shape), so the reshape operation is just a + # no-op that formally encodes this fact to make boolean_mask() + # happy. + event_shape_mask = array_ops.reshape(event_shape_in >= 0, [-1]) + x_event_shape_specified = array_ops.boolean_mask(x_event_shape, + event_shape_mask) + event_shape_in_specified = array_ops.boolean_mask(event_shape_in, + event_shape_mask) + assertions.append(check_ops.assert_equal( + x_event_shape_specified, event_shape_in_specified, + message="Input `event_shape` does not match `event_shape_in`.")) + + if assertions: + x = control_flow_ops.with_dependencies(assertions, x) + + # get the parts of shape(x) that will not change + sample_and_batch_shape = array_ops.shape(x) + + ndims = (x.shape.ndims if x.shape.ndims is not None + else array_ops.rank(x)) + sample_and_batch_shape = sample_and_batch_shape[ + :(ndims - math_ops.abs(event_ndims_in))] + + if (event_ndims_in_ is not None + and x_ndims_ is not None + and event_ndims_in_ == x_ndims_): + # Hack to allow forward/inverse_event_shape to do shape + # inference by calling this helper method with a dummy Tensor of + # shape event_shape_in. In this special case, + # sample_and_batch_shape will be empty so we can preserve static + # shape information by avoiding the concat operation below + # (which would be a no-op). + new_shape = event_shape_out + else: + new_shape = array_ops.concat( + [sample_and_batch_shape, event_shape_out], axis=0) + + return array_ops.reshape(x, new_shape) + + def _forward(self, x): + with ops.control_dependencies(self._assertions): + return self._reshape_helper(x, + self._event_shape_in, + self._event_shape_out) + + def _inverse(self, y): + with ops.control_dependencies(self._assertions): + return self._reshape_helper(y, + self._event_shape_out, + self._event_shape_in) + + def _inverse_log_det_jacobian(self, y): + with ops.control_dependencies(self._assertions): + return constant_op.constant(0., dtype=y.dtype) + + def _forward_log_det_jacobian(self, x): + with ops.control_dependencies(self._assertions): + return constant_op.constant(0., dtype=x.dtype) + + def _forward_event_shape(self, input_shape): + # NOTE: this method and the other *_event_shape* methods + # compute shape by explicit transformation of a dummy + # variable. This approach is not generally recommended because it + # bloats the graph and could in general trigger side effects. + # + # In this particular case of the Reshape bijector, the + # forward and inverse transforms have no side effects, and we + # believe the reduction in code complexity from delegating the + # heavy lifting to tf.reshape() is worth the added graph ops. + # However, you should think hard before implementing this approach + # in other Bijectors; it is strongly preferred to compute + # shapes explicitly whenever it's feasible to do so. + with ops.control_dependencies(self._assertions): + dummy = array_ops.zeros(dtype=dtypes.float32, shape=input_shape) + dummy_reshaped = self.forward(dummy) + return dummy_reshaped.shape + + def _inverse_event_shape(self, output_shape): + with ops.control_dependencies(self._assertions): + dummy = array_ops.zeros(dtype=dtypes.float32, shape=output_shape) + dummy_reshaped = self.inverse(dummy) + return dummy_reshaped.shape + + def _forward_event_shape_tensor(self, input_shape): + with ops.control_dependencies(self._assertions): + dummy = array_ops.zeros(dtype=dtypes.float32, shape=input_shape) + dummy_reshaped = self.forward(dummy) + return array_ops.shape(dummy_reshaped) + + def _inverse_event_shape_tensor(self, output_shape): + with ops.control_dependencies(self._assertions): + dummy = array_ops.zeros(dtype=dtypes.float32, shape=output_shape) + dummy_reshaped = self.inverse(dummy) + return array_ops.shape(dummy_reshaped) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/reshape_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/reshape_impl.py deleted file mode 100644 index 93682639aa3be3b8f59a369dedb6ee773c468130..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distributions/python/ops/bijectors/reshape_impl.py +++ /dev/null @@ -1,297 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Reshape bijectors.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_util -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import check_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops.distributions import bijector as bijector_lib - - -__all__ = [ - "Reshape", -] - - -class Reshape(bijector_lib.Bijector): - """Reshapes the `event_shape` of a `Tensor`. - - The semantics generally follow that of `tf.reshape()`, with - a few differences: - * The user must provide both the input and output shape, so that - the transformation can be inverted. - * The `Reshape` bijector automatically broadcasts over the leftmost - dimensions of its input (`sample_shape` and `batch_shape`); only - the rightmost `event_ndims_in` dimensions are reshaped. The - number of dimensions to reshape is inferred from the provided - `event_shape_in` (`event_ndims_in = len(event_shape_in)`). - * The `Reshape` bijector does not currently support - partially-specified shapes, i.e., those with a dimension - implicitly specified by `-1`. - - Example usage: - ```python - - bs = tf.contrib.distributions.bijectors - - reverse = bs.Reshape(event_shape_out=[1,2], - event_shape_in=[2,]) - - reverse.forward([1., 2.]) # shape [2,] - # ==> [[1., 2.]] # shape [1,2] - - reverse.forward([[1., 2.], [3., 4.]]) # shape [2, 2] - # ==> [[[1., 2.]], [[3., 4.]]] # shape [2, 1, 2] - - reverse.inverse([[1., 2.]]) # shape [1,2] - # ==> [1., 2.] # shape [2,] - - reverse.forward_log_det_jacobian(any_value) - # ==> 0. - - reverse.inverse_log_det_jacobian(any_value) - # ==> 0. - ``` - - """ - - def __init__(self, event_shape_out, event_shape_in, - validate_args=False, name=None): - """Creates a `Reshape` bijector. - - Args: - event_shape_out: An `int`-like vector-shaped `Tensor` - representing the fully specified (no -1's) event shape of the - transformed output. - event_shape_in: An `int`-like vector-shaped `Tensor` - representing the fully specified (no -1's) event shape of the - input. - validate_args: Python `bool` indicating whether arguments should - be checked for correctness. - name: Python `str`, name given to ops managed by this object. - - Raises: - TypeError: if either `event_shape_in` or `event_shape_out` has - non-vector shape (`rank > 1`), or non-integer `dtype`. - ValueError: if either `event_shape_in` or `event_shape_out` - contains non-positive entries, or if their sizes do not match - (`prod(event_shape_in)` != `prod(event_shape_out)`), or if - their dimensionality(s) cannot be statically inferred. - """ - with ops.name_scope(name, "reshape", - values=[event_shape_out, event_shape_in]): - - event_shape_out = ops.convert_to_tensor(event_shape_out, - name="event_shape_out", - preferred_dtype=dtypes.int32) - event_shape_in = ops.convert_to_tensor(event_shape_in, - name="event_shape_in", - preferred_dtype=dtypes.int32) - - # check that input shapes are positive integers - assertions = [] - assertions += self._maybe_check_valid_shape( - event_shape_out, "event_shape_out", - validate_args=validate_args) - assertions += self._maybe_check_valid_shape( - event_shape_in, "event_shape_in", validate_args=validate_args) - - # check that prod(event_shape_in) = prod(event_shape_out) - assertions += self._maybe_check_matching_sizes( - event_shape_in, event_shape_out, validate_args=validate_args) - - self._assertions = assertions - self._event_shape_in = event_shape_in - self._event_shape_out = event_shape_out - self._event_shape_in_static = tensor_util.constant_value_as_shape( - event_shape_in) - self._event_shape_out_static = tensor_util.constant_value_as_shape( - event_shape_out) - - super(Reshape, self).__init__(is_constant_jacobian=True, - validate_args=validate_args, - name=name or "reshape") - - def _maybe_check_valid_shape(self, shape_tensor, label, - validate_args=False): - """Check that a shape Tensor is int-type and positive.""" - - assertions = [] - - if not shape_tensor.dtype.is_integer: - raise TypeError("{} dtype ({}) should be `int`-like.".format( - label, shape_tensor.dtype.name)) - - shape_rank = tensor_util.constant_value(array_ops.rank(shape_tensor)) - if shape_rank is not None and shape_rank > 1: - raise ValueError("{} rank should be <= 1.".format(label)) - - s = tensor_util.constant_value(shape_tensor) - if s is not None: - if (s <= 0).any(): - raise ValueError("{} entries must be positive, but found {}".format( - label, s)) - elif validate_args: - assertions.append(check_ops.assert_positive( - shape_tensor, message="{} entries must be positive".format(label))) - - return assertions - - def _maybe_check_matching_sizes(self, event_shape_in, event_shape_out, - validate_args=False): - """Check that prod(event_shape_in)==prod(event_shape_out).""" - - def _get_size_from_shape(shape): - """Computes size from a shape `Tensor`, statically if possible.""" - s = tensor_util.constant_value(shape) - if s is not None: - return [np.int32(np.prod(s))]*2 - return None, math_ops.reduce_prod(shape, name="size") - - # Ensure `event_shape_in` is compatible with `event_shape_out`. - event_size_in_, event_size_in = _get_size_from_shape( # pylint: disable=unbalanced-tuple-unpacking - event_shape_in) - event_size_out_, event_size_out = _get_size_from_shape( # pylint: disable=unbalanced-tuple-unpacking - event_shape_out) - - assertions = [] - if event_size_in_ is not None and event_size_out_ is not None: - if event_size_in_ != event_size_out_: - raise ValueError( - "Input `event_size` ({}) does not match output `event_size` ({}).". - format(event_size_in, event_size_out_)) - elif validate_args: - assertions.append(check_ops.assert_equal( - event_size_in, event_size_out, - message="Input/output `event_size`s do not match.")) - - return assertions - - def _reshape_helper(self, x, event_shape_in, event_shape_out): - """Reshape only the event_shape of an input `Tensor`.""" - - def _get_rank_from_shape(shape): - """Computes rank from a shape `Tensor`, statically if possible.""" - # Uses fact that rank is "shape of shape". - ndims = shape.shape.with_rank_at_least(1)[0].value - if ndims is not None: - return ndims, ndims - return None, array_ops.shape(shape)[0] - - event_ndims_in_, event_ndims_in = _get_rank_from_shape(event_shape_in) - - assertions = [] - # Ensure x.event_shape is compatible with event_shape_in. - if x.shape.ndims is not None: - x_ndims_, x_ndims = [x.shape.ndims]*2 - else: - x_ndims_, x_ndims = None, array_ops.rank(x) - - if (event_ndims_in_ is not None - and x_ndims_ is not None - and x.shape.with_rank_at_least(event_ndims_in_)[ - x_ndims_-event_ndims_in_:].is_fully_defined()): - x_event_shape_, x_event_shape = [ # pylint: disable=unbalanced-tuple-unpacking - np.int32(x.shape[x_ndims_-event_ndims_in_:])]*2 - else: - x_event_shape_, x_event_shape = ( - None, array_ops.shape(x)[x_ndims-event_ndims_in:]) - - event_shape_in_ = tensor_util.constant_value(event_shape_in) - - if x_event_shape_ is not None and event_shape_in_ is not None: - if not np.equal(x_event_shape_, event_shape_in_).all(): - raise ValueError( - "Input `event_shape` ({}) does not match `event_shape_in` ({}).". - format(x_event_shape_, event_shape_in_)) - elif self.validate_args: - assertions.append(check_ops.assert_equal( - x_event_shape, event_shape_in, - message="Input `event_shape` does not match `event_shape_in`.")) - - if assertions: - x = control_flow_ops.with_dependencies(assertions, x) - - # get the parts of shape(x) that will not change - sample_and_batch_shape = array_ops.shape(x) - - ndims = (x.shape.ndims if x.shape.ndims is not None - else array_ops.rank(x)) - sample_and_batch_shape = sample_and_batch_shape[ - :(ndims - math_ops.abs(event_ndims_in))] - - new_shape = array_ops.concat( - [sample_and_batch_shape, event_shape_out], axis=0) - - return array_ops.reshape(x, new_shape) - - def _forward(self, x): - with ops.control_dependencies(self._assertions): - return self._reshape_helper(x, - self._event_shape_in, - self._event_shape_out) - - def _inverse(self, y): - with ops.control_dependencies(self._assertions): - return self._reshape_helper(y, - self._event_shape_out, - self._event_shape_in) - - def _inverse_log_det_jacobian(self, y): - with ops.control_dependencies(self._assertions): - return constant_op.constant(0., dtype=y.dtype) - - def _forward_log_det_jacobian(self, x): - with ops.control_dependencies(self._assertions): - return constant_op.constant(0., dtype=x.dtype) - - def _forward_event_shape(self, input_shape): - self._event_shape_in_static.assert_is_compatible_with(input_shape) - return self._event_shape_out_static - - def _inverse_event_shape(self, output_shape): - self._event_shape_out_static.assert_is_compatible_with(output_shape) - return self._event_shape_in_static - - def _forward_event_shape_tensor(self, input_shape): - input_assertions = self._maybe_check_valid_shape( - input_shape, "input event shape", validate_args=self.validate_args) - input_assertions += self._maybe_check_matching_sizes( - input_shape, self._event_shape_out, - validate_args=self.validate_args) - - return control_flow_ops.with_dependencies( - input_assertions + self._assertions, self._event_shape_out) - - def _inverse_event_shape_tensor(self, output_shape): - - output_assertions = self._maybe_check_valid_shape( - output_shape, "output event shape", validate_args=self.validate_args) - output_assertions += self._maybe_check_matching_sizes( - output_shape, self._event_shape_in, validate_args=self.validate_args) - - return control_flow_ops.with_dependencies( - output_assertions + self._assertions, self._event_shape_in) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/sigmoid.py b/tensorflow/contrib/distributions/python/ops/bijectors/sigmoid.py index c20e76c0b7367369865faf973377201c8b8b17e6..a640dfe7dfbcce96261589c7fc49107deaefdd54 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/sigmoid.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/sigmoid.py @@ -18,12 +18,31 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.distributions.python.ops.bijectors.sigmoid_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops.distributions import bijector -_allowed_symbols = ["Sigmoid"] -remove_undocumented(__name__, _allowed_symbols) +__all__ = [ + "Sigmoid", +] + + +class Sigmoid(bijector.Bijector): + """Bijector which computes `Y = g(X) = 1 / (1 + exp(-X))`.""" + + def __init__(self, validate_args=False, name="sigmoid"): + super(Sigmoid, self).__init__( + event_ndims=0, validate_args=validate_args, name=name) + + def _forward(self, x): + return math_ops.sigmoid(x) + + def _inverse(self, y): + return math_ops.log(y) - math_ops.log1p(-y) + + def _inverse_log_det_jacobian(self, y): + return -math_ops.log(y) - math_ops.log1p(-y) + + def _forward_log_det_jacobian(self, x): + return -nn_ops.softplus(-x) - nn_ops.softplus(x) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/sigmoid_centered.py b/tensorflow/contrib/distributions/python/ops/bijectors/sigmoid_centered.py index 448125230d24066697624bce03fed71a2c2f00b1..223bc9d042c69be05b0e578835a31ed6e83c0c97 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/sigmoid_centered.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/sigmoid_centered.py @@ -18,12 +18,22 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.distributions.python.ops.bijectors.sigmoid_centered_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented +from tensorflow.contrib.distributions.python.ops.bijectors import softmax_centered -_allowed_symbols = ["SigmoidCentered"] -remove_undocumented(__name__, _allowed_symbols) +__all__ = [ + "SigmoidCentered", +] + + +class SigmoidCentered(softmax_centered.SoftmaxCentered): + """Bijector which computes Y = g(X) = exp([X 0]) / (1 + exp(-X)). + + Equivalent to: `bijector.SoftmaxCentered(event_ndims=0)`. + + See `bijector.SoftmaxCentered` for more details. + """ + + def __init__(self, validate_args=False, name="sigmoid_centered"): + super(SigmoidCentered, self).__init__( + event_ndims=0, validate_args=validate_args, name=name) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/sinh_arcsinh.py b/tensorflow/contrib/distributions/python/ops/bijectors/sinh_arcsinh.py index b3cf03c24612f5c618c71c0a8615f272acdf2d10..3a75e4ae9495793901b0da91a5aa3982aab35852 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/sinh_arcsinh.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/sinh_arcsinh.py @@ -18,12 +18,162 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.distributions.python.ops.bijectors.sinh_arcsinh_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented +import numpy as np -_allowed_symbols = ["SinhArcsinh"] +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import bijector -remove_undocumented(__name__, _allowed_symbols) +__all__ = [ + "SinhArcsinh", +] + + +def _sqrtx2p1(x): + """Implementation of `sqrt(1 + x**2)` which is stable despite large `x`.""" + return array_ops.where( + math_ops.abs(x) * np.sqrt(np.finfo(x.dtype.as_numpy_dtype).eps) <= 1., + math_ops.sqrt(x**2. + 1.), + # For large x, calculating x**2 can overflow. This can be alleviated by + # considering: + # sqrt(1 + x**2) + # = exp(0.5 log(1 + x**2)) + # = exp(0.5 log(x**2 * (1 + x**-2))) + # = exp(log(x) + 0.5 * log(1 + x**-2)) + # = |x| * exp(0.5 log(1 + x**-2)) + # = |x| * sqrt(1 + x**-2) + # We omit the last term in this approximation. + # When |x| > 1 / sqrt(machineepsilon), the second term will be 1, + # due to sqrt(1 + x**-2) = 1. This is also true with the gradient term, + # and higher order gradients, since the first order derivative of + # sqrt(1 + x**-2) is -2 * x**-3 / (1 + x**-2) = -2 / (x**3 + x), + # and all nth-order derivatives will be O(x**-(n + 2)). This makes any + # gradient terms that contain any derivatives of sqrt(1 + x**-2) vanish. + math_ops.abs(x)) + + +class SinhArcsinh(bijector.Bijector): + """Compute `Y = g(X) = Sinh( (Arcsinh(X) + skewness) * tailweight )`. + + For `skewness in (-inf, inf)` and `tailweight in (0, inf)`, this + transformation is a + diffeomorphism of the real line `(-inf, inf)`. The inverse transform is + `X = g^{-1}(Y) = Sinh( ArcSinh(Y) / tailweight - skewness )`. + + The `SinhArcsinh` transformation of the Normal is described in + [Sinh-arcsinh distributions](https://www.jstor.org/stable/27798865) + This Bijector allows a similar transformation of any distribution supported on + `(-inf, inf)`. + + #### Meaning of the parameters + + * If `skewness = 0` and `tailweight = 1`, this transform is the identity. + * Positive (negative) `skewness` leads to positive (negative) skew. + * positive skew means, for unimodal `X` centered at zero, the mode of `Y` is + "tilted" to the right. + * positive skew means positive values of `Y` become more likely, and + negative values become less likely. + * Larger (smaller) `tailweight` leads to fatter (thinner) tails. + * Fatter tails mean larger values of `|Y|` become more likely. + * If `X` is a unit Normal, `tailweight < 1` leads to a distribution that is + "flat" around `Y = 0`, and a very steep drop-off in the tails. + * If `X` is a unit Normal, `tailweight > 1` leads to a distribution more + peaked at the mode with heavier tails. + + To see the argument about the tails, note that for `|X| >> 1` and + `|X| >> (|skewness| * tailweight)**tailweight`, we have + `Y approx 0.5 X**tailweight e**(sign(X) skewness * tailweight)`. + """ + + def __init__(self, + skewness=None, + tailweight=None, + event_ndims=0, + validate_args=False, + name="SinhArcsinh"): + """Instantiates the `SinhArcsinh` bijector. + + Args: + skewness: Skewness parameter. Float-type `Tensor`. Default is `0` + of type `float32`. + tailweight: Tailweight parameter. Positive `Tensor` of same `dtype` as + `skewness` and broadcastable `shape`. Default is `1` of type `float32`. + event_ndims: Python scalar indicating the number of dimensions associated + with a particular draw from the distribution. + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str` name given to ops managed by this object. + """ + self._graph_parents = [] + self._name = name + self._validate_args = validate_args + with self._name_scope("init", values=[skewness, tailweight]): + tailweight = 1. if tailweight is None else tailweight + skewness = 0. if skewness is None else skewness + self._skewness = ops.convert_to_tensor( + skewness, name="skewness") + self._tailweight = ops.convert_to_tensor( + tailweight, name="tailweight", dtype=self._skewness.dtype) + check_ops.assert_same_float_dtype([self._skewness, self._tailweight]) + if validate_args: + self._tailweight = control_flow_ops.with_dependencies([ + check_ops.assert_positive( + self._tailweight, + message="Argument tailweight was not positive") + ], self._tailweight) + super(SinhArcsinh, self).__init__( + event_ndims=event_ndims, validate_args=validate_args, name=name) + + @property + def skewness(self): + """The `skewness` in: `Y = Sinh((Arcsinh(X) + skewness) * tailweight)`.""" + return self._skewness + + @property + def tailweight(self): + """The `tailweight` in: `Y = Sinh((Arcsinh(X) + skewness) * tailweight)`.""" + return self._tailweight + + def _forward(self, x): + return math_ops.sinh((math_ops.asinh(x) + self.skewness) * self.tailweight) + + def _inverse(self, y): + return math_ops.sinh(math_ops.asinh(y) / self.tailweight - self.skewness) + + def _inverse_log_det_jacobian(self, y): + # x = sinh(arcsinh(y) / tailweight - skewness) + # Using sinh' = cosh, arcsinh'(y) = 1 / sqrt(y**2 + 1), + # dx/dy + # = cosh(arcsinh(y) / tailweight - skewness) + # / (tailweight * sqrt(y**2 + 1)) + event_dims = self._event_dims_tensor(y) + return math_ops.reduce_sum( + # This is computed inside the log to avoid catastrophic cancellations + # from cosh((arcsinh(y) / tailweight) - skewness) and sqrt(x**2 + 1). + math_ops.log(math_ops.cosh( + math_ops.asinh(y) / self.tailweight - self.skewness) + # TODO(srvasude): Consider using cosh(arcsinh(x)) in cases + # where (arcsinh(x) / tailweight) - skewness ~= arcsinh(x). + / _sqrtx2p1(y)) + - math_ops.log(self.tailweight), + axis=event_dims) + + def _forward_log_det_jacobian(self, x): + # y = sinh((arcsinh(x) + skewness) * tailweight) + # Using sinh' = cosh, arcsinh'(x) = 1 / sqrt(x**2 + 1), + # dy/dx + # = cosh((arcsinh(x) + skewness) * tailweight) * tailweight / sqrt(x**2 + 1) + event_dims = self._event_dims_tensor(x) + return math_ops.reduce_sum( + # This is computed inside the log to avoid catastrophic cancellations + # from cosh((arcsinh(x) + skewness) * tailweight) and sqrt(x**2 + 1). + math_ops.log(math_ops.cosh( + (math_ops.asinh(x) + self.skewness) * self.tailweight) + # TODO(srvasude): Consider using cosh(arcsinh(x)) in cases + # where (arcsinh(x) + skewness) * tailweight ~= arcsinh(x). + / _sqrtx2p1(x)) + + math_ops.log(self.tailweight), + axis=event_dims) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/sinh_arcsinh_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/sinh_arcsinh_impl.py deleted file mode 100644 index 3a75e4ae9495793901b0da91a5aa3982aab35852..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distributions/python/ops/bijectors/sinh_arcsinh_impl.py +++ /dev/null @@ -1,179 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""SinhArcsinh bijector.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import check_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops.distributions import bijector - -__all__ = [ - "SinhArcsinh", -] - - -def _sqrtx2p1(x): - """Implementation of `sqrt(1 + x**2)` which is stable despite large `x`.""" - return array_ops.where( - math_ops.abs(x) * np.sqrt(np.finfo(x.dtype.as_numpy_dtype).eps) <= 1., - math_ops.sqrt(x**2. + 1.), - # For large x, calculating x**2 can overflow. This can be alleviated by - # considering: - # sqrt(1 + x**2) - # = exp(0.5 log(1 + x**2)) - # = exp(0.5 log(x**2 * (1 + x**-2))) - # = exp(log(x) + 0.5 * log(1 + x**-2)) - # = |x| * exp(0.5 log(1 + x**-2)) - # = |x| * sqrt(1 + x**-2) - # We omit the last term in this approximation. - # When |x| > 1 / sqrt(machineepsilon), the second term will be 1, - # due to sqrt(1 + x**-2) = 1. This is also true with the gradient term, - # and higher order gradients, since the first order derivative of - # sqrt(1 + x**-2) is -2 * x**-3 / (1 + x**-2) = -2 / (x**3 + x), - # and all nth-order derivatives will be O(x**-(n + 2)). This makes any - # gradient terms that contain any derivatives of sqrt(1 + x**-2) vanish. - math_ops.abs(x)) - - -class SinhArcsinh(bijector.Bijector): - """Compute `Y = g(X) = Sinh( (Arcsinh(X) + skewness) * tailweight )`. - - For `skewness in (-inf, inf)` and `tailweight in (0, inf)`, this - transformation is a - diffeomorphism of the real line `(-inf, inf)`. The inverse transform is - `X = g^{-1}(Y) = Sinh( ArcSinh(Y) / tailweight - skewness )`. - - The `SinhArcsinh` transformation of the Normal is described in - [Sinh-arcsinh distributions](https://www.jstor.org/stable/27798865) - This Bijector allows a similar transformation of any distribution supported on - `(-inf, inf)`. - - #### Meaning of the parameters - - * If `skewness = 0` and `tailweight = 1`, this transform is the identity. - * Positive (negative) `skewness` leads to positive (negative) skew. - * positive skew means, for unimodal `X` centered at zero, the mode of `Y` is - "tilted" to the right. - * positive skew means positive values of `Y` become more likely, and - negative values become less likely. - * Larger (smaller) `tailweight` leads to fatter (thinner) tails. - * Fatter tails mean larger values of `|Y|` become more likely. - * If `X` is a unit Normal, `tailweight < 1` leads to a distribution that is - "flat" around `Y = 0`, and a very steep drop-off in the tails. - * If `X` is a unit Normal, `tailweight > 1` leads to a distribution more - peaked at the mode with heavier tails. - - To see the argument about the tails, note that for `|X| >> 1` and - `|X| >> (|skewness| * tailweight)**tailweight`, we have - `Y approx 0.5 X**tailweight e**(sign(X) skewness * tailweight)`. - """ - - def __init__(self, - skewness=None, - tailweight=None, - event_ndims=0, - validate_args=False, - name="SinhArcsinh"): - """Instantiates the `SinhArcsinh` bijector. - - Args: - skewness: Skewness parameter. Float-type `Tensor`. Default is `0` - of type `float32`. - tailweight: Tailweight parameter. Positive `Tensor` of same `dtype` as - `skewness` and broadcastable `shape`. Default is `1` of type `float32`. - event_ndims: Python scalar indicating the number of dimensions associated - with a particular draw from the distribution. - validate_args: Python `bool` indicating whether arguments should be - checked for correctness. - name: Python `str` name given to ops managed by this object. - """ - self._graph_parents = [] - self._name = name - self._validate_args = validate_args - with self._name_scope("init", values=[skewness, tailweight]): - tailweight = 1. if tailweight is None else tailweight - skewness = 0. if skewness is None else skewness - self._skewness = ops.convert_to_tensor( - skewness, name="skewness") - self._tailweight = ops.convert_to_tensor( - tailweight, name="tailweight", dtype=self._skewness.dtype) - check_ops.assert_same_float_dtype([self._skewness, self._tailweight]) - if validate_args: - self._tailweight = control_flow_ops.with_dependencies([ - check_ops.assert_positive( - self._tailweight, - message="Argument tailweight was not positive") - ], self._tailweight) - super(SinhArcsinh, self).__init__( - event_ndims=event_ndims, validate_args=validate_args, name=name) - - @property - def skewness(self): - """The `skewness` in: `Y = Sinh((Arcsinh(X) + skewness) * tailweight)`.""" - return self._skewness - - @property - def tailweight(self): - """The `tailweight` in: `Y = Sinh((Arcsinh(X) + skewness) * tailweight)`.""" - return self._tailweight - - def _forward(self, x): - return math_ops.sinh((math_ops.asinh(x) + self.skewness) * self.tailweight) - - def _inverse(self, y): - return math_ops.sinh(math_ops.asinh(y) / self.tailweight - self.skewness) - - def _inverse_log_det_jacobian(self, y): - # x = sinh(arcsinh(y) / tailweight - skewness) - # Using sinh' = cosh, arcsinh'(y) = 1 / sqrt(y**2 + 1), - # dx/dy - # = cosh(arcsinh(y) / tailweight - skewness) - # / (tailweight * sqrt(y**2 + 1)) - event_dims = self._event_dims_tensor(y) - return math_ops.reduce_sum( - # This is computed inside the log to avoid catastrophic cancellations - # from cosh((arcsinh(y) / tailweight) - skewness) and sqrt(x**2 + 1). - math_ops.log(math_ops.cosh( - math_ops.asinh(y) / self.tailweight - self.skewness) - # TODO(srvasude): Consider using cosh(arcsinh(x)) in cases - # where (arcsinh(x) / tailweight) - skewness ~= arcsinh(x). - / _sqrtx2p1(y)) - - math_ops.log(self.tailweight), - axis=event_dims) - - def _forward_log_det_jacobian(self, x): - # y = sinh((arcsinh(x) + skewness) * tailweight) - # Using sinh' = cosh, arcsinh'(x) = 1 / sqrt(x**2 + 1), - # dy/dx - # = cosh((arcsinh(x) + skewness) * tailweight) * tailweight / sqrt(x**2 + 1) - event_dims = self._event_dims_tensor(x) - return math_ops.reduce_sum( - # This is computed inside the log to avoid catastrophic cancellations - # from cosh((arcsinh(x) + skewness) * tailweight) and sqrt(x**2 + 1). - math_ops.log(math_ops.cosh( - (math_ops.asinh(x) + self.skewness) * self.tailweight) - # TODO(srvasude): Consider using cosh(arcsinh(x)) in cases - # where (arcsinh(x) + skewness) * tailweight ~= arcsinh(x). - / _sqrtx2p1(x)) - + math_ops.log(self.tailweight), - axis=event_dims) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered.py b/tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered.py index be6608f97880ae68e10b17c815bf2d8438293261..a9dcce6c526600f3b26c6bceb730417000917ce7 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered.py @@ -18,12 +18,223 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.distributions.python.ops.bijectors.softmax_centered_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented +import numpy as np -_allowed_symbols = ["SoftmaxCentered"] +from tensorflow.contrib.distributions.python.ops import distribution_util +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops.distributions import bijector -remove_undocumented(__name__, _allowed_symbols) + +__all__ = [ + "SoftmaxCentered", +] + + +class SoftmaxCentered(bijector.Bijector): + """Bijector which computes `Y = g(X) = exp([X 0]) / sum(exp([X 0]))`. + + To implement [softmax](https://en.wikipedia.org/wiki/Softmax_function) as a + bijection, the forward transformation appends a value to the input and the + inverse removes this coordinate. The appended coordinate represents a pivot, + e.g., `softmax(x) = exp(x-c) / sum(exp(x-c))` where `c` is the implicit last + coordinate. + + Because we append a coordinate, this bijector only supports `event_ndim in [0, + 1]`, i.e., scalars and vectors. + + Example Use: + + ```python + bijector.SoftmaxCentered(event_ndims=1).forward(tf.log([2, 3, 4])) + # Result: [0.2, 0.3, 0.4, 0.1] + # Extra result: 0.1 + + bijector.SoftmaxCentered(event_ndims=1).inverse([0.2, 0.3, 0.4, 0.1]) + # Result: tf.log([2, 3, 4]) + # Extra coordinate removed. + ``` + + At first blush it may seem like the [Invariance of domain]( + https://en.wikipedia.org/wiki/Invariance_of_domain) theorem implies this + implementation is not a bijection. However, the appended dimension + makes the (forward) image non-open and the theorem does not directly apply. + """ + + def __init__(self, + event_ndims=0, + validate_args=False, + name="softmax_centered"): + self._graph_parents = [] + self._name = name + with self._name_scope("init", values=[event_ndims]): + event_ndims = ops.convert_to_tensor(event_ndims, name="event_ndims") + event_ndims = tensor_util.constant_value(event_ndims) + if event_ndims is None or event_ndims not in [0, 1]: + raise ValueError("`event_ndims` must be a TF constant which is 0 or 1") + self._static_event_ndims = event_ndims + super(SoftmaxCentered, self).__init__( + event_ndims=event_ndims, + validate_args=validate_args, + name=name) + + def _forward_event_shape(self, input_shape): + if input_shape.ndims is None: + return input_shape + if input_shape.ndims != self._static_event_ndims: + raise ValueError("input_shape.dims = %d != %d" % + (input_shape.ndims, self._static_event_ndims)) + if input_shape.ndims == 0: + return tensor_shape.TensorShape([2]) + if input_shape.ndims == 1: + return tensor_shape.TensorShape(input_shape[0] + 1) + # Unreachable code: + raise ValueError("event_ndims = %d must be 0 or 1" % input_shape.ndims) + + def _forward_event_shape_tensor(self, input_shape): + ndims = array_ops.shape(input_shape) + if self.validate_args: + # It is not possible for a negative shape so we need only check <= 1. + is_zero_or_one = check_ops.assert_equal( + ndims, 0 if self._static_event_ndims == 0 else 1, + message="event_ndims must be 0 or 1") + ndims = control_flow_ops.with_dependencies([is_zero_or_one], ndims) + if self._static_event_ndims == 0: + return ops.convert_to_tensor( + [2], dtype=dtypes.int32, name="output_shape") + return input_shape + 1 + + def _inverse_event_shape(self, output_shape): + if output_shape.ndims is None: + return output_shape + if output_shape.ndims != 1: + raise ValueError("output_shape.ndims = %d != 1" % output_shape.ndims) + if self._static_event_ndims == 0: + return tensor_shape.TensorShape([]) + return tensor_shape.TensorShape(output_shape[0] - 1) + + def _inverse_event_shape_tensor(self, output_shape): + ndims = array_ops.shape(output_shape)[0] + if self.validate_args: + # It is not possible for a negative shape so we need only check <= 1. + is_one = check_ops.assert_equal( + ndims, 1, message="event_ndims must be 1") + ndims = control_flow_ops.with_dependencies([is_one], ndims) + if self._static_event_ndims == 0: + return ops.convert_to_tensor([], dtype=dtypes.int32, name="output_shape") + return array_ops.expand_dims(output_shape[0] - 1, dim=0) + + def _forward(self, x): + # Pad the last dim with a zeros vector. We need this because it lets us + # infer the scale in the inverse function. + y = array_ops.expand_dims(x, dim=-1) if self._static_event_ndims == 0 else x + y = distribution_util.pad(y, axis=-1, back=True) + + # Set shape hints. + if x.shape.ndims is not None: + shape = x.shape.as_list() + if self._static_event_ndims == 0: + shape += [2] + elif shape[-1] is not None: + shape[-1] += 1 + shape = tensor_shape.TensorShape(shape) + y.shape.assert_is_compatible_with(shape) + y.set_shape(shape) + + # Since we only support event_ndims in [0, 1] and we do padding, we always + # reduce over the last dimension, i.e., dim=-1 (which is the default). + return nn_ops.softmax(y) + + def _inverse(self, y): + # To derive the inverse mapping note that: + # y[i] = exp(x[i]) / normalization + # and + # y[end] = 1 / normalization. + # Thus: + # x[i] = log(exp(x[i])) - log(y[end]) - log(normalization) + # = log(exp(x[i])/normalization) - log(y[end]) + # = log(y[i]) - log(y[end]) + shape = (np.asarray(y.shape.as_list(), dtype=np.int32) + if y.shape.is_fully_defined() + else array_ops.shape(y, name="shape")) + ndims = distribution_util.prefer_static_rank(y) + + # Do this first to make sure CSE catches that it'll happen again in + # _inverse_log_det_jacobian. + x = math_ops.log(y) + + # We now extract the last coordinate of the rightmost dimension. + # Our trick is to slice from [0,0,...,shape[-1]-1] to shape[:-1]+[1]. + begin = array_ops.one_hot(indices=ndims-1, + depth=ndims, + on_value=shape[-1]-np.array(1, dtype=shape.dtype), + dtype=shape.dtype) + size = array_ops.concat([shape[:-1], np.asarray([1], dtype=shape.dtype)], 0) + log_normalization = -array_ops.strided_slice(x, begin, begin + size) + + # Here we slice out all but the last coordinate; see above for idea. + begin = array_ops.zeros_like(shape) + size = array_ops.concat([shape[:-1], [shape[-1] - 1]], 0) + x = array_ops.strided_slice(x, begin, begin + size) + + x += log_normalization + + if self._static_event_ndims == 0: + x = array_ops.squeeze(x, squeeze_dims=[ndims-1]) + + # Set shape hints. + if y.shape.ndims is not None: + shape = y.shape.as_list() + if self._static_event_ndims == 0: + shape = shape[:-1] + elif shape[-1] is not None: + shape[-1] -= 1 + shape = tensor_shape.TensorShape(shape) + x.shape.assert_is_compatible_with(shape) + x.set_shape(shape) + + return x + + def _inverse_log_det_jacobian(self, y): + # WLOG, consider the vector case: + # x = log(y[:-1]) - log(y[-1]) + # where, + # y[-1] = 1 - sum(y[:-1]). + # We have: + # det{ dX/dY } = det{ diag(1 ./ y[:-1]) + 1 / y[-1] } + # = det{ inv{ diag(y[:-1]) - y[:-1]' y[:-1] } } (1) + # = 1 / det{ diag(y[:-1]) - y[:-1]' y[:-1] } + # = 1 / { (1 + y[:-1]' inv(diag(y[:-1])) y[:-1]) * + # det(diag(y[:-1])) } (2) + # = 1 / { y[-1] prod(y[:-1]) } + # = 1 / prod(y) + # (1) - https://en.wikipedia.org/wiki/Sherman%E2%80%93Morrison_formula + # or by noting that det{ dX/dY } = 1 / det{ dY/dX } from Bijector + # docstring "Tip". + # (2) - https://en.wikipedia.org/wiki/Matrix_determinant_lemma + return -math_ops.reduce_sum(math_ops.log(y), axis=-1) + + def _forward_log_det_jacobian(self, x): + if self._static_event_ndims == 0: + return x - 2. * nn_ops.softplus(x) + else: + # This code is similar to nn_ops.log_softmax but different because we have + # an implicit zero column to handle. I.e., instead of: + # reduce_sum(logits - reduce_sum(exp(logits), dim)) + # we must do: + # log_normalization = 1 + reduce_sum(exp(logits)) + # -log_normalization + reduce_sum(logits - log_normalization) + log_normalization = nn_ops.softplus( + math_ops.reduce_logsumexp(x, axis=-1, keep_dims=True)) + fldj = (-log_normalization + + math_ops.reduce_sum(x - log_normalization, + axis=-1, + keep_dims=True)) + return array_ops.squeeze(fldj, squeeze_dims=-1) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered_impl.py deleted file mode 100644 index 8645cc1b6b04be75a419342591272f07a4a1711c..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered_impl.py +++ /dev/null @@ -1,245 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""SoftmaxCentered bijector.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import tensor_util -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import check_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import nn_ops -from tensorflow.python.ops.distributions import bijector - - -__all__ = [ - "SoftmaxCentered", -] - - -class SoftmaxCentered(bijector.Bijector): - """Bijector which computes `Y = g(X) = exp([X 0]) / sum(exp([X 0]))`. - - To implement [softmax](https://en.wikipedia.org/wiki/Softmax_function) as a - bijection, the forward transformation appends a value to the input and the - inverse removes this coordinate. The appended coordinate represents a pivot, - e.g., `softmax(x) = exp(x-c) / sum(exp(x-c))` where `c` is the implicit last - coordinate. - - Because we append a coordinate, this bijector only supports `event_ndim in [0, - 1]`, i.e., scalars and vectors. - - Example Use: - - ```python - bijector.SoftmaxCentered(event_ndims=1).forward(tf.log([2, 3, 4])) - # Result: [0.2, 0.3, 0.4, 0.1] - # Extra result: 0.1 - - bijector.SoftmaxCentered(event_ndims=1).inverse([0.2, 0.3, 0.4, 0.1]) - # Result: tf.log([2, 3, 4]) - # Extra coordinate removed. - ``` - - At first blush it may seem like the [Invariance of domain]( - https://en.wikipedia.org/wiki/Invariance_of_domain) theorem implies this - implementation is not a bijection. However, the appended dimension - makes the (forward) image non-open and the theorem does not directly apply. - """ - - def __init__(self, - event_ndims=0, - validate_args=False, - name="softmax_centered"): - self._graph_parents = [] - self._name = name - with self._name_scope("init", values=[event_ndims]): - event_ndims = ops.convert_to_tensor(event_ndims, name="event_ndims") - event_ndims = tensor_util.constant_value(event_ndims) - if event_ndims is None or event_ndims not in [0, 1]: - raise ValueError("`event_ndims` must be a TF constant which is 0 or 1") - self._static_event_ndims = event_ndims - super(SoftmaxCentered, self).__init__( - event_ndims=event_ndims, - validate_args=validate_args, - name=name) - - def _forward_event_shape(self, input_shape): - if input_shape.ndims is None: - return input_shape - if input_shape.ndims != self._static_event_ndims: - raise ValueError("input_shape.dims = %d != %d" % - (input_shape.ndims, self._static_event_ndims)) - if input_shape.ndims == 0: - return tensor_shape.TensorShape([2]) - if input_shape.ndims == 1: - return tensor_shape.TensorShape(input_shape[0] + 1) - # Unreachable code: - raise ValueError("event_ndims = %d must be 0 or 1" % input_shape.ndims) - - def _forward_event_shape_tensor(self, input_shape): - ndims = array_ops.shape(input_shape) - if self.validate_args: - # It is not possible for a negative shape so we need only check <= 1. - is_zero_or_one = check_ops.assert_equal( - ndims, 0 if self._static_event_ndims == 0 else 1, - message="event_ndims must be 0 or 1") - ndims = control_flow_ops.with_dependencies([is_zero_or_one], ndims) - if self._static_event_ndims == 0: - return ops.convert_to_tensor( - [2], dtype=dtypes.int32, name="output_shape") - return input_shape + 1 - - def _inverse_event_shape(self, output_shape): - if output_shape.ndims is None: - return output_shape - if output_shape.ndims != 1: - raise ValueError("output_shape.ndims = %d != 1" % output_shape.ndims) - if self._static_event_ndims == 0: - return tensor_shape.TensorShape([]) - return tensor_shape.TensorShape(output_shape[0] - 1) - - def _inverse_event_shape_tensor(self, output_shape): - ndims = array_ops.shape(output_shape)[0] - if self.validate_args: - # It is not possible for a negative shape so we need only check <= 1. - is_one = check_ops.assert_equal( - ndims, 1, message="event_ndims must be 1") - ndims = control_flow_ops.with_dependencies([is_one], ndims) - if self._static_event_ndims == 0: - return ops.convert_to_tensor([], dtype=dtypes.int32, name="output_shape") - return array_ops.expand_dims(output_shape[0] - 1, dim=0) - - def _forward(self, x): - # Pad the last dim with a zeros vector. We need this because it lets us - # infer the scale in the inverse function. - y = array_ops.expand_dims(x, dim=-1) if self._static_event_ndims == 0 else x - ndims = (y.get_shape().ndims if y.get_shape().ndims is not None - else array_ops.rank(y)) - y = array_ops.pad(y, - paddings=array_ops.concat( - (array_ops.zeros( - (ndims - 1, 2), dtype=dtypes.int32), [[0, 1]]), - 0)) - - # Set shape hints. - if x.get_shape().ndims is not None: - shape = x.get_shape().as_list() - if self._static_event_ndims == 0: - shape += [2] - elif shape[-1] is not None: - shape[-1] += 1 - shape = tensor_shape.TensorShape(shape) - y.get_shape().assert_is_compatible_with(shape) - y.set_shape(shape) - - # Since we only support event_ndims in [0, 1] and we do padding, we always - # reduce over the last dimension, i.e., dim=-1 (which is the default). - return nn_ops.softmax(y) - - def _inverse(self, y): - # To derive the inverse mapping note that: - # y[i] = exp(x[i]) / normalization - # and - # y[end] = 1 / normalization. - # Thus: - # x[i] = log(exp(x[i])) - log(y[end]) - log(normalization) - # = log(exp(x[i])/normalization) - log(y[end]) - # = log(y[i]) - log(y[end]) - shape = (np.asarray(y.get_shape().as_list(), dtype=np.int32) - if y.get_shape().is_fully_defined() - else array_ops.shape(y, name="shape")) - ndims = y.get_shape().ndims or math_ops.rank(y, name="ndims") - - # Do this first to make sure CSE catches that it'll happen again in - # _inverse_log_det_jacobian. - x = math_ops.log(y) - - # We now extract the last coordinate of the rightmost dimension. - # Our trick is to slice from [0,0,...,shape[-1]-1] to shape[:-1]+[1]. - begin = array_ops.one_hot(indices=ndims-1, - depth=ndims, - on_value=shape[-1]-np.array(1, dtype=shape.dtype), - dtype=shape.dtype) - size = array_ops.concat([shape[:-1], np.asarray([1], dtype=shape.dtype)], 0) - log_normalization = -array_ops.strided_slice(x, begin, begin + size) - - # Here we slice out all but the last coordinate; see above for idea. - begin = array_ops.zeros_like(shape) - size = array_ops.concat([shape[:-1], [shape[-1] - 1]], 0) - x = array_ops.strided_slice(x, begin, begin + size) - - x += log_normalization - - if self._static_event_ndims == 0: - x = array_ops.squeeze(x, squeeze_dims=[ndims-1]) - - # Set shape hints. - if y.get_shape().ndims is not None: - shape = y.get_shape().as_list() - if self._static_event_ndims == 0: - shape = shape[:-1] - elif shape[-1] is not None: - shape[-1] -= 1 - shape = tensor_shape.TensorShape(shape) - x.get_shape().assert_is_compatible_with(shape) - x.set_shape(shape) - - return x - - def _inverse_log_det_jacobian(self, y): - # WLOG, consider the vector case: - # x = log(y[:-1]) - log(y[-1]) - # where, - # y[-1] = 1 - sum(y[:-1]). - # We have: - # det{ dX/dY } = det{ diag(1 ./ y[:-1]) + 1 / y[-1] } - # = det{ inv{ diag(y[:-1]) - y[:-1]' y[:-1] } } (1) - # = 1 / det{ diag(y[:-1]) - y[:-1]' y[:-1] } - # = 1 / { (1 + y[:-1]' inv(diag(y[:-1])) y[:-1]) * - # det(diag(y[:-1])) } (2) - # = 1 / { y[-1] prod(y[:-1]) } - # = 1 / prod(y) - # (1) - https://en.wikipedia.org/wiki/Sherman%E2%80%93Morrison_formula - # or by noting that det{ dX/dY } = 1 / det{ dY/dX } from Bijector - # docstring "Tip". - # (2) - https://en.wikipedia.org/wiki/Matrix_determinant_lemma - return -math_ops.reduce_sum(math_ops.log(y), axis=-1) - - def _forward_log_det_jacobian(self, x): - if self._static_event_ndims == 0: - return x - 2. * nn_ops.softplus(x) - else: - # This code is similar to nn_ops.log_softmax but different because we have - # an implicit zero column to handle. I.e., instead of: - # reduce_sum(logits - reduce_sum(exp(logits), dim)) - # we must do: - # log_normalization = 1 + reduce_sum(exp(logits)) - # -log_normalization + reduce_sum(logits - log_normalization) - log_normalization = nn_ops.softplus( - math_ops.reduce_logsumexp(x, axis=-1, keep_dims=True)) - fldj = (-log_normalization + - math_ops.reduce_sum(x - log_normalization, - axis=-1, - keep_dims=True)) - return array_ops.squeeze(fldj, squeeze_dims=-1) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/softplus.py b/tensorflow/contrib/distributions/python/ops/bijectors/softplus.py index 250a1144b53bb43271ff7ee494604d9bae6feda8..81957fcf78922fa15fd20a25d144071f431161ae 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/softplus.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/softplus.py @@ -18,12 +18,127 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.distributions.python.ops.bijectors.softplus_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented +from tensorflow.python.framework import ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops.distributions import bijector +from tensorflow.python.ops.distributions import util as distribution_util -_allowed_symbols = ["Softplus"] -remove_undocumented(__name__, _allowed_symbols) +__all__ = [ + "Softplus", +] + + +class Softplus(bijector.Bijector): + """Bijector which computes `Y = g(X) = Log[1 + exp(X)]`. + + The softplus `Bijector` has the following two useful properties: + + * The domain is the positive real numbers + * `softplus(x) approx x`, for large `x`, so it does not overflow as easily as + the `Exp` `Bijector`. + + The optional nonzero `hinge_softness` parameter changes the transition at + zero. With `hinge_softness = c`, the bijector is: + + ```f_c(x) := c * g(x / c) = c * Log[1 + exp(x / c)].``` + + For large `x >> 1`, `c * Log[1 + exp(x / c)] approx c * Log[exp(x / c)] = x`, + so the behavior for large `x` is the same as the standard softplus. + + As `c > 0` approaches 0 from the right, `f_c(x)` becomes less and less soft, + approaching `max(0, x)`. + + * `c = 1` is the default. + * `c > 0` but small means `f(x) approx ReLu(x) = max(0, x)`. + * `c < 0` flips sign and reflects around the `y-axis`: `f_{-c}(x) = -f_c(-x)`. + * `c = 0` results in a non-bijective transformation and triggers an exception. + + Example Use: + + ```python + # Create the Y=g(X)=softplus(X) transform which works only on Tensors with 1 + # batch ndim and 2 event ndims (i.e., vector of matrices). + softplus = Softplus(event_ndims=2) + x = [[[1., 2], + [3, 4]], + [[5, 6], + [7, 8]]] + log(1 + exp(x)) == softplus.forward(x) + log(exp(x) - 1) == softplus.inverse(x) + ``` + + Note: log(.) and exp(.) are applied element-wise but the Jacobian is a + reduction over the event space. + """ + + @distribution_util.AppendDocstring( + kwargs_dict={ + "hinge_softness": ( + "Nonzero floating point `Tensor`. Controls the softness of what " + "would otherwise be a kink at the origin. Default is 1.0")}) + def __init__(self, + event_ndims=0, + hinge_softness=None, + validate_args=False, + name="softplus"): + with ops.name_scope(name, values=[hinge_softness]): + if hinge_softness is not None: + self._hinge_softness = ops.convert_to_tensor( + hinge_softness, name="hinge_softness") + else: + self._hinge_softness = None + if validate_args: + nonzero_check = check_ops.assert_none_equal( + ops.convert_to_tensor( + 0, dtype=self.hinge_softness.dtype), + self.hinge_softness, + message="hinge_softness must be non-zero") + self._hinge_softness = control_flow_ops.with_dependencies( + [nonzero_check], self.hinge_softness) + + super(Softplus, self).__init__( + event_ndims=event_ndims, + validate_args=validate_args, + name=name) + + def _forward(self, x): + if self.hinge_softness is None: + return nn_ops.softplus(x) + hinge_softness = math_ops.cast(self.hinge_softness, x.dtype) + return hinge_softness * nn_ops.softplus(x / hinge_softness) + + def _inverse(self, y): + if self.hinge_softness is None: + return distribution_util.softplus_inverse(y) + hinge_softness = math_ops.cast(self.hinge_softness, y.dtype) + return hinge_softness * distribution_util.softplus_inverse( + y / hinge_softness) + + def _inverse_log_det_jacobian(self, y): + # Could also do: + # ildj = math_ops.reduce_sum(y - distribution_util.softplus_inverse(y), + # axis=event_dims) + # but the following is more numerically stable. Ie, + # Y = Log[1 + exp{X}] ==> X = Log[exp{Y} - 1] + # ==> dX/dY = exp{Y} / (exp{Y} - 1) + # = 1 / (1 - exp{-Y}), + # which is the most stable for large Y > 0. For small Y, we use + # 1 - exp{-Y} approx Y. + if self.hinge_softness is not None: + y /= math_ops.cast(self.hinge_softness, y.dtype) + return -math_ops.reduce_sum(math_ops.log(-math_ops.expm1(-y)), + axis=self._event_dims_tensor(y)) + + def _forward_log_det_jacobian(self, x): + if self.hinge_softness is not None: + x /= math_ops.cast(self.hinge_softness, x.dtype) + return -math_ops.reduce_sum(nn_ops.softplus(-x), + axis=self._event_dims_tensor(x)) + + @property + def hinge_softness(self): + return self._hinge_softness diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/softplus_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/softplus_impl.py deleted file mode 100644 index 81957fcf78922fa15fd20a25d144071f431161ae..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distributions/python/ops/bijectors/softplus_impl.py +++ /dev/null @@ -1,144 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Softplus bijector.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python.framework import ops -from tensorflow.python.ops import check_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import nn_ops -from tensorflow.python.ops.distributions import bijector -from tensorflow.python.ops.distributions import util as distribution_util - - -__all__ = [ - "Softplus", -] - - -class Softplus(bijector.Bijector): - """Bijector which computes `Y = g(X) = Log[1 + exp(X)]`. - - The softplus `Bijector` has the following two useful properties: - - * The domain is the positive real numbers - * `softplus(x) approx x`, for large `x`, so it does not overflow as easily as - the `Exp` `Bijector`. - - The optional nonzero `hinge_softness` parameter changes the transition at - zero. With `hinge_softness = c`, the bijector is: - - ```f_c(x) := c * g(x / c) = c * Log[1 + exp(x / c)].``` - - For large `x >> 1`, `c * Log[1 + exp(x / c)] approx c * Log[exp(x / c)] = x`, - so the behavior for large `x` is the same as the standard softplus. - - As `c > 0` approaches 0 from the right, `f_c(x)` becomes less and less soft, - approaching `max(0, x)`. - - * `c = 1` is the default. - * `c > 0` but small means `f(x) approx ReLu(x) = max(0, x)`. - * `c < 0` flips sign and reflects around the `y-axis`: `f_{-c}(x) = -f_c(-x)`. - * `c = 0` results in a non-bijective transformation and triggers an exception. - - Example Use: - - ```python - # Create the Y=g(X)=softplus(X) transform which works only on Tensors with 1 - # batch ndim and 2 event ndims (i.e., vector of matrices). - softplus = Softplus(event_ndims=2) - x = [[[1., 2], - [3, 4]], - [[5, 6], - [7, 8]]] - log(1 + exp(x)) == softplus.forward(x) - log(exp(x) - 1) == softplus.inverse(x) - ``` - - Note: log(.) and exp(.) are applied element-wise but the Jacobian is a - reduction over the event space. - """ - - @distribution_util.AppendDocstring( - kwargs_dict={ - "hinge_softness": ( - "Nonzero floating point `Tensor`. Controls the softness of what " - "would otherwise be a kink at the origin. Default is 1.0")}) - def __init__(self, - event_ndims=0, - hinge_softness=None, - validate_args=False, - name="softplus"): - with ops.name_scope(name, values=[hinge_softness]): - if hinge_softness is not None: - self._hinge_softness = ops.convert_to_tensor( - hinge_softness, name="hinge_softness") - else: - self._hinge_softness = None - if validate_args: - nonzero_check = check_ops.assert_none_equal( - ops.convert_to_tensor( - 0, dtype=self.hinge_softness.dtype), - self.hinge_softness, - message="hinge_softness must be non-zero") - self._hinge_softness = control_flow_ops.with_dependencies( - [nonzero_check], self.hinge_softness) - - super(Softplus, self).__init__( - event_ndims=event_ndims, - validate_args=validate_args, - name=name) - - def _forward(self, x): - if self.hinge_softness is None: - return nn_ops.softplus(x) - hinge_softness = math_ops.cast(self.hinge_softness, x.dtype) - return hinge_softness * nn_ops.softplus(x / hinge_softness) - - def _inverse(self, y): - if self.hinge_softness is None: - return distribution_util.softplus_inverse(y) - hinge_softness = math_ops.cast(self.hinge_softness, y.dtype) - return hinge_softness * distribution_util.softplus_inverse( - y / hinge_softness) - - def _inverse_log_det_jacobian(self, y): - # Could also do: - # ildj = math_ops.reduce_sum(y - distribution_util.softplus_inverse(y), - # axis=event_dims) - # but the following is more numerically stable. Ie, - # Y = Log[1 + exp{X}] ==> X = Log[exp{Y} - 1] - # ==> dX/dY = exp{Y} / (exp{Y} - 1) - # = 1 / (1 - exp{-Y}), - # which is the most stable for large Y > 0. For small Y, we use - # 1 - exp{-Y} approx Y. - if self.hinge_softness is not None: - y /= math_ops.cast(self.hinge_softness, y.dtype) - return -math_ops.reduce_sum(math_ops.log(-math_ops.expm1(-y)), - axis=self._event_dims_tensor(y)) - - def _forward_log_det_jacobian(self, x): - if self.hinge_softness is not None: - x /= math_ops.cast(self.hinge_softness, x.dtype) - return -math_ops.reduce_sum(nn_ops.softplus(-x), - axis=self._event_dims_tensor(x)) - - @property - def hinge_softness(self): - return self._hinge_softness diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/weibull.py b/tensorflow/contrib/distributions/python/ops/bijectors/weibull.py index d439f28884d8bd7f2b808317e10c5b5e44bfcfa2..00520bcda85e9527767e6342bf75f10667c264a8 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/weibull.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/weibull.py @@ -18,12 +18,132 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.distributions.python.ops.bijectors.weibull_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import bijector -_allowed_symbols = ["Weibull"] -remove_undocumented(__name__, _allowed_symbols) +__all__ = [ + "Weibull", +] + + +class Weibull(bijector.Bijector): + """Compute `Y = g(X) = 1 - exp((-X / scale) ** concentration), X >= 0`. + + This bijector maps inputs from `[0, inf]` to [0, 1]`. The inverse of the + bijector applied to a uniform random variable `X ~ U(0, 1) gives back a + random variable with the + [Weibull distribution](https://en.wikipedia.org/wiki/Weibull_distribution): + + ```none + Y ~ Weibull(scale, concentration) + pdf(y; scale, concentration, y >= 0) = (scale / concentration) * ( + scale / concentration) ** (concentration - 1) * exp( + -(y / scale) ** concentration) + ``` + """ + + def __init__(self, + scale=1., + concentration=1., + event_ndims=0, + validate_args=False, + name="weibull"): + """Instantiates the `Weibull` bijector. + + Args: + scale: Positive Float-type `Tensor` that is the same dtype and is + broadcastable with `concentration`. + This is `l` in `Y = g(X) = 1 - exp((-x / l) ** k)`. + concentration: Positive Float-type `Tensor` that is the same dtype and is + broadcastable with `scale`. + This is `k` in `Y = g(X) = 1 - exp((-x / l) ** k)`. + event_ndims: Python scalar indicating the number of dimensions associated + with a particular draw from the distribution. + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str` name given to ops managed by this object. + """ + self._graph_parents = [] + self._name = name + self._validate_args = validate_args + with self._name_scope("init", values=[scale, concentration]): + self._scale = ops.convert_to_tensor(scale, name="scale") + self._concentration = ops.convert_to_tensor( + concentration, name="concentration") + check_ops.assert_same_float_dtype([self._scale, self._concentration]) + if validate_args: + self._scale = control_flow_ops.with_dependencies([ + check_ops.assert_positive( + self._scale, + message="Argument scale was not positive") + ], self._scale) + self._concentration = control_flow_ops.with_dependencies([ + check_ops.assert_positive( + self._concentration, + message="Argument concentration was not positive") + ], self._concentration) + + super(Weibull, self).__init__( + event_ndims=event_ndims, + validate_args=validate_args, + name=name) + + @property + def scale(self): + """The `l` in `Y = g(X) = 1 - exp((-x / l) ** k)`.""" + return self._scale + + @property + def concentration(self): + """The `k` in `Y = g(X) = 1 - exp((-x / l) ** k)`.""" + return self._concentration + + def _forward(self, x): + x = self._maybe_assert_valid_x(x) + return -math_ops.expm1(-((x / self.scale) ** self.concentration)) + + def _inverse(self, y): + y = self._maybe_assert_valid_y(y) + return self.scale * (-math_ops.log1p(-y)) ** (1 / self.concentration) + + def _inverse_log_det_jacobian(self, y): + y = self._maybe_assert_valid_y(y) + event_dims = self._event_dims_tensor(y) + return math_ops.reduce_sum( + -math_ops.log1p(-y) + + (1 / self.concentration - 1) * math_ops.log(-math_ops.log1p(-y)) + + math_ops.log(self.scale / self.concentration), + axis=event_dims) + + def _forward_log_det_jacobian(self, x): + x = self._maybe_assert_valid_x(x) + event_dims = self._event_dims_tensor(x) + return math_ops.reduce_sum( + -(x / self.scale) ** self.concentration + + (self.concentration - 1) * math_ops.log(x) + + math_ops.log(self.concentration) + + -self.concentration * math_ops.log(self.scale), + axis=event_dims) + + def _maybe_assert_valid_x(self, x): + if not self.validate_args: + return x + is_valid = check_ops.assert_non_negative( + x, + message="Forward transformation input must be at least {}.".format(0)) + return control_flow_ops.with_dependencies([is_valid], x) + + def _maybe_assert_valid_y(self, y): + if not self.validate_args: + return y + is_positive = check_ops.assert_non_negative( + y, message="Inverse transformation input must be greater than 0.") + less_than_one = check_ops.assert_less_equal( + y, constant_op.constant(1., y.dtype), + message="Inverse transformation input must be less than or equal to 1.") + return control_flow_ops.with_dependencies([is_positive, less_than_one], y) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/weibull_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/weibull_impl.py deleted file mode 100644 index 00520bcda85e9527767e6342bf75f10667c264a8..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distributions/python/ops/bijectors/weibull_impl.py +++ /dev/null @@ -1,149 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Weibull bijector.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import ops -from tensorflow.python.ops import check_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops.distributions import bijector - - -__all__ = [ - "Weibull", -] - - -class Weibull(bijector.Bijector): - """Compute `Y = g(X) = 1 - exp((-X / scale) ** concentration), X >= 0`. - - This bijector maps inputs from `[0, inf]` to [0, 1]`. The inverse of the - bijector applied to a uniform random variable `X ~ U(0, 1) gives back a - random variable with the - [Weibull distribution](https://en.wikipedia.org/wiki/Weibull_distribution): - - ```none - Y ~ Weibull(scale, concentration) - pdf(y; scale, concentration, y >= 0) = (scale / concentration) * ( - scale / concentration) ** (concentration - 1) * exp( - -(y / scale) ** concentration) - ``` - """ - - def __init__(self, - scale=1., - concentration=1., - event_ndims=0, - validate_args=False, - name="weibull"): - """Instantiates the `Weibull` bijector. - - Args: - scale: Positive Float-type `Tensor` that is the same dtype and is - broadcastable with `concentration`. - This is `l` in `Y = g(X) = 1 - exp((-x / l) ** k)`. - concentration: Positive Float-type `Tensor` that is the same dtype and is - broadcastable with `scale`. - This is `k` in `Y = g(X) = 1 - exp((-x / l) ** k)`. - event_ndims: Python scalar indicating the number of dimensions associated - with a particular draw from the distribution. - validate_args: Python `bool` indicating whether arguments should be - checked for correctness. - name: Python `str` name given to ops managed by this object. - """ - self._graph_parents = [] - self._name = name - self._validate_args = validate_args - with self._name_scope("init", values=[scale, concentration]): - self._scale = ops.convert_to_tensor(scale, name="scale") - self._concentration = ops.convert_to_tensor( - concentration, name="concentration") - check_ops.assert_same_float_dtype([self._scale, self._concentration]) - if validate_args: - self._scale = control_flow_ops.with_dependencies([ - check_ops.assert_positive( - self._scale, - message="Argument scale was not positive") - ], self._scale) - self._concentration = control_flow_ops.with_dependencies([ - check_ops.assert_positive( - self._concentration, - message="Argument concentration was not positive") - ], self._concentration) - - super(Weibull, self).__init__( - event_ndims=event_ndims, - validate_args=validate_args, - name=name) - - @property - def scale(self): - """The `l` in `Y = g(X) = 1 - exp((-x / l) ** k)`.""" - return self._scale - - @property - def concentration(self): - """The `k` in `Y = g(X) = 1 - exp((-x / l) ** k)`.""" - return self._concentration - - def _forward(self, x): - x = self._maybe_assert_valid_x(x) - return -math_ops.expm1(-((x / self.scale) ** self.concentration)) - - def _inverse(self, y): - y = self._maybe_assert_valid_y(y) - return self.scale * (-math_ops.log1p(-y)) ** (1 / self.concentration) - - def _inverse_log_det_jacobian(self, y): - y = self._maybe_assert_valid_y(y) - event_dims = self._event_dims_tensor(y) - return math_ops.reduce_sum( - -math_ops.log1p(-y) + - (1 / self.concentration - 1) * math_ops.log(-math_ops.log1p(-y)) + - math_ops.log(self.scale / self.concentration), - axis=event_dims) - - def _forward_log_det_jacobian(self, x): - x = self._maybe_assert_valid_x(x) - event_dims = self._event_dims_tensor(x) - return math_ops.reduce_sum( - -(x / self.scale) ** self.concentration + - (self.concentration - 1) * math_ops.log(x) + - math_ops.log(self.concentration) + - -self.concentration * math_ops.log(self.scale), - axis=event_dims) - - def _maybe_assert_valid_x(self, x): - if not self.validate_args: - return x - is_valid = check_ops.assert_non_negative( - x, - message="Forward transformation input must be at least {}.".format(0)) - return control_flow_ops.with_dependencies([is_valid], x) - - def _maybe_assert_valid_y(self, y): - if not self.validate_args: - return y - is_positive = check_ops.assert_non_negative( - y, message="Inverse transformation input must be greater than 0.") - less_than_one = check_ops.assert_less_equal( - y, constant_op.constant(1., y.dtype), - message="Inverse transformation input must be less than or equal to 1.") - return control_flow_ops.with_dependencies([is_positive, less_than_one], y) diff --git a/tensorflow/contrib/distributions/python/ops/cauchy.py b/tensorflow/contrib/distributions/python/ops/cauchy.py index 8d59c1abfbc607c67b2bbca21f880743a43e5b2a..6f5d724a2a945ed8f9c159d8314327c6f994d1db 100644 --- a/tensorflow/contrib/distributions/python/ops/cauchy.py +++ b/tensorflow/contrib/distributions/python/ops/cauchy.py @@ -43,16 +43,17 @@ class Cauchy(distribution.Distribution): The probability density function (pdf) is, ```none - pdf(x; loc, scale) = 1 / (pi * scale * (1 + ((x - loc) / scale)**2)) + pdf(x; loc, scale) = 1 / (pi scale (1 + z**2)) + z = (x - loc) / scale ``` where `loc` is the location, and `scale` is the scale. The Cauchy distribution is a member of the [location-scale family]( https://en.wikipedia.org/wiki/Location-scale_family), i.e. + `Y ~ Cauchy(loc, scale)` is equivalent to, ```none X ~ Cauchy(loc=0, scale=1) - Y ~ Cauchy(loc=loc, scale=scale) Y = loc + scale * X ``` @@ -61,14 +62,16 @@ class Cauchy(distribution.Distribution): Examples of initialization of one or a batch of distributions. ```python + tfd = tf.contrib.distributions + # Define a single scalar Cauchy distribution. - dist = Cauchy(loc=0., scale=3.) + dist = tfd.Cauchy(loc=0., scale=3.) # Evaluate the cdf at 1, returning a scalar. dist.cdf(1.) # Define a batch of two scalar valued Cauchy distributions. - dist = Cauchy(loc=[1, 2.], scale=[11, 22.]) + dist = tfd.Cauchy(loc=[1, 2.], scale=[11, 22.]) # Evaluate the pdf of the first distribution on 0, and the second on 1.5, # returning a length two tensor. @@ -76,18 +79,17 @@ class Cauchy(distribution.Distribution): # Get 3 samples, returning a 3 x 2 tensor. dist.sample([3]) - ``` - - Arguments are broadcast when possible. - ```python + # Arguments are broadcast when possible. # Define a batch of two scalar valued Cauchy distributions. # Both have median 1, but different scales. - dist = tf.contrib.distributions.Cauchy(loc=1., scale=[11, 22.]) + dist = tfd.Cauchy(loc=1., scale=[11, 22.]) + # Evaluate the pdf of both distributions on the same point, 3.0, # returning a length 2 tensor. - dist.prob(3.0) + dist.prob(3.) ``` + """ def __init__(self, diff --git a/tensorflow/contrib/distributions/python/ops/conditional_transformed_distribution.py b/tensorflow/contrib/distributions/python/ops/conditional_transformed_distribution.py index 599c855cda434d9249187d5d154d50a8a8c49a6c..1d4c5660d8d73b7b6a7e758fc834ccfddeb5c8ea 100644 --- a/tensorflow/contrib/distributions/python/ops/conditional_transformed_distribution.py +++ b/tensorflow/contrib/distributions/python/ops/conditional_transformed_distribution.py @@ -121,7 +121,7 @@ class ConditionalTransformedDistribution( log_prob = self.distribution.log_prob(x, **distribution_kwargs) if self._is_maybe_event_override: log_prob = math_ops.reduce_sum(log_prob, self._reduce_event_indices) - return ildj + log_prob + return math_ops.cast(ildj, log_prob.dtype) + log_prob @distribution_util.AppendDocstring(kwargs_dict=_condition_kwargs_dict) def _prob(self, y, bijector_kwargs=None, distribution_kwargs=None): @@ -143,7 +143,7 @@ class ConditionalTransformedDistribution( prob = self.distribution.prob(x, **distribution_kwargs) if self._is_maybe_event_override: prob = math_ops.reduce_prod(prob, self._reduce_event_indices) - return math_ops.exp(ildj) * prob + return math_ops.exp(math_ops.cast(ildj, prob.dtype)) * prob @distribution_util.AppendDocstring(kwargs_dict=_condition_kwargs_dict) def _log_cdf(self, y, bijector_kwargs=None, distribution_kwargs=None): diff --git a/tensorflow/contrib/distributions/python/ops/deterministic.py b/tensorflow/contrib/distributions/python/ops/deterministic.py index 850d08d1bd69ebc7661557d648e2bffe77e6a908..8049522e9f5dc26b244b7e710a9ae8b981efd6b6 100644 --- a/tensorflow/contrib/distributions/python/ops/deterministic.py +++ b/tensorflow/contrib/distributions/python/ops/deterministic.py @@ -290,8 +290,10 @@ class VectorDeterministic(_BaseDeterministic): #### Examples ```python + tfd = tf.contrib.distributions + # Initialize a single VectorDeterministic supported at [0., 2.] in R^2. - constant = tf.contrib.distributions.Deterministic([0., 2.]) + constant = tfd.Deterministic([0., 2.]) constant.prob([0., 2.]) ==> 1. constant.prob([0., 3.]) @@ -299,7 +301,7 @@ class VectorDeterministic(_BaseDeterministic): # Initialize a [3] batch of constants on R^2. loc = [[0., 1.], [2., 3.], [4., 5.]] - constant = constant_lib.VectorDeterministic(loc) + constant = tfd.VectorDeterministic(loc) constant.prob([[0., 1.], [1.9, 3.], [3.99, 5.]]) ==> [1., 0., 0.] ``` diff --git a/tensorflow/contrib/distributions/python/ops/distribution_util.py b/tensorflow/contrib/distributions/python/ops/distribution_util.py index 869b5698e57d199755ce1686a74a1eafe3b73e7d..a4d249d41ec9733721a3583d3708e0da56db1733 100644 --- a/tensorflow/contrib/distributions/python/ops/distribution_util.py +++ b/tensorflow/contrib/distributions/python/ops/distribution_util.py @@ -19,9 +19,7 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib import linalg -from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops @@ -330,54 +328,14 @@ def shapes_from_loc_and_scale(loc, scale, name="shapes_from_loc_and_scale"): else: loc_batch_shape = ops.convert_to_tensor(loc_batch_shape, name="loc_batch_shape") + # This is defined in the core util module. + # pylint: disable=undefined-variable batch_shape = prefer_static_broadcast_shape(batch_shape, loc_batch_shape) + # pylint: enable=undefined-variable return batch_shape, event_shape -def prefer_static_broadcast_shape( - shape1, shape2, name="prefer_static_broadcast_shape"): - """Convenience function which statically broadcasts shape when possible. - - Args: - shape1: `1-D` integer `Tensor`. Already converted to tensor! - shape2: `1-D` integer `Tensor`. Already converted to tensor! - name: A string name to prepend to created ops. - - Returns: - The broadcast shape, either as `TensorShape` (if broadcast can be done - statically), or as a `Tensor`. - """ - with ops.name_scope(name, values=[shape1, shape2]): - def make_shape_tensor(x): - return ops.convert_to_tensor(x, name="shape", dtype=dtypes.int32) - - def get_tensor_shape(s): - if isinstance(s, tensor_shape.TensorShape): - return s - s_ = tensor_util.constant_value(make_shape_tensor(s)) - if s_ is not None: - return tensor_shape.TensorShape(s_) - return None - - def get_shape_tensor(s): - if not isinstance(s, tensor_shape.TensorShape): - return make_shape_tensor(s) - if s.is_fully_defined(): - return make_shape_tensor(s.as_list()) - raise ValueError("Cannot broadcast from partially " - "defined `TensorShape`.") - - shape1_ = get_tensor_shape(shape1) - shape2_ = get_tensor_shape(shape2) - if shape1_ is not None and shape2_ is not None: - return array_ops.broadcast_static_shape(shape1_, shape2_) - - shape1_ = get_shape_tensor(shape1) - shape2_ = get_shape_tensor(shape2) - return array_ops.broadcast_dynamic_shape(shape1_, shape2_) - - def get_broadcast_shape(*tensors): """Get broadcast shape as a Python list of integers (preferred) or `Tensor`. diff --git a/tensorflow/contrib/distributions/python/ops/gumbel.py b/tensorflow/contrib/distributions/python/ops/gumbel.py index ba8d3c639b397422f0f6210ba9f48650f0da1e3e..d0efaefb8e78ddf4436e9e5a112d2c1cdddaf3b5 100644 --- a/tensorflow/contrib/distributions/python/ops/gumbel.py +++ b/tensorflow/contrib/distributions/python/ops/gumbel.py @@ -62,15 +62,17 @@ class _Gumbel(distribution.Distribution): Examples of initialization of one or a batch of distributions. ```python + tfd = tf.contrib.distributions + # Define a single scalar Gumbel distribution. - dist = tf.contrib.distributions.Gumbel(loc=0., scale=3.) + dist = tfd.Gumbel(loc=0., scale=3.) # Evaluate the cdf at 1, returning a scalar. dist.cdf(1.) # Define a batch of two scalar valued Gumbels. # The first has mean 1 and scale 11, the second 2 and 22. - dist = tf.contrib.distributions.Gumbel(loc=[1, 2.], scale=[11, 22.]) + dist = tfd.Gumbel(loc=[1, 2.], scale=[11, 22.]) # Evaluate the pdf of the first distribution on 0, and the second on 1.5, # returning a length two tensor. @@ -85,7 +87,7 @@ class _Gumbel(distribution.Distribution): ```python # Define a batch of two scalar valued Logistics. # Both have mean 1, but different scales. - dist = tf.contrib.distributions.Gumbel(loc=1., scale=[11, 22.]) + dist = tfd.Gumbel(loc=1., scale=[11, 22.]) # Evaluate the pdf of both distributions on the same point, 3.0, # returning a length 2 tensor. diff --git a/tensorflow/contrib/distributions/python/ops/half_normal.py b/tensorflow/contrib/distributions/python/ops/half_normal.py index 12059b6a9e199dc3ae00ac47a62ece9c9a147000..fc0751a6e0b78cb3d79bd3478e740bb05cd26428 100644 --- a/tensorflow/contrib/distributions/python/ops/half_normal.py +++ b/tensorflow/contrib/distributions/python/ops/half_normal.py @@ -84,6 +84,7 @@ class HalfNormal(distribution.Distribution): ``` """ + def __init__(self, scale, validate_args=False, @@ -120,7 +121,7 @@ class HalfNormal(distribution.Distribution): @staticmethod def _param_shapes(sample_shape): - return {'scale': ops.convert_to_tensor(sample_shape, dtype=dtypes.int32)} + return {"scale": ops.convert_to_tensor(sample_shape, dtype=dtypes.int32)} @property def scale(self): diff --git a/tensorflow/contrib/distributions/python/ops/independent.py b/tensorflow/contrib/distributions/python/ops/independent.py index 6a74ca9a0ae1ad30081d21cc15a65be052a99e2a..cbce005013281ff3c58c94d525d5ce7a865d725a 100644 --- a/tensorflow/contrib/distributions/python/ops/independent.py +++ b/tensorflow/contrib/distributions/python/ops/independent.py @@ -68,11 +68,11 @@ class Independent(distribution_lib.Distribution): #### Examples ```python - ds = tf.contrib.distributions + tfd = tf.contrib.distributions # Make independent distribution from a 2-batch Normal. - ind = ds.Independent( - distribution=ds.Normal(loc=[-1., 1], scale=[0.1, 0.5]), + ind = tfd.Independent( + distribution=tfd.Normal(loc=[-1., 1], scale=[0.1, 0.5]), reinterpreted_batch_ndims=1) # All batch dims have been "absorbed" into event dims. @@ -80,8 +80,8 @@ class Independent(distribution_lib.Distribution): ind.event_shape # ==> [2] # Make independent distribution from a 2-batch bivariate Normal. - ind = ds.Independent( - distribution=ds.MultivariateNormalDiag( + ind = tfd.Independent( + distribution=tfd.MultivariateNormalDiag( loc=[[-1., 1], [1, -1]], scale_identity_multiplier=[1., 0.5]), reinterpreted_batch_ndims=1) diff --git a/tensorflow/contrib/distributions/python/ops/inverse_gamma.py b/tensorflow/contrib/distributions/python/ops/inverse_gamma.py index 956dee38a378813434656a28a69c89b6ec1e8b72..ee4d86867d48b20e97757bcec57d452085814b80 100644 --- a/tensorflow/contrib/distributions/python/ops/inverse_gamma.py +++ b/tensorflow/contrib/distributions/python/ops/inverse_gamma.py @@ -88,8 +88,9 @@ class InverseGamma(distribution.Distribution): #### Examples ```python - dist = InverseGamma(concentration=3.0, rate=2.0) - dist2 = InverseGamma(concentration=[3.0, 4.0], rate=[2.0, 3.0]) + tfd = tf.contrib.distributions + dist = tfd.InverseGamma(concentration=3.0, rate=2.0) + dist2 = tfd.InverseGamma(concentration=[3.0, 4.0], rate=[2.0, 3.0]) ``` """ diff --git a/tensorflow/contrib/distributions/python/ops/logistic.py b/tensorflow/contrib/distributions/python/ops/logistic.py index 48794a48828fe796e233e968d8c755136ce166ad..473677f8d91b184e029f345bb05f5c5d63df7a40 100644 --- a/tensorflow/contrib/distributions/python/ops/logistic.py +++ b/tensorflow/contrib/distributions/python/ops/logistic.py @@ -60,15 +60,17 @@ class Logistic(distribution.Distribution): Examples of initialization of one or a batch of distributions. ```python + tfd = tf.contrib.distributions + # Define a single scalar Logistic distribution. - dist = tf.contrib.distributions.Logistic(loc=0., scale=3.) + dist = tfd.Logistic(loc=0., scale=3.) # Evaluate the cdf at 1, returning a scalar. dist.cdf(1.) # Define a batch of two scalar valued Logistics. # The first has mean 1 and scale 11, the second 2 and 22. - dist = tf.contrib.distributions.Logistic(loc=[1, 2.], scale=[11, 22.]) + dist = tfd.Logistic(loc=[1, 2.], scale=[11, 22.]) # Evaluate the pdf of the first distribution on 0, and the second on 1.5, # returning a length two tensor. @@ -76,14 +78,11 @@ class Logistic(distribution.Distribution): # Get 3 samples, returning a 3 x 2 tensor. dist.sample([3]) - ``` - Arguments are broadcast when possible. - - ```python + # Arguments are broadcast when possible. # Define a batch of two scalar valued Logistics. # Both have mean 1, but different scales. - dist = tf.contrib.distributions.Logistic(loc=1., scale=[11, 22.]) + dist = tfd.Logistic(loc=1., scale=[11, 22.]) # Evaluate the pdf of both distributions on the same point, 3.0, # returning a length 2 tensor. diff --git a/tensorflow/contrib/distributions/python/ops/mixture.py b/tensorflow/contrib/distributions/python/ops/mixture.py index e676931d9145e72907d990148ee2d180e0da0258..f2d492f5489a197157558ae727416b51db04793e 100644 --- a/tensorflow/contrib/distributions/python/ops/mixture.py +++ b/tensorflow/contrib/distributions/python/ops/mixture.py @@ -49,13 +49,13 @@ class Mixture(distribution.Distribution): ```python # Create a mixture of two Gaussians: - ds = tf.contrib.distributions + tfd = tf.contrib.distributions mix = 0.3 - bimix_gauss = ds.Mixture( - cat=ds.Categorical(probs=[mix, 1.-mix]), + bimix_gauss = tfd.Mixture( + cat=tfd.Categorical(probs=[mix, 1.-mix]), components=[ - ds.Normal(loc=-1., scale=0.1), - ds.Normal(loc=+1., scale=0.5), + tfd.Normal(loc=-1., scale=0.1), + tfd.Normal(loc=+1., scale=0.5), ]) # Plot the PDF. diff --git a/tensorflow/contrib/distributions/python/ops/mixture_same_family.py b/tensorflow/contrib/distributions/python/ops/mixture_same_family.py index 5558ef0f255db684b229d129666634e50c625887..49afbea7f05136674aa0c1441bd46548b7b55c8f 100644 --- a/tensorflow/contrib/distributions/python/ops/mixture_same_family.py +++ b/tensorflow/contrib/distributions/python/ops/mixture_same_family.py @@ -43,15 +43,14 @@ class MixtureSameFamily(distribution.Distribution): #### Examples ```python - import matplotlib.pyplot as plt - ds = tf.contrib.distributions + tfd = tf.contrib.distributions ### Create a mixture of two scalar Gaussians: - gm = ds.MixtureSameFamily( - mixture_distribution=ds.Categorical( + gm = tfd.MixtureSameFamily( + mixture_distribution=tfd.Categorical( probs=[0.3, 0.7]), - components_distribution=ds.Normal( + components_distribution=tfd.Normal( loc=[-1., 1], # One for each component. scale=[0.1, 0.5])) # And same here. @@ -63,14 +62,15 @@ class MixtureSameFamily(distribution.Distribution): # Plot PDF. x = np.linspace(-2., 3., int(1e4), dtype=np.float32) + import matplotlib.pyplot as plt plt.plot(x, gm.prob(x).eval()); ### Create a mixture of two Bivariate Gaussians: - gm = ds.MixtureSameFamily( - mixture_distribution=ds.Categorical( + gm = tfd.MixtureSameFamily( + mixture_distribution=tfd.Categorical( probs=[0.3, 0.7]), - components_distribution=ds.MultivariateNormalDiag( + components_distribution=tfd.MultivariateNormalDiag( loc=[[-1., 1], # component 1 [1, -1]], # component 2 scale_identity_multiplier=[.3, .6])) @@ -248,7 +248,7 @@ class MixtureSameFamily(distribution.Distribution): x = self._pad_sample_dims(x) log_prob_x = self.components_distribution.log_prob(x) # [S, B, k] log_mix_prob = nn_ops.log_softmax( - self.mixture_distribution.logits, dim=-1) # [B, k] + self.mixture_distribution.logits, axis=-1) # [B, k] return math_ops.reduce_logsumexp( log_prob_x + log_mix_prob, axis=-1) # [S, B] @@ -264,7 +264,7 @@ class MixtureSameFamily(distribution.Distribution): x = self._pad_sample_dims(x) log_cdf_x = self.components_distribution.log_cdf(x) # [S, B, k] log_mix_prob = nn_ops.log_softmax( - self.mixture_distribution.logits, dim=-1) # [B, k] + self.mixture_distribution.logits, axis=-1) # [B, k] return math_ops.reduce_logsumexp( log_cdf_x + log_mix_prob, axis=-1) # [S, B] @@ -320,13 +320,14 @@ class MixtureSameFamily(distribution.Distribution): return array_ops.shape(d.batch_shape_tensor())[0] dist_batch_ndims = _get_ndims(self) cat_batch_ndims = _get_ndims(self.mixture_distribution) - bnd = distribution_util.pick_vector( + pad_ndims = array_ops.where( self.mixture_distribution.is_scalar_batch(), - [dist_batch_ndims], [cat_batch_ndims])[0] + dist_batch_ndims, + dist_batch_ndims - cat_batch_ndims) s = array_ops.shape(x) x = array_ops.reshape(x, shape=array_ops.concat([ s[:-1], - array_ops.ones([bnd], dtype=dtypes.int32), + array_ops.ones([pad_ndims], dtype=dtypes.int32), s[-1:], array_ops.ones([self._event_ndims], dtype=dtypes.int32), ], axis=0)) diff --git a/tensorflow/contrib/distributions/python/ops/mvn_diag.py b/tensorflow/contrib/distributions/python/ops/mvn_diag.py index 163cf75d990d5fe7ec1e3aaf0040fc71f61774a7..e862552880f4073c8fa8e90134d0633e7484b0bf 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn_diag.py +++ b/tensorflow/contrib/distributions/python/ops/mvn_diag.py @@ -84,10 +84,10 @@ class MultivariateNormalDiag( #### Examples ```python - ds = tf.contrib.distributions + tfd = tf.contrib.distributions # Initialize a single 2-variate Gaussian. - mvn = ds.MultivariateNormalDiag( + mvn = tfd.MultivariateNormalDiag( loc=[1., -1], scale_diag=[1, 2.]) @@ -101,7 +101,7 @@ class MultivariateNormalDiag( mvn.prob([-1., 0]).eval() # shape: [] # Initialize a 3-batch, 2-variate scaled-identity Gaussian. - mvn = ds.MultivariateNormalDiag( + mvn = tfd.MultivariateNormalDiag( loc=[1., -1], scale_identity_multiplier=[1, 2., 3]) @@ -119,7 +119,7 @@ class MultivariateNormalDiag( mvn.prob([-1., 0]).eval() # shape: [3] # Initialize a 2-batch of 3-variate Gaussians. - mvn = ds.MultivariateNormalDiag( + mvn = tfd.MultivariateNormalDiag( loc=[[1., 2, 3], [11, 22, 33]] # shape: [2, 3] scale_diag=[[1., 2, 3], diff --git a/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py b/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py index 040bc230722194316b8a74627344e315a2578281..413e88f03ae0286c294f3404549a73e1a47dcff7 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py +++ b/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py @@ -86,7 +86,7 @@ class MultivariateNormalDiagPlusLowRank( #### Examples ```python - ds = tf.contrib.distributions + tfd = tf.contrib.distributions # Initialize a single 3-variate Gaussian with covariance `cov = S @ S.T`, # `S = diag(d) + U @ diag(m) @ U.T`. The perturbation, `U @ diag(m) @ U.T`, is @@ -97,7 +97,7 @@ class MultivariateNormalDiagPlusLowRank( [-1, 1], [2, -0.5]] # shape: [3, 2] m = [4., 5] # shape: [2] - mvn = ds.MultivariateNormalDiagPlusLowRank( + mvn = tfd.MultivariateNormalDiagPlusLowRank( loc=mu scale_diag=d scale_perturb_factor=U, @@ -118,7 +118,7 @@ class MultivariateNormalDiagPlusLowRank( m = [[0.1, 0.2], [0.4, 0.5]] # shape: [b, r] = [2, 2] - mvn = ds.MultivariateNormalDiagPlusLowRank( + mvn = tfd.MultivariateNormalDiagPlusLowRank( loc=mu, scale_perturb_factor=U, scale_perturb_diag=m) diff --git a/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py b/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py index f9952b2069d6dfd2593e6bd71ede0badf44cdf98..4bea99fbb75349f97fde473cb5716fe6c426ce90 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py +++ b/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py @@ -73,14 +73,14 @@ class MultivariateNormalFullCovariance(mvn_tril.MultivariateNormalTriL): #### Examples ```python - ds = tf.contrib.distributions + tfd = tf.contrib.distributions # Initialize a single 3-variate Gaussian. mu = [1., 2, 3] cov = [[ 0.36, 0.12, 0.06], [ 0.12, 0.29, -0.13], [ 0.06, -0.13, 0.26]] - mvn = ds.MultivariateNormalFullCovariance( + mvn = tfd.MultivariateNormalFullCovariance( loc=mu, covariance_matrix=cov) @@ -100,7 +100,7 @@ class MultivariateNormalFullCovariance(mvn_tril.MultivariateNormalTriL): mu = [[1., 2, 3], [11, 22, 33]] # shape: [2, 3] covariance_matrix = ... # shape: [2, 3, 3], symmetric, positive definite. - mvn = ds.MultivariateNormalFullCovariance( + mvn = tfd.MultivariateNormalFullCovariance( loc=mu, covariance=covariance_matrix) @@ -167,12 +167,11 @@ class MultivariateNormalFullCovariance(mvn_tril.MultivariateNormalTriL): covariance_matrix = ops.convert_to_tensor( covariance_matrix, name="covariance_matrix") if validate_args: - assert_symmetric = check_ops.assert_equal( - covariance_matrix, - array_ops.matrix_transpose(covariance_matrix), - message="Matrix was not symmetric.") - covariance_matrix = control_flow_ops.with_dependencies( - [assert_symmetric], covariance_matrix) + covariance_matrix = control_flow_ops.with_dependencies([ + check_ops.assert_near( + covariance_matrix, + array_ops.matrix_transpose(covariance_matrix), + message="Matrix was not symmetric")], covariance_matrix) # No need to validate that covariance_matrix is non-singular. # LinearOperatorLowerTriangular has an assert_non_singular method that # is called by the Bijector. diff --git a/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py b/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py index 300bdd5f6064a1cc9c336689ac4fae04338edb30..a7399792892f4c179c05168184d76ec95c168b51 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py +++ b/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py @@ -90,8 +90,7 @@ class MultivariateNormalLinearOperator( #### Examples ```python - ds = tf.contrib.distributions - la = tf.linalg + tfd = tf.contrib.distributions # Initialize a single 3-variate Gaussian. mu = [1., 2, 3] @@ -103,9 +102,9 @@ class MultivariateNormalLinearOperator( # [ 0.2, 0.5, 0. ], # [ 0.1, -0.3, 0.4]]) - mvn = ds.MultivariateNormalLinearOperator( + mvn = tfd.MultivariateNormalLinearOperator( loc=mu, - scale=la.LinearOperatorLowerTriangular(scale)) + scale=tf.linalg.LinearOperatorLowerTriangular(scale)) # Covariance agrees with cholesky(cov) parameterization. mvn.covariance().eval() @@ -122,9 +121,9 @@ class MultivariateNormalLinearOperator( scale_diag = [[1., 2, 3], [0.5, 1, 1.5]] # shape: [2, 3] - mvn = ds.MultivariateNormalLinearOperator( + mvn = tfd.MultivariateNormalLinearOperator( loc=mu, - scale=la.LinearOperatorDiag(scale_diag)) + scale=tf.linalg.LinearOperatorDiag(scale_diag)) # Compute the pdf of two `R^3` observations; return a length-2 vector. x = [[-0.9, 0, 0.1], diff --git a/tensorflow/contrib/distributions/python/ops/mvn_tril.py b/tensorflow/contrib/distributions/python/ops/mvn_tril.py index 260dcc18f513d5440d3d39368539274c03faa72a..6c7dc4ca7aaf5b3a20b072e9360d15528ad10556 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn_tril.py +++ b/tensorflow/contrib/distributions/python/ops/mvn_tril.py @@ -76,12 +76,13 @@ class MultivariateNormalTriL( ``` Trainable (batch) lower-triangular matrices can be created with - `ds.matrix_diag_transform()` and/or `ds.fill_triangular()` + `tf.contrib.distributions.matrix_diag_transform()` and/or + `tf.contrib.distributions.fill_triangular()` #### Examples ```python - ds = tf.contrib.distributions + tfd = tf.contrib.distributions # Initialize a single 3-variate Gaussian. mu = [1., 2, 3] @@ -92,7 +93,7 @@ class MultivariateNormalTriL( # ==> [[ 0.6, 0. , 0. ], # [ 0.2, 0.5, 0. ], # [ 0.1, -0.3, 0.4]]) - mvn = ds.MultivariateNormalTriL( + mvn = tfd.MultivariateNormalTriL( loc=mu, scale_tril=scale) @@ -112,7 +113,7 @@ class MultivariateNormalTriL( mu = [[1., 2, 3], [11, 22, 33]] # shape: [2, 3] tril = ... # shape: [2, 3, 3], lower triangular, non-zero diagonal. - mvn = ds.MultivariateNormalTriL( + mvn = tfd.MultivariateNormalTriL( loc=mu, scale_tril=tril) @@ -124,9 +125,9 @@ class MultivariateNormalTriL( # Instantiate a "learnable" MVN. dims = 4 with tf.variable_scope("model"): - mvn = ds.MultivariateNormalTriL( + mvn = tfd.MultivariateNormalTriL( loc=tf.get_variable(shape=[dims], dtype=tf.float32, name="mu"), - scale_tril=ds.fill_triangular( + scale_tril=tfd.fill_triangular( tf.get_variable(shape=[dims * (dims + 1) / 2], dtype=tf.float32, name="chol_Sigma"))) ``` diff --git a/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py b/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py index e1118ed4312ca2ed678a05a298110e2669d0a27e..92f2bba1828696248c9d9460566a08ba372c3358 100644 --- a/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py +++ b/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py @@ -22,21 +22,135 @@ import numpy as np from tensorflow.contrib.distributions.python.ops import distribution_util from tensorflow.contrib.distributions.python.ops import poisson as poisson_lib +from tensorflow.contrib.distributions.python.ops.bijectors.exp import Exp +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import categorical as categorical_lib from tensorflow.python.ops.distributions import distribution as distribution_lib +from tensorflow.python.ops.distributions import normal as normal_lib +from tensorflow.python.ops.distributions import transformed_distribution as transformed_lib __all__ = [ "PoissonLogNormalQuadratureCompound", + "quadrature_scheme_lognormal_gauss_hermite", + "quadrature_scheme_lognormal_quantiles", ] +def quadrature_scheme_lognormal_gauss_hermite( + loc, scale, quadrature_size, + validate_args=False, name=None): # pylint: disable=unused-argument + """Use Gauss-Hermite quadrature to form quadrature on positive-reals. + + Note: for a given `quadrature_size`, this method is generally less accurate + than `quadrature_scheme_lognormal_quantiles`. + + Args: + loc: `float`-like (batch of) scalar `Tensor`; the location parameter of + the LogNormal prior. + scale: `float`-like (batch of) scalar `Tensor`; the scale parameter of + the LogNormal prior. + quadrature_size: Python `int` scalar representing the number of quadrature + points. + validate_args: Python `bool`, default `False`. When `True` distribution + parameters are checked for validity despite possibly degrading runtime + performance. When `False` invalid inputs may silently render incorrect + outputs. + name: Python `str` name prefixed to Ops created by this class. + + Returns: + grid: (Batch of) length-`quadrature_size` vectors representing the + `log_rate` parameters of a `Poisson`. + probs: (Batch of) length-`quadrature_size` vectors representing the + weight associate with each `grid` value. + """ + with ops.name_scope(name, "vector_diffeomixture_quadrature_gauss_hermite", + [loc, scale]): + grid, probs = np.polynomial.hermite.hermgauss(deg=quadrature_size) + grid = grid.astype(loc.dtype.as_numpy_dtype) + probs = probs.astype(loc.dtype.as_numpy_dtype) + probs /= np.linalg.norm(probs, ord=1, keepdims=True) + probs = ops.convert_to_tensor(probs, name="probs", dtype=loc.dtype) + # The following maps the broadcast of `loc` and `scale` to each grid + # point, i.e., we are creating several log-rates that correspond to the + # different Gauss-Hermite quadrature points and (possible) batches of + # `loc` and `scale`. + grid = (loc[..., array_ops.newaxis] + + np.sqrt(2.) * scale[..., array_ops.newaxis] * grid) + return grid, probs + + +def quadrature_scheme_lognormal_quantiles( + loc, scale, quadrature_size, + validate_args=False, name=None): + """Use LogNormal quantiles to form quadrature on positive-reals. + + Args: + loc: `float`-like (batch of) scalar `Tensor`; the location parameter of + the LogNormal prior. + scale: `float`-like (batch of) scalar `Tensor`; the scale parameter of + the LogNormal prior. + quadrature_size: Python `int` scalar representing the number of quadrature + points. + validate_args: Python `bool`, default `False`. When `True` distribution + parameters are checked for validity despite possibly degrading runtime + performance. When `False` invalid inputs may silently render incorrect + outputs. + name: Python `str` name prefixed to Ops created by this class. + + Returns: + grid: (Batch of) length-`quadrature_size` vectors representing the + `log_rate` parameters of a `Poisson`. + probs: (Batch of) length-`quadrature_size` vectors representing the + weight associate with each `grid` value. + """ + with ops.name_scope(name, "quadrature_scheme_lognormal_quantiles", + [loc, scale]): + # Create a LogNormal distribution. + dist = transformed_lib.TransformedDistribution( + distribution=normal_lib.Normal(loc=loc, scale=scale), + bijector=Exp(event_ndims=0), + validate_args=validate_args) + batch_ndims = dist.batch_shape.ndims + if batch_ndims is None: + batch_ndims = array_ops.shape(dist.batch_shape_tensor())[0] + + def _compute_quantiles(): + """Helper to build quantiles.""" + # Omit {0, 1} since they might lead to Inf/NaN. + zero = array_ops.zeros([], dtype=dist.dtype) + edges = math_ops.linspace(zero, 1., quadrature_size + 3)[1:-1] + # Expand edges so its broadcast across batch dims. + edges = array_ops.reshape(edges, shape=array_ops.concat([ + [-1], array_ops.ones([batch_ndims], dtype=dtypes.int32)], axis=0)) + quantiles = dist.quantile(edges) + # Cyclically permute left by one. + perm = array_ops.concat([ + math_ops.range(1, 1 + batch_ndims), [0]], axis=0) + quantiles = array_ops.transpose(quantiles, perm) + return quantiles + quantiles = _compute_quantiles() + + # Compute grid as quantile midpoints. + grid = (quantiles[..., :-1] + quantiles[..., 1:]) / 2. + # Set shape hints. + grid.set_shape(dist.batch_shape.concatenate([quadrature_size])) + + # By construction probs is constant, i.e., `1 / quadrature_size`. This is + # important, because non-constant probs leads to non-reparameterizable + # samples. + probs = array_ops.fill( + dims=[quadrature_size], + value=1. / math_ops.cast(quadrature_size, dist.dtype)) + + return grid, probs + + class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution): """`PoissonLogNormalQuadratureCompound` distribution. @@ -47,30 +161,18 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution): ```none p(k|loc, scale) = int_{R_+} dl LogNormal(l | loc, scale) Poisson(k | l) - = int_{R} dz ((lambda(z) sqrt(2) scale) - * exp(-z**2) / (lambda(z) sqrt(2 pi) sigma) - * Poisson(k | lambda(z))) - = int_{R} dz exp(-z**2) / sqrt(pi) Poisson(k | lambda(z)) approx= sum{ prob[d] Poisson(k | lambda(grid[d])) : d=0, ..., deg-1 } ``` - where `lambda(z) = exp(sqrt(2) scale z + loc)` and the `prob,grid` terms - are from [numerical quadrature]( - https://en.wikipedia.org/wiki/Numerical_integration) (default: - [Gauss--Hermite quadrature]( - https://en.wikipedia.org/wiki/Gauss%E2%80%93Hermite_quadrature)). Note that - the second line made the substitution: - `z(l) = (log(l) - loc) / (sqrt(2) scale)` which implies `lambda(z)` [above] - and `dl = sqrt(2) scale lambda(z) dz` + By default, the `grid` is chosen as quantiles of the `LogNormal` distribution + parameterized by `loc`, `scale` and the `prob` vector is + `[1. / quadrature_size]*quadrature_size`. In the non-approximation case, a draw from the LogNormal prior represents the Poisson rate parameter. Unfortunately, the non-approximate distribution lacks an analytical probability density function (pdf). Therefore the `PoissonLogNormalQuadratureCompound` class implements an approximation based - on [numerical quadrature]( - https://en.wikipedia.org/wiki/Numerical_integration) (default: - [Gauss--Hermite quadrature]( - https://en.wikipedia.org/wiki/Gauss%E2%80%93Hermite_quadrature)). + on [quadrature](https://en.wikipedia.org/wiki/Numerical_integration). Note: although the `PoissonLogNormalQuadratureCompound` is approximately the Poisson-LogNormal compound distribution, it is itself a valid distribution. @@ -84,10 +186,8 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution): https://en.wikipedia.org/wiki/Compound_probability_distribution). Using variable-substitution and [numerical quadrature]( https://en.wikipedia.org/wiki/Numerical_integration) (default: - [Gauss--Hermite quadrature]( - https://en.wikipedia.org/wiki/Gauss%E2%80%93Hermite_quadrature)) we can - redefine the distribution to be a parameter-less convex combination of `deg` - different Poisson samples. + based on `LogNormal` quantiles) we can redefine the distribution to be a + parameter-less convex combination of `deg` different Poisson samples. That is, defined over positive integers, this distribution is parameterized by a (batch of) `loc` and `scale` scalars. @@ -96,46 +196,51 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution): ```none pdf(k | loc, scale, deg) - = sum{ prob[d] Poisson(k | lambda=exp(sqrt(2) scale grid[d] + loc)) + = sum{ prob[d] Poisson(k | lambda=exp(grid[d])) : d=0, ..., deg-1 } ``` - where, [e.g., `grid, w = numpy.polynomial.hermite.hermgauss(deg)`]( - https://docs.scipy.org/doc/numpy-1.10.0/reference/generated/numpy.polynomial.hermite.hermgauss.html) - and `prob = w / sqrt(pi)`. - #### Examples ```python - ds = tf.contrib.distributions + tfd = tf.contrib.distributions + # Create two batches of PoissonLogNormalQuadratureCompounds, one with # prior `loc = 0.` and another with `loc = 1.` In both cases `scale = 1.` - pln = ds.PoissonLogNormalQuadratureCompound( + pln = tfd.PoissonLogNormalQuadratureCompound( loc=[0., -0.5], scale=1., - quadrature_grid_and_probs=( - np.polynomial.hermite.hermgauss(deg=10)), + quadrature_size=10, validate_args=True) """ def __init__(self, loc, scale, - quadrature_grid_and_probs=None, + quadrature_size=8, + quadrature_fn=quadrature_scheme_lognormal_quantiles, validate_args=False, allow_nan_stats=True, name="PoissonLogNormalQuadratureCompound"): - """Constructs the PoissonLogNormalQuadratureCompound on `R**k`. + """Constructs the PoissonLogNormalQuadratureCompound`. + + Note: `probs` returned by (optional) `quadrature_fn` are presumed to be + either a length-`quadrature_size` vector or a batch of vectors in 1-to-1 + correspondence with the returned `grid`. (I.e., broadcasting is only + partially supported.) Args: loc: `float`-like (batch of) scalar `Tensor`; the location parameter of the LogNormal prior. scale: `float`-like (batch of) scalar `Tensor`; the scale parameter of the LogNormal prior. - quadrature_grid_and_probs: Python pair of `float`-like `Tensor`s - representing the sample points and the corresponding (possibly - normalized) weight. When `None`, defaults to: - `np.polynomial.hermite.hermgauss(deg=8)`. + quadrature_size: Python `int` scalar representing the number of quadrature + points. + quadrature_fn: Python callable taking `loc`, `scale`, + `quadrature_size`, `validate_args` and returning `tuple(grid, probs)` + representing the LogNormal grid and corresponding normalized weight. + normalized) weight. + Default value: `quadrature_scheme_lognormal_quantiles`. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect @@ -147,47 +252,41 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution): name: Python `str` name prefixed to Ops created by this class. Raises: - TypeError: if `loc.dtype != scale[0].dtype`. + TypeError: if `quadrature_grid` and `quadrature_probs` have different base + `dtype`. """ parameters = locals() with ops.name_scope(name, values=[loc, scale]): - loc = ops.convert_to_tensor(loc, name="loc") - self._loc = loc + if loc is not None: + loc = ops.convert_to_tensor(loc, name="loc") + if scale is not None: + scale = ops.convert_to_tensor( + scale, dtype=None if loc is None else loc.dtype, name="scale") + self._quadrature_grid, self._quadrature_probs = tuple(quadrature_fn( + loc, scale, quadrature_size, validate_args)) + + dt = self._quadrature_grid.dtype + if dt.base_dtype != self._quadrature_probs.dtype.base_dtype: + raise TypeError("Quadrature grid dtype ({}) does not match quadrature " + "probs dtype ({}).".format( + dt.name, self._quadrature_probs.dtype.name)) - scale = ops.convert_to_tensor(scale, name="scale") - self._scale = scale - - dtype = loc.dtype.base_dtype - if dtype != scale.dtype.base_dtype: - raise TypeError( - "loc.dtype(\"{}\") does not match scale.dtype(\"{}\")".format( - loc.dtype.name, scale.dtype.name)) - - grid, probs = distribution_util.process_quadrature_grid_and_probs( - quadrature_grid_and_probs, dtype, validate_args) - self._quadrature_grid = grid - self._quadrature_probs = probs - self._quadrature_size = distribution_util.dimension_size(probs, axis=0) + self._distribution = poisson_lib.Poisson( + log_rate=self._quadrature_grid, + validate_args=validate_args, + allow_nan_stats=allow_nan_stats) self._mixture_distribution = categorical_lib.Categorical( logits=math_ops.log(self._quadrature_probs), validate_args=validate_args, allow_nan_stats=allow_nan_stats) - # The following maps the broadcast of `loc` and `scale` to each grid - # point, i.e., we are creating several log-rates that correspond to the - # different Gauss-Hermite quadrature points and (possible) batches of - # `loc` and `scale`. - self._log_rate = (loc[..., array_ops.newaxis] - + np.sqrt(2.) * scale[..., array_ops.newaxis] * grid) - - self._distribution = poisson_lib.Poisson( - log_rate=self._log_rate, - validate_args=validate_args, - allow_nan_stats=allow_nan_stats) + self._loc = loc + self._scale = scale + self._quadrature_size = quadrature_size super(PoissonLogNormalQuadratureCompound, self).__init__( - dtype=dtype, + dtype=dt, reparameterization_type=distribution_lib.NOT_REPARAMETERIZED, validate_args=validate_args, allow_nan_stats=allow_nan_stats, @@ -197,12 +296,12 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution): @property def mixture_distribution(self): - """Distribution which randomly selects a Poisson with Gauss-Hermite rate.""" + """Distribution which randomly selects a Poisson with quadrature param.""" return self._mixture_distribution @property def distribution(self): - """Base Poisson parameterized by a Gauss-Hermite grid of rates.""" + """Base Poisson parameterized by a quadrature grid.""" return self._distribution @property @@ -216,24 +315,18 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution): return self._scale @property - def quadrature_grid(self): - """Quadrature grid points.""" - return self._quadrature_grid - - @property - def quadrature_probs(self): - """Quadrature normalized weights.""" - return self._quadrature_probs + def quadrature_size(self): + return self._quadrature_size def _batch_shape_tensor(self): return array_ops.broadcast_dynamic_shape( - array_ops.shape(self.loc), - array_ops.shape(self.scale)) + self.distribution.batch_shape_tensor(), + array_ops.shape(self.mixture_distribution.logits))[:-1] def _batch_shape(self): return array_ops.broadcast_static_shape( - self.loc.shape, - self.scale.shape) + self.distribution.batch_shape, + self.mixture_distribution.logits.shape)[:-1] def _event_shape(self): return tensor_shape.scalar() @@ -241,18 +334,31 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution): def _sample_n(self, n, seed=None): # Get ids as a [n, batch_size]-shaped matrix, unless batch_shape=[] then get # ids as a [n]-shaped vector. - batch_size = (np.prod(self.batch_shape.as_list(), dtype=np.int32) - if self.batch_shape.is_fully_defined() - else math_ops.reduce_prod(self.batch_shape_tensor())) + batch_size = self.batch_shape.num_elements() + if batch_size is None: + batch_size = math_ops.reduce_prod(self.batch_shape_tensor()) + # We need to "sample extra" from the mixture distribution if it doesn't + # already specify a probs vector for each batch coordinate. + # We only support this kind of reduced broadcasting, i.e., there is exactly + # one probs vector for all batch dims or one for each. ids = self._mixture_distribution.sample( sample_shape=concat_vectors( [n], distribution_util.pick_vector( - self.is_scalar_batch(), - np.int32([]), - [batch_size])), + self.mixture_distribution.is_scalar_batch(), + [batch_size], + np.int32([]))), seed=distribution_util.gen_new_seed( seed, "poisson_lognormal_quadrature_compound")) + # We need to flatten batch dims in case mixture_distribution has its own + # batch dims. + ids = array_ops.reshape(ids, shape=concat_vectors( + [n], + distribution_util.pick_vector( + self.is_scalar_batch(), + np.int32([]), + np.int32([-1])))) + # Stride `quadrature_size` for `batch_size` number of times. offset = math_ops.range(start=0, limit=batch_size * self._quadrature_size, @@ -275,7 +381,7 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution): def _mean(self): return math_ops.exp( math_ops.reduce_logsumexp( - self.mixture_distribution.logits + self._log_rate, + self.mixture_distribution.logits + self.distribution.log_rate, axis=-1)) def _variance(self): @@ -300,7 +406,7 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution): # Var[E[Z | V]] = sum{ prob[d] (Mean[d] - Mean)**2 : d=0, ..., deg-1 } v = array_ops.stack([ # log(self.distribution.variance()) = log(Var[d]) = log(rate[d]) - self._log_rate, + self.distribution.log_rate, # log((Mean[d] - Mean)**2) 2. * math_ops.log( math_ops.abs(self.distribution.mean() @@ -311,14 +417,9 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution): axis=[-2, -1]) -def static_value(x): - """Returns the static value of a `Tensor` or `None`.""" - return tensor_util.constant_value(ops.convert_to_tensor(x)) - - def concat_vectors(*args): """Concatenates input vectors, statically if possible.""" - args_ = [static_value(x) for x in args] + args_ = [distribution_util.static_value(x) for x in args] if any(vec is None for vec in args_): return array_ops.concat(args, axis=0) return [val for vec in args_ for val in vec] diff --git a/tensorflow/contrib/distributions/python/ops/sample_stats.py b/tensorflow/contrib/distributions/python/ops/sample_stats.py index 2a4b92c72900f79785e7e34b77179d3decbace5b..dfc813361977c159d8d48f9d5b9ff03db5b4acdc 100644 --- a/tensorflow/contrib/distributions/python/ops/sample_stats.py +++ b/tensorflow/contrib/distributions/python/ops/sample_stats.py @@ -28,12 +28,190 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import spectral_ops +from tensorflow.python.ops.distributions import util __all__ = [ + "auto_correlation", "percentile", ] +# TODO(langmore) Write separate versions of this for real/complex dtype, taking +# advantage of optimized real-fft ops. +def auto_correlation( + x, + axis=-1, + max_lags=None, + center=True, + normalize=True, + name="auto_correlation"): + """Auto correlation along one axis. + + Given a `1-D` wide sense stationary (WSS) sequence `X`, the auto correlation + `RXX` may be defined as (with `E` expectation and `Conj` complex conjugate) + + ``` + RXX[m] := E{ W[m] Conj(W[0]) } = E{ W[0] Conj(W[-m]) }, + W[n] := (X[n] - MU) / S, + MU := E{ X[0] }, + S**2 := E{ (X[0] - MU) Conj(X[0] - MU) }. + ``` + + This function takes the viewpoint that `x` is (along one axis) a finite + sub-sequence of a realization of (WSS) `X`, and then uses `x` to produce an + estimate of `RXX[m]` as follows: + + After extending `x` from length `L` to `inf` by zero padding, the auto + correlation estimate `rxx[m]` is computed for `m = 0, 1, ..., max_lags` as + + ``` + rxx[m] := (L - m)**-1 sum_n w[n + m] Conj(w[n]), + w[n] := (x[n] - mu) / s, + mu := L**-1 sum_n x[n], + s**2 := L**-1 sum_n (x[n] - mu) Conj(x[n] - mu) + ``` + + The error in this estimate is proportional to `1 / sqrt(len(x) - m)`, so users + often set `max_lags` small enough so that the entire output is meaningful. + + Note that since `mu` is an imperfect estimate of `E{ X[0] }`, and we divide by + `len(x) - m` rather than `len(x) - m - 1`, our estimate of auto correlation + contains a slight bias, which goes to zero as `len(x) - m --> infinity`. + + Args: + x: `float32` or `complex64` `Tensor`. + axis: Python `int`. The axis number along which to compute correlation. + Other dimensions index different batch members. + max_lags: Positive `int` tensor. The maximum value of `m` to consider + (in equation above). If `max_lags >= x.shape[axis]`, we effectively + re-set `max_lags` to `x.shape[axis] - 1`. + center: Python `bool`. If `False`, do not subtract the mean estimate `mu` + from `x[n]` when forming `w[n]`. + normalize: Python `bool`. If `False`, do not divide by the variance + estimate `s**2` when forming `w[n]`. + name: `String` name to prepend to created ops. + + Returns: + `rxx`: `Tensor` of same `dtype` as `x`. `rxx.shape[i] = x.shape[i]` for + `i != axis`, and `rxx.shape[axis] = max_lags + 1`. + + Raises: + TypeError: If `x` is not a supported type. + """ + # Implementation details: + # Extend length N / 2 1-D array x to length N by zero padding onto the end. + # Then, set + # F[x]_k := sum_n x_n exp{-i 2 pi k n / N }. + # It is not hard to see that + # F[x]_k Conj(F[x]_k) = F[R]_k, where + # R_m := sum_n x_n Conj(x_{(n - m) mod N}). + # One can also check that R_m / (N / 2 - m) is an unbiased estimate of RXX[m]. + + # Since F[x] is the DFT of x, this leads us to a zero-padding and FFT/IFFT + # based version of estimating RXX. + # Note that this is a special case of the Wiener-Khinchin Theorem. + with ops.name_scope(name, values=[x]): + x = ops.convert_to_tensor(x, name="x") + + # Rotate dimensions of x in order to put axis at the rightmost dim. + # FFT op requires this. + rank = util.prefer_static_rank(x) + if axis < 0: + axis = rank + axis + shift = rank - 1 - axis + # Suppose x.shape[axis] = T, so there are T "time" steps. + # ==> x_rotated.shape = B + [T], + # where B is x_rotated's batch shape. + x_rotated = util.rotate_transpose(x, shift) + + if center: + x_rotated -= math_ops.reduce_mean(x_rotated, axis=-1, keepdims=True) + + # x_len = N / 2 from above explanation. The length of x along axis. + # Get a value for x_len that works in all cases. + x_len = util.prefer_static_shape(x_rotated)[-1] + + # TODO(langmore) Investigate whether this zero padding helps or hurts. At + # the moment is is necessary so that all FFT implementations work. + # Zero pad to the next power of 2 greater than 2 * x_len, which equals + # 2**(ceil(Log_2(2 * x_len))). Note: Log_2(X) = Log_e(X) / Log_e(2). + x_len_float64 = math_ops.cast(x_len, np.float64) + target_length = math_ops.pow( + np.float64(2.), + math_ops.ceil(math_ops.log(x_len_float64 * 2) / np.log(2.))) + pad_length = math_ops.cast(target_length - x_len_float64, np.int32) + + # We should have: + # x_rotated_pad.shape = x_rotated.shape[:-1] + [T + pad_length] + # = B + [T + pad_length] + x_rotated_pad = util.pad(x_rotated, axis=-1, back=True, count=pad_length) + + dtype = x.dtype + if not dtype.is_complex: + if not dtype.is_floating: + raise TypeError("Argument x must have either float or complex dtype" + " found: {}".format(dtype)) + x_rotated_pad = math_ops.complex(x_rotated_pad, + dtype.real_dtype.as_numpy_dtype(0.)) + + # Autocorrelation is IFFT of power-spectral density (up to some scaling). + fft_x_rotated_pad = spectral_ops.fft(x_rotated_pad) + spectral_density = fft_x_rotated_pad * math_ops.conj(fft_x_rotated_pad) + # shifted_product is R[m] from above detailed explanation. + # It is the inner product sum_n X[n] * Conj(X[n - m]). + shifted_product = spectral_ops.ifft(spectral_density) + + # Cast back to real-valued if x was real to begin with. + shifted_product = math_ops.cast(shifted_product, dtype) + + # Figure out if we can deduce the final static shape, and set max_lags. + # Use x_rotated as a reference, because it has the time dimension in the far + # right, and was created before we performed all sorts of crazy shape + # manipulations. + know_static_shape = True + if not x_rotated.shape.is_fully_defined(): + know_static_shape = False + if max_lags is None: + max_lags = x_len - 1 + else: + max_lags = ops.convert_to_tensor(max_lags, name="max_lags") + max_lags_ = tensor_util.constant_value(max_lags) + if max_lags_ is None or not know_static_shape: + know_static_shape = False + max_lags = math_ops.minimum(x_len - 1, max_lags) + else: + max_lags = min(x_len - 1, max_lags_) + + # Chop off the padding. + # We allow users to provide a huge max_lags, but cut it off here. + # shifted_product_chopped.shape = x_rotated.shape[:-1] + [max_lags] + shifted_product_chopped = shifted_product[..., :max_lags + 1] + + # If possible, set shape. + if know_static_shape: + chopped_shape = x_rotated.shape.as_list() + chopped_shape[-1] = min(x_len, max_lags + 1) + shifted_product_chopped.set_shape(chopped_shape) + + # Recall R[m] is a sum of N / 2 - m nonzero terms x[n] Conj(x[n - m]). The + # other terms were zeros arising only due to zero padding. + # `denominator = (N / 2 - m)` (defined below) is the proper term to + # divide by by to make this an unbiased estimate of the expectation + # E[X[n] Conj(X[n - m])]. + x_len = math_ops.cast(x_len, dtype.real_dtype) + max_lags = math_ops.cast(max_lags, dtype.real_dtype) + denominator = x_len - math_ops.range(0., max_lags + 1.) + denominator = math_ops.cast(denominator, dtype) + shifted_product_rotated = shifted_product_chopped / denominator + + if normalize: + shifted_product_rotated /= shifted_product_rotated[..., :1] + + # Transpose dimensions back to those of x. + return util.rotate_transpose(shifted_product_rotated, -shift) + + # TODO(langmore) To make equivalent to numpy.percentile: # Make work with a sequence of floats or single float for 'q'. # Make work with "linear", "midpoint" interpolation. (linear should be default) diff --git a/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py b/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py index b05f15771a3a94779ffddea8f16ad2fa4ea2fdd1..c4b8f055b7fbc3f0835b503eddd7617610326d8c 100644 --- a/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py +++ b/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py @@ -115,7 +115,7 @@ class SinhArcsinh(transformed_distribution.TransformedDistribution): tailweight: Tailweight parameter. Default is `1.0` (unchanged tailweight) distribution: `tf.Distribution`-like instance. Distribution that is transformed to produce this distribution. - Default is `ds.Normal(0., 1.)`. + Default is `tf.distributions.Normal(0., 1.)`. Must be a scalar-batch, scalar-event distribution. Typically `distribution.reparameterization_type = FULLY_REPARAMETERIZED` or it is a function of non-trainable parameters. WARNING: If you backprop through diff --git a/tensorflow/contrib/distributions/python/ops/test_util.py b/tensorflow/contrib/distributions/python/ops/test_util.py index 77f2a39273dc365a4ac202d846dd2bc364655c86..15b0820cbdf560e04a304c40a47e541006523b6d 100644 --- a/tensorflow/contrib/distributions/python/ops/test_util.py +++ b/tensorflow/contrib/distributions/python/ops/test_util.py @@ -40,6 +40,7 @@ class DiscreteScalarDistributionTestHelpers(object): def run_test_sample_consistent_log_prob( self, sess_run_fn, dist, num_samples=int(1e5), num_threshold=int(1e3), seed=42, + batch_size=None, rtol=1e-2, atol=0.): """Tests that sample/log_prob are consistent with each other. @@ -66,6 +67,8 @@ class DiscreteScalarDistributionTestHelpers(object): seed: Python `int` indicating the seed to use when sampling from `dist`. In general it is not recommended to use `None` during a test as this increases the likelihood of spurious test failure. + batch_size: Hint for unpacking result of samples. Default: `None` means + batch_size is inferred. rtol: Python `float`-type indicating the admissible relative error between analytical and sample statistics. atol: Python `float`-type indicating the admissible absolute error between @@ -80,10 +83,11 @@ class DiscreteScalarDistributionTestHelpers(object): # Histogram only supports vectors so we call it once per batch coordinate. y = dist.sample(num_samples, seed=seed) y = array_ops.reshape(y, shape=[num_samples, -1]) - batch_size = math_ops.reduce_prod(dist.batch_shape_tensor()) + if batch_size is None: + batch_size = math_ops.reduce_prod(dist.batch_shape_tensor()) batch_dims = array_ops.shape(dist.batch_shape_tensor())[0] edges_expanded_shape = 1 + array_ops.pad([-2], paddings=[[0, batch_dims]]) - for b, x in enumerate(array_ops.unstack(y, axis=1)): + for b, x in enumerate(array_ops.unstack(y, num=batch_size, axis=1)): counts, edges = self.histogram(x) edges = array_ops.reshape(edges, edges_expanded_shape) probs = math_ops.exp(dist.log_prob(edges)) @@ -323,7 +327,7 @@ class VectorDistributionTestHelpers(object): num_samples=int(1e5), seed=24, rtol=1e-2, - atol=0., + atol=0.1, cov_rtol=None, cov_atol=None): """Tests that sample/mean/covariance are consistent with each other. diff --git a/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py b/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py index 92043d6a08833888c36009261addca0d14949ea8..7ce8a83fd91e2dfaa0ccef633f803b3ae595e646 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py +++ b/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py @@ -22,30 +22,176 @@ import numpy as np from tensorflow.contrib.distributions.python.ops import distribution_util from tensorflow.contrib.distributions.python.ops.bijectors.affine_linear_operator import AffineLinearOperator +from tensorflow.contrib.distributions.python.ops.bijectors.softmax_centered import SoftmaxCentered from tensorflow.contrib.linalg.python.ops import linear_operator_addition as linop_add_lib -from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops from tensorflow.python.ops.distributions import categorical as categorical_lib from tensorflow.python.ops.distributions import distribution as distribution_lib +from tensorflow.python.ops.distributions import normal as normal_lib from tensorflow.python.ops.linalg import linear_operator_diag as linop_diag_lib from tensorflow.python.ops.linalg import linear_operator_full_matrix as linop_full_lib from tensorflow.python.ops.linalg import linear_operator_identity as linop_identity_lib from tensorflow.python.ops.linalg import linear_operator_lower_triangular as linop_tril_lib -static_value = distribution_util.static_value - __all__ = [ "VectorDiffeomixture", + "quadrature_scheme_softmaxnormal_gauss_hermite", + "quadrature_scheme_softmaxnormal_quantiles", ] +def quadrature_scheme_softmaxnormal_gauss_hermite( + loc, scale, quadrature_size, + validate_args=False, name=None): + """Use Gauss-Hermite quadrature to form quadrature on `K - 1` simplex. + + Note: for a given `quadrature_size`, this method is generally less accurate + than `quadrature_scheme_softmaxnormal_quantiles`. + + Args: + loc: `float`-like `Tensor` with shape `[b1, ..., bB, K-1]`, B>=0. + Represents the `location` parameter of the SoftmaxNormal used for + selecting one of the `K` affine transformations. + scale: `float`-like `Tensor` with shape `[b1, ..., bB, K-1]`, B>=0. + Represents the `scale` parameter of the SoftmaxNormal used for + selecting one of the `K` affine transformations. + quadrature_size: Python `int` scalar representing the number of quadrature + points. + validate_args: Python `bool`, default `False`. When `True` distribution + parameters are checked for validity despite possibly degrading runtime + performance. When `False` invalid inputs may silently render incorrect + outputs. + name: Python `str` name prefixed to Ops created by this class. + + Returns: + grid: Shape `[b1, ..., bB, K, quadrature_size]` `Tensor` representing the + convex combination of affine parameters for `K` components. + `grid[..., :, n]` is the `n`-th grid point, living in the `K - 1` simplex. + probs: Shape `[b1, ..., bB, K, quadrature_size]` `Tensor` representing the + associated with each grid point. + """ + with ops.name_scope(name, "quadrature_scheme_softmaxnormal_gauss_hermite", + [loc, scale]): + loc = ops.convert_to_tensor(loc, name="loc") + dt = loc.dtype.base_dtype + scale = ops.convert_to_tensor(scale, dtype=dt, name="scale") + + loc = maybe_check_quadrature_param(loc, "loc", validate_args) + scale = maybe_check_quadrature_param(scale, "scale", validate_args) + + grid, probs = np.polynomial.hermite.hermgauss(deg=quadrature_size) + grid = grid.astype(loc.dtype.as_numpy_dtype) + probs = probs.astype(loc.dtype.as_numpy_dtype) + probs /= np.linalg.norm(probs, ord=1, keepdims=True) + probs = ops.convert_to_tensor(probs, name="probs", dtype=loc.dtype) + + grid = softmax( + -distribution_util.pad( + (loc[..., array_ops.newaxis] + + np.sqrt(2.) * scale[..., array_ops.newaxis] * grid), + axis=-2, + front=True), + axis=-2) # shape: [B, components, deg] + + return grid, probs + + +def quadrature_scheme_softmaxnormal_quantiles( + loc, scale, quadrature_size, + validate_args=False, name=None): + """Use SoftmaxNormal quantiles to form quadrature on `K - 1` simplex. + + Args: + loc: `float`-like `Tensor` with shape `[b1, ..., bB, K-1]`, B>=0. + Represents the `location` parameter of the SoftmaxNormal used for + selecting one of the `K` affine transformations. + scale: `float`-like `Tensor` with shape `[b1, ..., bB, K-1]`, B>=0. + Represents the `scale` parameter of the SoftmaxNormal used for + selecting one of the `K` affine transformations. + quadrature_size: Python scalar `int` representing the number of quadrature + points. + validate_args: Python `bool`, default `False`. When `True` distribution + parameters are checked for validity despite possibly degrading runtime + performance. When `False` invalid inputs may silently render incorrect + outputs. + name: Python `str` name prefixed to Ops created by this class. + + Returns: + grid: Shape `[b1, ..., bB, K, quadrature_size]` `Tensor` representing the + convex combination of affine parameters for `K` components. + `grid[..., :, n]` is the `n`-th grid point, living in the `K - 1` simplex. + probs: Shape `[b1, ..., bB, K, quadrature_size]` `Tensor` representing the + associated with each grid point. + """ + with ops.name_scope(name, "softmax_normal_grid_and_probs", [loc, scale]): + loc = ops.convert_to_tensor(loc, name="loc") + dt = loc.dtype.base_dtype + scale = ops.convert_to_tensor(scale, dtype=dt, name="scale") + + loc = maybe_check_quadrature_param(loc, "loc", validate_args) + scale = maybe_check_quadrature_param(scale, "scale", validate_args) + + dist = normal_lib.Normal(loc=loc, scale=scale) + + def _get_batch_ndims(): + """Helper to get dist.batch_shape.ndims, statically if possible.""" + ndims = dist.batch_shape.ndims + if ndims is None: + ndims = array_ops.shape(dist.batch_shape_tensor())[0] + return ndims + batch_ndims = _get_batch_ndims() + + def _get_final_shape(qs): + """Helper to build `TensorShape`.""" + bs = dist.batch_shape.with_rank_at_least(1) + num_components = bs[-1].value + if num_components is not None: + num_components += 1 + tail = tensor_shape.TensorShape([num_components, qs]) + return bs[:-1].concatenate(tail) + + def _compute_quantiles(): + """Helper to build quantiles.""" + # Omit {0, 1} since they might lead to Inf/NaN. + zero = array_ops.zeros([], dtype=dist.dtype) + edges = math_ops.linspace(zero, 1., quadrature_size + 3)[1:-1] + # Expand edges so its broadcast across batch dims. + edges = array_ops.reshape(edges, shape=array_ops.concat([ + [-1], array_ops.ones([batch_ndims], dtype=dtypes.int32)], axis=0)) + quantiles = dist.quantile(edges) + quantiles = SoftmaxCentered(event_ndims=1).forward(quantiles) + # Cyclically permute left by one. + perm = array_ops.concat([ + math_ops.range(1, 1 + batch_ndims), [0]], axis=0) + quantiles = array_ops.transpose(quantiles, perm) + quantiles.set_shape(_get_final_shape(quadrature_size + 1)) + return quantiles + quantiles = _compute_quantiles() + + # Compute grid as quantile midpoints. + grid = (quantiles[..., :-1] + quantiles[..., 1:]) / 2. + # Set shape hints. + grid.set_shape(_get_final_shape(quadrature_size)) + + # By construction probs is constant, i.e., `1 / quadrature_size`. This is + # important, because non-constant probs leads to non-reparameterizable + # samples. + probs = array_ops.fill( + dims=[quadrature_size], + value=1. / math_ops.cast(quadrature_size, dist.dtype)) + + return grid, probs + + class VectorDiffeomixture(distribution_lib.Distribution): """VectorDiffeomixture distribution. @@ -188,8 +334,7 @@ class VectorDiffeomixture(distribution_lib.Distribution): #### Examples ```python - ds = tf.contrib.distributions - la = tf.linalg + tfd = tf.contrib.distributions # Create two batches of VectorDiffeomixtures, one with mix_loc=[0.] and # another with mix_loc=[1]. In both cases, `K=2` and the affine @@ -197,20 +342,20 @@ class VectorDiffeomixture(distribution_lib.Distribution): # k=0: loc=zeros(dims) scale=LinearOperatorScaledIdentity # k=1: loc=[2.]*dims scale=LinOpDiag dims = 5 - vdm = ds.VectorDiffeomixture( + vdm = tfd.VectorDiffeomixture( mix_loc=[[0.], [1]], mix_scale=[1.], - distribution=ds.Normal(loc=0., scale=1.), + distribution=tfd.Normal(loc=0., scale=1.), loc=[ None, # Equivalent to `np.zeros(dims, dtype=np.float32)`. np.float32([2.]*dims), ], scale=[ - la.LinearOperatorScaledIdentity( + tf.linalg.LinearOperatorScaledIdentity( num_rows=dims, multiplier=np.float32(1.1), is_positive_definite=True), - la.LinearOperatorDiag( + tf.linalg.LinearOperatorDiag( diag=np.linspace(2.5, 3.5, dims, dtype=np.float32), is_positive_definite=True), ], @@ -223,17 +368,20 @@ class VectorDiffeomixture(distribution_lib.Distribution): distribution, loc=None, scale=None, - quadrature_grid_and_probs=None, + quadrature_size=8, + quadrature_fn=quadrature_scheme_softmaxnormal_quantiles, validate_args=False, allow_nan_stats=True, name="VectorDiffeomixture"): - """Constructs the VectorDiffeomixture on `R**k`. + """Constructs the VectorDiffeomixture on `R**d`. Args: - mix_loc: `float`-like `Tensor`. Represents the `location` parameter of the - SoftmaxNormal used for selecting one of the `K` affine transformations. - mix_scale: `float`-like `Tensor`. Represents the `scale` parameter of the - SoftmaxNormal used for selecting one of the `K` affine transformations. + mix_loc: `float`-like `Tensor` with shape `[b1, ..., bB, K-1]`. Represents + the `location` parameter of the SoftmaxNormal used for selecting one of + the `K` affine transformations. + mix_scale: `float`-like `Tensor` with shape `[b1, ..., bB, K-1]`. + Represents the `scale` parameter of the SoftmaxNormal used for selecting + one of the `K` affine transformations. distribution: `tf.Distribution`-like instance. Distribution from which `d` iid samples are used as input to the selected affine transformation. Must be a scalar-batch, scalar-event distribution. Typically @@ -252,10 +400,13 @@ class VectorDiffeomixture(distribution_lib.Distribution): `k`-th element represents the `scale` used for the `k`-th affine transformation. `LinearOperator`s must have shape `[B1, ..., Bb, d, d]`, `b >= 0`, i.e., characterizes `b`-batches of `d x d` matrices - quadrature_grid_and_probs: Python pair of `float`-like `Tensor`s - representing the sample points and the corresponding (possibly - normalized) weight. When `None`, defaults to: - `np.polynomial.hermite.hermgauss(deg=8)`. + quadrature_size: Python `int` scalar representing number of + quadrature points. + quadrature_fn: Python callable taking `mix_loc`, `mix_scale`, + `quadrature_size`, `validate_args` and returning `tuple(grid, probs)` + representing the SoftmaxNormal grid and corresponding normalized weight. + normalized) weight. + Default value: `quadrature_scheme_softmaxnormal_quantiles`. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect @@ -322,11 +473,8 @@ class VectorDiffeomixture(distribution_lib.Distribution): raise NotImplementedError("Currently only bimixtures are supported; " "len(scale)={} is not 2.".format(len(scale))) - grid, probs = distribution_util.process_quadrature_grid_and_probs( - quadrature_grid_and_probs, dtype, validate_args) - self._quadrature_grid = grid - self._quadrature_probs = probs - self._quadrature_size = distribution_util.dimension_size(probs, axis=0) + self._grid, probs = tuple(quadrature_fn( + mix_loc, mix_scale, quadrature_size, validate_args)) # Note: by creating the logits as `log(prob)` we ensure that # `self.mixture_distribution.logits` is equivalent to @@ -336,22 +484,13 @@ class VectorDiffeomixture(distribution_lib.Distribution): validate_args=validate_args, allow_nan_stats=allow_nan_stats) - mix_loc = maybe_check_mix_param( - mix_loc, "mix_loc", dtype, validate_args) - mix_scale = maybe_check_mix_param( - mix_scale, "mix_scale", dtype, validate_args) - asserts = distribution_util.maybe_check_scalar_distribution( distribution, dtype, validate_args) if asserts: - mix_loc = control_flow_ops.with_dependencies(asserts, mix_loc) + self._grid = control_flow_ops.with_dependencies( + asserts, self._grid) self._distribution = distribution - # shape: [B, deg] - self._interpolate_weight = math_ops.sigmoid( - mix_loc - + np.sqrt(2.) * mix_scale * grid) - self._interpolated_affine = [ AffineLinearOperator(shift=loc_, scale=scale_, @@ -359,15 +498,16 @@ class VectorDiffeomixture(distribution_lib.Distribution): validate_args=validate_args, name="interpolated_affine_{}".format(k)) for k, (loc_, scale_) in enumerate(zip( - interpolate_loc(self._quadrature_size, - self._interpolate_weight, - loc), - interpolate_scale(self._quadrature_size, - self._interpolate_weight, - scale)))] + interpolate_loc(self._grid, loc), + interpolate_scale(self._grid, scale)))] - self._batch_shape_, self._event_shape_ = determine_batch_event_shapes( - mix_loc, mix_scale, self._endpoint_affine) + [ + self._batch_shape_, + self._batch_shape_tensor_, + self._event_shape_, + self._event_shape_tensor_, + ] = determine_batch_event_shapes(self._grid, + self._endpoint_affine) super(VectorDiffeomixture, self).__init__( dtype=dtype, @@ -386,8 +526,7 @@ class VectorDiffeomixture(distribution_lib.Distribution): allow_nan_stats=allow_nan_stats, parameters=parameters, graph_parents=( - [mix_loc, mix_scale] - + distribution._graph_parents # pylint: disable=protected-access + distribution._graph_parents # pylint: disable=protected-access + [loc_ for loc_ in loc if loc_ is not None] + [p for scale_ in scale for p in scale_.graph_parents]), name=name) @@ -403,9 +542,9 @@ class VectorDiffeomixture(distribution_lib.Distribution): return self._distribution @property - def interpolate_weight(self): + def grid(self): """Grid of mixing probabilities, one for each grid point.""" - return self._interpolate_weight + return self._grid @property def endpoint_affine(self): @@ -417,27 +556,17 @@ class VectorDiffeomixture(distribution_lib.Distribution): """Affine transformation for each convex combination of `K` components.""" return self._interpolated_affine - @property - def quadrature_grid(self): - """Quadrature grid points.""" - return self._quadrature_grid - - @property - def quadrature_probs(self): - """Quadrature normalized weights.""" - return self._quadrature_probs - def _batch_shape_tensor(self): - return self._batch_shape_ + return self._batch_shape_tensor_ def _batch_shape(self): - return tensor_shape.TensorShape(static_value(self._batch_shape_)) + return self._batch_shape_ def _event_shape_tensor(self): - return self._event_shape_ + return self._event_shape_tensor_ def _event_shape(self): - return tensor_shape.TensorShape(static_value(self._event_shape_)) + return self._event_shape_ def _sample_n(self, n, seed=None): x = self.distribution.sample( @@ -450,25 +579,44 @@ class VectorDiffeomixture(distribution_lib.Distribution): # Get ids as a [n, batch_size]-shaped matrix, unless batch_shape=[] then get # ids as a [n]-shaped vector. - batch_size = reduce_prod(self.batch_shape_tensor()) - ids = self._mixture_distribution.sample( + batch_size = self.batch_shape.num_elements() + if batch_size is None: + batch_size = array_ops.reduce_prod(self.batch_shape_tensor()) + mix_batch_size = self.mixture_distribution.batch_shape.num_elements() + if mix_batch_size is None: + mix_batch_size = math_ops.reduce_prod( + self.mixture_distribution.batch_shape_tensor()) + ids = self.mixture_distribution.sample( sample_shape=concat_vectors( [n], distribution_util.pick_vector( self.is_scalar_batch(), np.int32([]), - [batch_size])), + [batch_size // mix_batch_size])), seed=distribution_util.gen_new_seed( seed, "vector_diffeomixture")) - - # Stride `quadrature_size` for `batch_size` number of times. + # We need to flatten batch dims in case mixture_distribution has its own + # batch dims. + ids = array_ops.reshape(ids, shape=concat_vectors( + [n], + distribution_util.pick_vector( + self.is_scalar_batch(), + np.int32([]), + np.int32([-1])))) + + # Stride `components * quadrature_size` for `batch_size` number of times. + stride = self.grid.shape.with_rank_at_least( + 2)[-2:].num_elements() + if stride is None: + stride = array_ops.reduce_prod( + array_ops.shape(self.grid)[-2:]) offset = math_ops.range(start=0, - limit=batch_size * self._quadrature_size, - delta=self._quadrature_size, + limit=batch_size * stride, + delta=stride, dtype=ids.dtype) weight = array_ops.gather( - array_ops.reshape(self.interpolate_weight, shape=[-1]), + array_ops.reshape(self.grid, shape=[-1]), ids + offset) weight = weight[..., array_ops.newaxis] @@ -500,10 +648,7 @@ class VectorDiffeomixture(distribution_lib.Distribution): self.mixture_distribution.logits - fldj + log_prob, axis=-1) def _mean(self): - # Since we created logits to already be scaled, we can use exp which is - # slightly cheaper than `self.mixture_distribution.probs`. - p = math_ops.exp(self.mixture_distribution.logits) - + p = self._expand_mix_distribution_probs() m = self._expand_base_distribution_mean() mean = None for k, aff in enumerate(self.interpolated_affine): @@ -537,9 +682,7 @@ class VectorDiffeomixture(distribution_lib.Distribution): self._covariance_of_mean_given_quadrature_component(diag_only=True)) def _mean_of_covariance_given_quadrature_component(self, diag_only): - # Since we created logits to already be scaled, we can use exp which is - # slightly cheaper than `self.mixture_distribution.probs`. - p = math_ops.exp(self.mixture_distribution.logits) + p = self.mixture_distribution.probs # To compute E[Cov(Z|V)], we'll add matrices within three categories: # scaled-identity, diagonal, and full. Then we'll combine these at the end. @@ -611,10 +754,9 @@ class VectorDiffeomixture(distribution_lib.Distribution): def _covariance_of_mean_given_quadrature_component(self, diag_only): square = math_ops.square if diag_only else vec_osquare - # Since we created logits to already be scaled, we can use exp which is - # slightly cheaper than `self.mixture_distribution.probs`. - p = math_ops.exp(self.mixture_distribution.logits) - + p = self._expand_mix_distribution_probs() + if not diag_only: + p = p[..., array_ops.newaxis, :] # Assuming event.ndims=1. m = self._expand_base_distribution_mean() cov_e_z_given_v = None @@ -638,17 +780,25 @@ class VectorDiffeomixture(distribution_lib.Distribution): m.set_shape(self.batch_shape.concatenate(self.event_shape)) return m - -def maybe_check_mix_param(param, name, expected_base_dtype, validate_args): - """Helper which checks validity of `mix_loc` and `mix_scale` init args.""" + def _expand_mix_distribution_probs(self): + p = self.mixture_distribution.probs # [B, deg] + deg = p.shape.with_rank_at_least(1)[-1].value + if deg is None: + deg = array_ops.shape(p)[-1] + event_ndims = self.event_shape.ndims + if event_ndims is None: + event_ndims = array_ops.shape(self.event_shape_tensor())[0] + expand_shape = array_ops.concat([ + self.mixture_distribution.batch_shape_tensor(), + array_ops.ones([event_ndims], dtype=dtypes.int32), + [deg], + ], axis=0) + return array_ops.reshape(p, shape=expand_shape) + + +def maybe_check_quadrature_param(param, name, validate_args): + """Helper which checks validity of `loc` and `scale` init args.""" with ops.name_scope(name="check_" + name, values=[param]): - param = ops.convert_to_tensor(param, dtype=expected_base_dtype, name=name) - - if param.dtype.base_dtype != expected_base_dtype: - raise TypeError( - "dtype mismatch; {}.base_dtype=\"{}\" is not \"{}\".".format( - name, param.dtype.base_dtype.name, expected_base_dtype.name)) - assertions = [] if param.shape.ndims is not None: if param.shape.ndims == 0: @@ -679,79 +829,84 @@ def maybe_check_mix_param(param, name, expected_base_dtype, validate_args): return param -def determine_batch_event_shapes(mix_loc, mix_scale, endpoint_affine): +def determine_batch_event_shapes(grid, endpoint_affine): """Helper to infer batch_shape and event_shape.""" with ops.name_scope(name="determine_batch_event_shapes"): - mix_batch_shape = distribution_util.prefer_static_broadcast_shape( - array_ops.shape(mix_loc, name="mix_loc_shape"), - array_ops.shape(mix_scale, name="mix_scale_shape")) - if isinstance(mix_batch_shape, tensor_shape.TensorShape): - mix_batch_shape = mix_batch_shape.with_rank_at_least(1)[:-1] - else: - s = static_value(mix_batch_shape) - if s is not None: - mix_batch_shape = ops.convert_to_tensor( - s[:-1], dtype=dtypes.int32, name="mix_batch_shape") - else: - mix_batch_shape = mix_batch_shape[:-1] - - # We broadcast with a 1D constant to automatically make the result a - # TensorShape if possible. - batch_shape = distribution_util.prefer_static_broadcast_shape( - mix_batch_shape, - constant_op.constant([], dtype=dtypes.int32, name="batch_shape")) - event_shape = constant_op.constant( - [], dtype=dtypes.int32, name="event_shape") + # grid # shape: [B, k, q] + # endpoint_affine # len=k, shape: [B, d, d] + batch_shape = grid.shape[:-2] + batch_shape_tensor = array_ops.shape(grid)[:-2] + event_shape = None + event_shape_tensor = None + + def _set_event_shape(shape, shape_tensor): + if event_shape is None: + return shape, shape_tensor + return (array_ops.broadcast_static_shape(event_shape, shape), + array_ops.broadcast_dynamic_shape( + event_shape_tensor, shape_tensor)) + for aff in endpoint_affine: - b, e = distribution_util.shapes_from_loc_and_scale(aff.shift, aff.scale) - if batch_shape is None: - batch_shape = distribution_util.prefer_static_broadcast_shape( - mix_batch_shape, b) - else: - batch_shape = distribution_util.prefer_static_broadcast_shape( - batch_shape, b) - event_shape = distribution_util.prefer_static_broadcast_shape( - event_shape, e) - if isinstance(batch_shape, tensor_shape.TensorShape): - batch_shape = ops.convert_to_tensor( - batch_shape.as_list(), dtype=dtypes.int32, name="batch_shape") - if isinstance(event_shape, tensor_shape.TensorShape): - event_shape = ops.convert_to_tensor( - event_shape.as_list(), dtype=dtypes.int32, name="event_shape") - return batch_shape, event_shape - - -def interpolate_loc(deg, interpolate_weight, loc): + if aff.shift is not None: + batch_shape = array_ops.broadcast_static_shape( + batch_shape, aff.shift.shape[:-1]) + batch_shape_tensor = array_ops.broadcast_dynamic_shape( + batch_shape_tensor, array_ops.shape(aff.shift)[:-1]) + event_shape, event_shape_tensor = _set_event_shape( + aff.shift.shape[-1:], array_ops.shape(aff.shift)[-1:]) + + if aff.scale is not None: + batch_shape = array_ops.broadcast_static_shape( + batch_shape, aff.scale.batch_shape) + batch_shape_tensor = array_ops.broadcast_dynamic_shape( + batch_shape_tensor, aff.scale.batch_shape_tensor()) + event_shape, event_shape_tensor = _set_event_shape( + tensor_shape.TensorShape([aff.scale.range_dimension]), + aff.scale.range_dimension_tensor()[array_ops.newaxis]) + + return batch_shape, batch_shape_tensor, event_shape, event_shape_tensor + + +def interpolate_loc(grid, loc): """Helper which interpolates between two locs.""" if len(loc) != 2: raise NotImplementedError("Currently only bimixtures are supported; " "len(scale)={} is not 2.".format(len(loc))) - with ops.name_scope("interpolate_loc", values=[interpolate_weight, loc]): + deg = grid.shape.with_rank_at_least(1)[-1].value + if deg is None: + raise ValueError("Num quadrature grid points must be known prior " + "to graph execution.") + with ops.name_scope("interpolate_loc", values=[grid, loc]): if loc is None or loc[0] is None and loc[1] is None: return [None]*deg - w = interpolate_weight[..., array_ops.newaxis, :] # shape: [B, 1, deg] + # shape: [B, 1, k, deg] + w = grid[..., array_ops.newaxis, :, :] loc = [x[..., array_ops.newaxis] # shape: [B, e, 1] if x is not None else None for x in loc] if loc[0] is None: - x = (1. - w) * loc[1] # shape: [B, e, deg] + x = w[..., 1, :] * loc[1] # shape: [B, e, deg] elif loc[1] is None: - x = w * loc[0] # shape: [B, e, deg] + x = w[..., 0, :] * loc[0] # shape: [B, e, deg] else: delta = loc[0] - loc[1] - x = w * delta + loc[1] # shape: [B, e, deg] + x = w[..., 0, :] * delta + loc[1] # shape: [B, e, deg] return [x[..., k] for k in range(deg)] # list(shape:[B, e]) -def interpolate_scale(deg, interpolate_weight, scale): +def interpolate_scale(grid, scale): """Helper which interpolates between two scales.""" if len(scale) != 2: raise NotImplementedError("Currently only bimixtures are supported; " "len(scale)={} is not 2.".format(len(scale))) - with ops.name_scope("interpolate_scale", values=[interpolate_weight]): + deg = grid.shape.with_rank_at_least(1)[-1].value + if deg is None: + raise ValueError("Num quadrature grid points must be known prior " + "to graph execution.") + with ops.name_scope("interpolate_scale", values=[grid]): return [linop_add_lib.add_operators([ - linop_scale(interpolate_weight[..., k], scale[0]), - linop_scale(1. - interpolate_weight[..., k], scale[1]), - ])[0] for k in range(deg)] + linop_scale(grid[..., k, q], s) + for k, s in enumerate(scale) + ])[0] for q in range(deg)] def linop_scale(w, op): @@ -791,39 +946,12 @@ def linop_scale(w, op): def concat_vectors(*args): """Concatenates input vectors, statically if possible.""" - args_ = [static_value(x) for x in args] + args_ = [distribution_util.static_value(x) for x in args] if any(vec is None for vec in args_): return array_ops.concat(args, axis=0) return [val for vec in args_ for val in vec] -def reduce_prod(x): - """Same as `math_ops.reduce_prod` but statically if possible.""" - x_ = static_value(x) - if x_ is not None: - return np.prod(x_, dtype=x.dtype.as_numpy_dtype) - return array_ops.reduce_prod(x) - - -def ndims_from_shape(shape): - """Returns `Tensor`'s `rank` implied by a `Tensor` shape.""" - if shape.shape.ndims not in (None, 1): - raise ValueError("input is not a valid shape: not 1D") - if not shape.dtype.is_integer: - raise TypeError("input is not a valid shape: wrong dtype") - if shape.shape.is_fully_defined(): - return shape.shape.as_list()[0] - return array_ops.shape(shape)[0] - - -def ndims(x): - """Returns rank, statically if possible.""" - x = ops.convert_to_tensor(x) - if x.shape.ndims is not None: - return x.shape.ndims - return array_ops.rank(x) - - def add(x, y): """Adds inputs; interprets `None` as zero.""" if x is None: @@ -836,3 +964,18 @@ def add(x, y): def vec_osquare(x): """Computes the outer-product of a (batch of) vector, i.e., x.T x.""" return x[..., :, array_ops.newaxis] * x[..., array_ops.newaxis, :] + + +def softmax(x, axis, name=None): + """Equivalent to tf.nn.softmax but works around b/70297725.""" + with ops.name_scope(name, "softmax", [x, axis]): + x = ops.convert_to_tensor(x, name="x") + ndims = (x.shape.ndims if x.shape.ndims is not None + else array_ops.rank(x, name="ndims")) + axis = ops.convert_to_tensor(axis, dtype=dtypes.int32, name="axis") + axis_ = tensor_util.constant_value(axis) + if axis_ is not None: + axis = np.int(ndims + axis_ if axis_ < 0 else axis_) + else: + axis = array_ops.where(axis < 0, ndims + axis, axis) + return nn_ops.softmax(x, axis=axis) diff --git a/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py b/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py index 356d78b67a8107750f68f7f84d73d1231f5b2b03..526fe2d39aef9aed833b889de80e849c469435e7 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py +++ b/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py @@ -89,14 +89,13 @@ class VectorExponentialDiag( #### Examples ```python - ds = tf.contrib.distributions - la = tf.linalg + tfd = tf.contrib.distributions # Initialize a single 2-variate VectorExponential, supported on # {(x, y) in R^2 : x > 0, y > 0}. # The first component has pdf exp{-x}, the second 0.5 exp{-x / 2} - vex = ds.VectorExponentialDiag(scale_diag=[1., 2.]) + vex = tfd.VectorExponentialDiag(scale_diag=[1., 2.]) # Compute the pdf of an`R^2` observation; return a scalar. vex.prob([3., 4.]).eval() # shape: [] @@ -107,7 +106,7 @@ class VectorExponentialDiag( scale_diag = [[1., 2, 3], [0.5, 1, 1.5]] # shape: [2, 3] - vex = ds.VectorExponentialDiag(loc, scale_diag) + vex = tfd.VectorExponentialDiag(loc, scale_diag) # Compute the pdf of two `R^3` observations; return a length-2 vector. x = [[1.9, 2.2, 3.1], diff --git a/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py b/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py index b313a851b381e5b3a057fd17e6c2ef4eb0fc34f1..9d5fd9ac4178a1ae29b1ce32f304b22fd3d234dc 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py +++ b/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py @@ -107,16 +107,15 @@ class VectorExponentialLinearOperator( #### Examples ```python - ds = tf.contrib.distributions - la = tf.linalg + tfd = tf.contrib.distributions # Initialize a single 2-variate VectorExponential, supported on # {(x, y) in R^2 : x > 0, y > 0}. mat = [[1.0, 0.1], [0.1, 1.0]] - vex = ds.VectorExponentialLinearOperator( - scale=la.LinearOperatorFullMatrix(mat)) + vex = tfd.VectorExponentialLinearOperator( + scale=tf.linalg.LinearOperatorFullMatrix(mat)) # Compute the pdf of an`R^2` observation; return a scalar. vex.prob([1., 2.]).eval() # shape: [] @@ -127,9 +126,9 @@ class VectorExponentialLinearOperator( scale_diag = [[1., 2, 3], [0.5, 1, 1.5]] # shape: [2, 3] - vex = ds.VectorExponentialLinearOperator( + vex = tfd.VectorExponentialLinearOperator( loc=mu, - scale=la.LinearOperatorDiag(scale_diag)) + scale=tf.linalg.LinearOperatorDiag(scale_diag)) # Compute the pdf of two `R^3` observations; return a length-2 vector. x = [[1.9, 2.2, 3.1], diff --git a/tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py b/tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py index 0e3867809a820f49cfa7f5282c47f786626481a6..8dd983b750d9b39775e570800006011f4968f7f3 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py +++ b/tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py @@ -101,10 +101,10 @@ class VectorLaplaceDiag( #### Examples ```python - ds = tf.contrib.distributions + tfd = tf.contrib.distributions # Initialize a single 2-variate VectorLaplace. - vla = ds.VectorLaplaceDiag( + vla = tfd.VectorLaplaceDiag( loc=[1., -1], scale_diag=[1, 2.]) @@ -118,7 +118,7 @@ class VectorLaplaceDiag( vla.prob([-1., 0]).eval() # shape: [] # Initialize a 3-batch, 2-variate scaled-identity VectorLaplace. - vla = ds.VectorLaplaceDiag( + vla = tfd.VectorLaplaceDiag( loc=[1., -1], scale_identity_multiplier=[1, 2., 3]) @@ -136,7 +136,7 @@ class VectorLaplaceDiag( vla.prob([-1., 0]).eval() # shape: [3] # Initialize a 2-batch of 3-variate VectorLaplace's. - vla = ds.VectorLaplaceDiag( + vla = tfd.VectorLaplaceDiag( loc=[[1., 2, 3], [11, 22, 33]] # shape: [2, 3] scale_diag=[[1., 2, 3], diff --git a/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py b/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py index c7abdbb4caf9bee4cbd5991eb5d652f20dd0f8d1..ec485c95c15da2794b67d2699d2bdd9db97bb6c4 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py +++ b/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py @@ -109,8 +109,7 @@ class VectorLaplaceLinearOperator( #### Examples ```python - ds = tf.contrib.distributions - la = tf.linalg + tfd = tf.contrib.distributions # Initialize a single 3-variate VectorLaplace with some desired covariance. mu = [1., 2, 3] @@ -124,9 +123,9 @@ class VectorLaplaceLinearOperator( # [ 0.1, -0.3, 0.4]]) # Divide scale by sqrt(2) so that the final covariance will be what we want. - vla = ds.VectorLaplaceLinearOperator( + vla = tfd.VectorLaplaceLinearOperator( loc=mu, - scale=la.LinearOperatorLowerTriangular(scale / tf.sqrt(2))) + scale=tf.linalg.LinearOperatorLowerTriangular(scale / tf.sqrt(2.))) # Covariance agrees with cholesky(cov) parameterization. vla.covariance().eval() @@ -143,9 +142,9 @@ class VectorLaplaceLinearOperator( scale_diag = [[1., 2, 3], [0.5, 1, 1.5]] # shape: [2, 3] - vla = ds.VectorLaplaceLinearOperator( + vla = tfd.VectorLaplaceLinearOperator( loc=mu, - scale=la.LinearOperatorDiag(scale_diag)) + scale=tf.linalg.LinearOperatorDiag(scale_diag)) # Compute the pdf of two `R^3` observations; return a length-2 vector. x = [[-0.9, 0, 0.1], diff --git a/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py b/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py index 544a8710709a0afb56c6ae6f36d35de892e8e420..e1ccf116457a97261b9ce3965552764771d3bdd2 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py +++ b/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py @@ -143,7 +143,7 @@ class VectorSinhArcsinhDiag(transformed_distribution.TransformedDistribution): broadcastable with `event_shape`. distribution: `tf.Distribution`-like instance. Distribution from which `k` iid samples are used as input to transformation `F`. Default is - `ds.Normal(0., 1.)`. + `tf.distributions.Normal(loc=0., scale=1.)`. Must be a scalar-batch, scalar-event distribution. Typically `distribution.reparameterization_type = FULLY_REPARAMETERIZED` or it is a function of non-trainable parameters. WARNING: If you backprop through diff --git a/tensorflow/contrib/distributions/python/ops/vector_student_t.py b/tensorflow/contrib/distributions/python/ops/vector_student_t.py index 29d41ab81c62d621c3c3533e1449341e9a085645..8c67647a618d22a58428d78865c4ebf7d98bdf9e 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_student_t.py +++ b/tensorflow/contrib/distributions/python/ops/vector_student_t.py @@ -91,14 +91,14 @@ class _VectorStudentT(transformed_distribution.TransformedDistribution): Extra leading dimensions, if provided, allow for batches. ```python - ds = tf.contrib.distributions + tfd = tf.contrib.distributions # Initialize a single 3-variate vector Student's t-distribution. mu = [1., 2, 3] chol = [[1., 0, 0.], [1, 3, 0], [1, 2, 3]] - vt = ds.VectorStudentT(df=2, loc=mu, scale_tril=chol) + vt = tfd.VectorStudentT(df=2, loc=mu, scale_tril=chol) # Evaluate this on an observation in R^3, returning a scalar. vt.prob([-1., 0, 1]) @@ -107,7 +107,7 @@ class _VectorStudentT(transformed_distribution.TransformedDistribution): mu = [[1., 2, 3], [11, 22, 33]] chol = ... # shape 2 x 3 x 3, lower triangular, positive diagonal. - vt = ds.VectorStudentT(loc=mu, scale_tril=chol) + vt = tfd.VectorStudentT(loc=mu, scale_tril=chol) # Evaluate this on a two observations, each in R^3, returning a length two # tensor. diff --git a/tensorflow/contrib/eager/README.md b/tensorflow/contrib/eager/README.md index dcc370cd00d5f93cd5b145a31fd58ef5041a86a8..09242ee47ddd044dfc99e22d5b7751a989c86485 100644 --- a/tensorflow/contrib/eager/README.md +++ b/tensorflow/contrib/eager/README.md @@ -76,3 +76,6 @@ For an introduction to eager execution in TensorFlow, see: ## Changelog - 2017/10/31: Initial preview release. +- 2017/12/01: Example of dynamic neural network: + [SPINN: Stack-augmented Parser-Interpreter Neural Network](https://arxiv.org/abs/1603.06021). + See [README.md](python/examples/spinn/README.md) for details. diff --git a/tensorflow/contrib/eager/proto/BUILD b/tensorflow/contrib/eager/proto/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..aedfec8924e7314addd22349c0576a84a58d9aa3 --- /dev/null +++ b/tensorflow/contrib/eager/proto/BUILD @@ -0,0 +1,24 @@ +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library") + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) + +tf_proto_library( + name = "checkpointable_object_graph_proto", + srcs = [ + "checkpointable_object_graph.proto", + ], + visibility = ["//tensorflow/contrib/eager/python:__subpackages__"], +) diff --git a/tensorflow/contrib/eager/proto/checkpointable_object_graph.proto b/tensorflow/contrib/eager/proto/checkpointable_object_graph.proto new file mode 100644 index 0000000000000000000000000000000000000000..c962638aa11c06dcd5be6a794314e029ae84e572 --- /dev/null +++ b/tensorflow/contrib/eager/proto/checkpointable_object_graph.proto @@ -0,0 +1,56 @@ +syntax = "proto3"; + +option cc_enable_arenas = true; + +package tensorflow.contrib.eager; + +// Prototype for an addition to BundleHeaderProto which saves extra information +// about the objects which own variables, allowing for more robust checkpoint +// loading into modified programs. + +message CheckpointableObjectGraph { + message Object { + message ObjectReference { + // An index into `CheckpointableObjectGraph.nodes`, indicating the object + // being referenced. + int32 node_id = 1; + // A numeric identifier for this object within its parent. + int32 local_uid = 2; + // A user-provided name for the edge. May be blank/omitted, in which case + // there is no explicitly provided local name; fall back on local_uid. + string local_name = 3; + } + + message VariableReference { + // A name for the variable which is unique within the object which owns + // it. Does not include a name_scope or variable_scope prefix. + string local_name = 1; + // The full name of the variable. Used to allow name-based loading of + // checkpoints which were saved using an object-based API. + string full_name = 2; + } + + message SlotVariableReference { + // An index into `CheckpointableObjectGraph.nodes`, indicating the object + // which created the variable that this variable is slotting for. + int32 original_variable_node_id = 1; + // The local name of the variable being slotted for within the object that + // owns it. + string original_variable_local_name = 2; + // The name of the slot (e.g. "m"/"v"). + string slot_name = 3; + // The full name of the slot variable. Used to allow name-based loading of + // checkpoints which were saved using an object-based API. + string full_name = 4; + } + + // Objects which this object depends on. + repeated ObjectReference children = 1; + // Non-slot variables owned by this object. + repeated VariableReference variables = 2; + // Slot variables owned by this object. + repeated SlotVariableReference slot_variables = 3; + } + + repeated Object nodes = 1; +} diff --git a/tensorflow/contrib/eager/python/BUILD b/tensorflow/contrib/eager/python/BUILD index bf2e883bc53c3281ef89d1200f5a089305ef3e72..e984c63af7ce2b32ab30121bf34bb2de4dfeb218 100644 --- a/tensorflow/contrib/eager/python/BUILD +++ b/tensorflow/contrib/eager/python/BUILD @@ -19,6 +19,8 @@ py_library( "//tensorflow/python:framework_test_lib", "//tensorflow/python:numerics", "//tensorflow/python:resource_variable_ops", + "//tensorflow/python:script_ops", + "//tensorflow/python:template", "//tensorflow/python:util", "//tensorflow/python:variable_scope", "//tensorflow/python/eager:backprop", @@ -67,6 +69,7 @@ cuda_py_test( srcs = ["datasets_test.py"], additional_deps = [ ":datasets", + "//tensorflow/contrib/lookup:lookup_py", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", "//tensorflow/python:math_ops", @@ -103,37 +106,6 @@ cuda_py_test( ], ) -py_library( - name = "summary_writer", - srcs = ["summary_writer.py"], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow/contrib/summary:gen_summary_ops", - "//tensorflow/python:constant_op", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:init_ops", - "//tensorflow/python:resource_variable_ops", - "//tensorflow/python:state_ops", - "//tensorflow/python:summary_op_util", - "//tensorflow/python:variable_scope", - "//tensorflow/python/eager:context", - ], -) - -cuda_py_test( - name = "summary_writer_test", - srcs = ["summary_writer_test.py"], - additional_deps = [ - ":summary_writer", - "//third_party/py/numpy", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:constant_op", - "//tensorflow/python/eager:context", - "//tensorflow/python/eager:test", - ], -) - py_library( name = "metrics", srcs = [ @@ -232,6 +204,7 @@ py_test( srcs_version = "PY2AND3", deps = [ ":network", + "//tensorflow/contrib/layers:layers_py", "//tensorflow/python:constant_op", "//tensorflow/python:errors", "//tensorflow/python:framework_test_lib", @@ -246,6 +219,39 @@ py_test( ], ) +py_library( + name = "checkpointable", + srcs = ["checkpointable.py"], + srcs_version = "PY2AND3", + visibility = ["//tensorflow:internal"], + deps = [ + "//tensorflow/contrib/eager/proto:checkpointable_object_graph_proto_py", + "//tensorflow/python:training", + "//tensorflow/python:variable_scope", + ], +) + +py_test( + name = "checkpointable_test", + srcs = ["checkpointable_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":checkpointable", + ":network", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:layers", + "//tensorflow/python:training", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + "//tensorflow/python/eager:context", + "//tensorflow/python/eager:test", + "@six_archive//:six", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/eager/python/checkpointable.py b/tensorflow/contrib/eager/python/checkpointable.py new file mode 100644 index 0000000000000000000000000000000000000000..b141ffb2bc03b8e38f8481bc044c3aae7e156c15 --- /dev/null +++ b/tensorflow/contrib/eager/python/checkpointable.py @@ -0,0 +1,392 @@ +"""An object-local variable management scheme.""" +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import re + +from tensorflow.contrib.eager.proto import checkpointable_object_graph_pb2 +from tensorflow.python.ops import variable_scope +from tensorflow.python.training import optimizer as optimizer_lib +from tensorflow.python.training import saver as saver_lib + +_CheckpointableReference = collections.namedtuple( + "_CheckpointableReference", + [ + "name", # The local name if explicitly specified, else None. + "local_uid", # 0 for the first dependency, 1 for the next, ... Used for + # routing checkpointed variables to their correct + # Checkpointables when "name" is not set (see docstring of + # `track_checkpointable`). + "ref" # The Checkpointable object being referenced. + ]) + +_OwnedVariable = collections.namedtuple( + "_OwnedVariable", + [ + "name", # The variable's (local) name. + "variable" # The owned variable object. + ]) + +# Validation regular expression for the local names of Checkpointable +# objects. In particular, disallows "/" in names, and reserves +# underscore-prefixed names. +_VALID_LOCAL_NAME = re.compile(r"^[A-Za-z0-9.][A-Za-z0-9_.-]*$") + +# Keyword for identifying that the next bit of a checkpoint variable name is a +# slot name. May not be the local name of a checkpointable. Checkpoint names for +# slot variables look like: +# +# /<_OPTIMIZER_SLOTS_NAME>// +# +# Where is a full path from the checkpoint root to the +# variable being slotted for. +_OPTIMIZER_SLOTS_NAME = "_OPTIMIZER_SLOT" + + +class Checkpointable(object): + """Manages variables and dependencies on other objects. + + To make reliable checkpoints, all `Checkpointable`s on which this object + depends must be registered in the constructor using `track_checkpointable` in + a deterministic order, and if possible they should be named. Variables may be + created using `add_variable` outside of the constructor and in any order, but + only these variables will be saved. + """ + + def __init__(self): + # Basically less useful OrderedDicts but without the reference cycles. + # TODO(allenl): Switch these to OrderedDict once TensorFlow supports only + # Python 3.6+. + self._checkpoint_dependencies = [] # A list of _CheckpointableReference + # objects. + self._dependency_names = set() + self._owned_variables = [] # A list of _OwnedVariable objects. + self._owned_variable_names = set() + + def add_variable(self, name, shape, dtype=None, initializer=None, **kwargs): + """Create a new variable object to be saved with this `Checkpointable`. + + If the user has requested that this object or another `Checkpointable` which + depends on this object be restored from a checkpoint (deferred loading + before variable object creation), `initializer` may be ignored and the value + from the checkpoint used instead. + + Args: + name: A name for the variable. Must be unique within this object. + shape: The shape of the variable. + dtype: The data type of the variable. + initializer: The initializer to use. Ignored if deferred loading has been + requested. + **kwargs: Passed to get_variable. + + Returns: + The new variable object. + + Raises: + ValueError: If the variable name is not unique. + """ + if name in self._owned_variable_names: + raise ValueError( + ("A variable named '%s' already exists in this Checkpointable, but " + "Checkpointable.add_variable called to create another with " + "that name. Variable names must be unique within a Checkpointable " + "object.") % (name,)) + if "getter" in kwargs: + # Allow the getter to be overridden, typically because there is a need for + # compatibility with some other variable creation mechanism. This should + # be relatively uncommon in user code. + getter = kwargs.pop("getter") + else: + getter = variable_scope.get_variable + # TODO(allenl): handle deferred loading + new_variable = getter( + name=name, shape=shape, dtype=dtype, initializer=initializer, **kwargs) + self._owned_variables.append( + _OwnedVariable(name=name, variable=new_variable)) + self._owned_variable_names.add(name) + return new_variable + + def track_checkpointable(self, checkpointable, name=None): + """Declare a dependency on another `Checkpointable` object. + + Indicates that checkpoints for this object should include variables from + `checkpointable`. + + Variables in a checkpoint are mapped to `Checkpointable`s based on names if + provided when the checkpoint was written, but otherwise use the order those + `Checkpointable`s were declared as dependencies. Both `name` arguments and + the dependency declaration order should be deterministic. + + There are two sufficient conditions to avoid breaking existing checkpoints + when modifying a class: (1) New dependencies must be declared after existing + dependencies, and (2) dependencies which were previously declared may never + be removed (a trivial placeholder with the same name may be used instead). + + Args: + checkpointable: A `Checkpointable` which this object depends on. + name: A local name for `checkpointable`, used for loading checkpoints into + the correct objects. If provided, it must be unique within this + `Checkpointable`. If None, dependency declaration order is used instead. + + Returns: + `checkpointable`, for convenience when declaring a dependency and + assigning to a member variable in one statement. + + Raises: + RuntimeError: If __init__ was not called. + TypeError: If `checkpointable` does not inherit from `Checkpointable`. + ValueError: For invalid names. + """ + if not hasattr(self, "_checkpoint_dependencies"): + raise RuntimeError("Need to call Checkpointable.__init__ before calling " + "Checkpointable.track_checkpointable().") + if not isinstance(checkpointable, Checkpointable): + raise TypeError( + ("Checkpointable.track_checkpointable() passed type %s, not a " + "Checkpointable.") % (type(checkpointable),)) + if name is not None: + if not _VALID_LOCAL_NAME.match(name): + raise ValueError( + ("Checkpointable names must match the regular expression '%s', but " + "got an invalid name '%s' instead.") % (_VALID_LOCAL_NAME.pattern, + name)) + if name in self._dependency_names: + raise ValueError( + ("Called Checkpointable.track_checkpointable() with name='%s', but " + "a Checkpointable with this name is already declared as a " + "dependency. If provided, names must be unique.") % (name,)) + self._dependency_names.add(name) + self._checkpoint_dependencies.append( + _CheckpointableReference( + name=name, + ref=checkpointable, + # TODO(allenl): Should this be exposed to allow users to stop + # depending on things and still load checkpoints when not using + # names? + local_uid=len(self._checkpoint_dependencies))) + return checkpointable + + @property + def checkpoint_dependencies(self): + """Other `Checkpointable` objects on which this object depends.""" + return self._checkpoint_dependencies + + +def _breadth_first_checkpointable_traversal(root_checkpointable): + """Find shortest paths to all variables owned by dependencies of root.""" + bfs_sorted = [] + root_checkpointable_reference = _CheckpointableReference( + name=None, local_uid=0, ref=root_checkpointable) + to_visit = collections.deque([root_checkpointable_reference]) + path_to_root = {root_checkpointable_reference: ()} + while to_visit: + current_checkpointable = to_visit.popleft() + bfs_sorted.append(current_checkpointable) + for child_checkpointable in ( + current_checkpointable.ref.checkpoint_dependencies): + if child_checkpointable not in path_to_root: + path_to_root[child_checkpointable] = ( + path_to_root[current_checkpointable] + (child_checkpointable,)) + to_visit.append(child_checkpointable) + return bfs_sorted, path_to_root + + +def _object_prefix_from_path(path_to_root): + return "/".join((checkpointable.name if checkpointable.name else "_%d" % ( + checkpointable.local_uid,)) for checkpointable in path_to_root) + + +def _escape_variable_name(variable_name): + # We need to support slashes in variable names for compatibility, since this + # naming scheme is being patched in to things like Layer.add_variable where + # slashes were previously accepted. We also want to use slashes to indicate + # edges traversed to reach the variable, so we escape forward slashes in + # variable names. + return variable_name.replace("_S_", "_S_.").replace(r"/", r"_S__") + + +def _variable_naming_for_object(path_to_root): + """Make a function for naming variables in an object.""" + # Name non-slot variables: + # + # / + # + # is not necessarily unique, but this is fine since we also + # save the graph of `Checkpointable`s with the checkpoint. Even if this path + # no longer exists because of a change in the Python program, we can look up + # the `Checkpointable` which owns the variable in the checkpoint's graph and + # use another path if one still exists. + + object_prefix = _object_prefix_from_path(path_to_root) + if object_prefix: + object_prefix += "/" + + def _name_single_variable(owned_variable): + """Names a variable within an object.""" + return object_prefix + _escape_variable_name(owned_variable.name) + + return _name_single_variable + + +def _slot_variable_naming_for_optimizer(optimizer, path_to_root): + """Make a function for naming slot variables in an optimizer.""" + # Name slot variables: + # + # /<_OPTIMIZER_SLOTS_NAME>// + # + # where is exactly the checkpoint name used for the original + # variable, including the path from the checkpoint root and the local name in + # the object which owns it. Note that we only save slot variables if the + # variable it's slotting for is also being saved. + + optimizer_identifier = "/%s/%s/" % (_OPTIMIZER_SLOTS_NAME, + _object_prefix_from_path(path_to_root)) + + def _name_slot_variable(variable_path, slot_name): + """With an optimizer specified, name a slot variable.""" + + if not _VALID_LOCAL_NAME.match(slot_name): + # Slot variable names include the name of the slot. We need to + # validate that part of the name to be sure that the checkpoint name + # is a valid name scope name. + raise ValueError( + ("Could not save slot variables for optimizer %s, because its " + "slot name has invalid characters (got '%s', was expecting it " + "to match the regular expression '%s').") % + (optimizer, slot_name, _VALID_LOCAL_NAME.pattern)) + + return variable_path + optimizer_identifier + slot_name + + return _name_slot_variable + + +def _serialize_non_slot_variables(checkpointable_objects, path_to_root, + object_graph_proto): + """Name non-slot variables and add them to `object_graph_proto`.""" + named_variables = {} + non_slot_variables = [] + checkpoint_node_ids = {} + + for checkpoint_id, checkpointable in enumerate(checkpointable_objects): + checkpoint_node_ids[checkpointable] = checkpoint_id + + for checkpoint_id, checkpointable in enumerate(checkpointable_objects): + naming_scheme = _variable_naming_for_object(path_to_root[checkpointable]) + object_proto = object_graph_proto.nodes.add() + for owned_variable in checkpointable.ref._owned_variables: # pylint: disable=protected-access + variable_name = naming_scheme(owned_variable) + named_variables[variable_name] = owned_variable.variable + non_slot_variables.append(( + variable_name, # The variable's full checkpoint name + owned_variable, # The variable's _OwnedVariable object + checkpoint_id)) # The checkpoint ID of the node which owns this + # variable. + variable_proto = object_proto.variables.add() + variable_proto.local_name = owned_variable.name + # Figure out the name-based Saver's name for this variable. + saver_dict = saver_lib.BaseSaverBuilder.OpListToDict( + [owned_variable.variable], convert_variable_to_tensor=False) + variable_full_name, = saver_dict.keys() + variable_proto.full_name = variable_full_name + + for child in checkpointable.ref.checkpoint_dependencies: + child_proto = object_proto.children.add() + child_proto.node_id = checkpoint_node_ids[child] + child_proto.local_uid = child.local_uid + if child.name is not None: + child_proto.local_name = child.name + return named_variables, non_slot_variables + + +def _serialize_slot_variables(checkpointable_objects, path_to_root, + non_slot_variables, object_graph_proto): + """Name slot variables and add them to `object_graph_proto`.""" + named_slot_variables = {} + for optimizer_checkpoint_id, checkpointable_ref in enumerate( + checkpointable_objects): + if isinstance(checkpointable_ref.ref, optimizer_lib.Optimizer): + optimizer_object_proto = object_graph_proto.nodes[optimizer_checkpoint_id] + naming_scheme = _slot_variable_naming_for_optimizer( + optimizer=checkpointable_ref.ref, + path_to_root=path_to_root[checkpointable_ref]) + slot_names = checkpointable_ref.ref.get_slot_names() + for (variable_path, owned_variable, + original_node_checkpoint_id) in non_slot_variables: + for slot_name in slot_names: + slot_variable = checkpointable_ref.ref.get_slot( + owned_variable.variable, slot_name) + if slot_variable is not None: + checkpoint_name = naming_scheme( + variable_path=variable_path, slot_name=slot_name) + named_slot_variables[checkpoint_name] = slot_variable + slot_variable_proto = optimizer_object_proto.slot_variables.add() + slot_variable_proto.slot_name = slot_name + # Figure out the name-based Saver's name for this variable. + saver_dict = saver_lib.BaseSaverBuilder.OpListToDict( + [slot_variable], convert_variable_to_tensor=False) + slot_variable_full_name, = saver_dict.keys() + slot_variable_proto.full_name = slot_variable_full_name + slot_variable_proto.original_variable_local_name = ( + owned_variable.name) + slot_variable_proto.original_variable_node_id = ( + original_node_checkpoint_id) + return named_slot_variables + + +# TODO(allenl): Convenience utility for saving multiple objects (i.e. construct +# a root Checkpointable if passed a list of Checkpointables). +def _serialize_object_graph(root_checkpointable): + """Determine checkpoint keys for variables and build a serialized graph. + + Non-slot variables are keyed based on a shortest path from the root saveable + to the object which owns the variable (i.e. the one which called + `Checkpointable.add_variable` to create it). + + Slot variables are keyed based on a shortest path to the variable being + slotted for, a shortest path to their optimizer, and the slot name. + + Args: + root_checkpointable: A `Checkpointable` object whose variables (including + the variables of dependencies, recursively) should be saved. + + Returns: + A tuple of (named_variables, object_graph_proto): + named_variables: A dictionary mapping names to variable objects. + object_graph_proto: A CheckpointableObjectGraph protocol buffer containing + the serialized object graph and variable references. + + Raises: + ValueError: If there are invalid characters in an optimizer's slot names. + """ + checkpointable_objects, path_to_root = ( + _breadth_first_checkpointable_traversal(root_checkpointable)) + object_graph_proto = ( + checkpointable_object_graph_pb2.CheckpointableObjectGraph()) + + # Gather non-slot variables. + named_variables, non_slot_variables = _serialize_non_slot_variables( + checkpointable_objects, path_to_root, object_graph_proto) + + # Gather slot variables which are associated with variables gathered above. + named_slot_variables = _serialize_slot_variables( + checkpointable_objects, path_to_root, non_slot_variables, + object_graph_proto) + + named_variables.update(named_slot_variables) + return named_variables, object_graph_proto diff --git a/tensorflow/contrib/eager/python/checkpointable_test.py b/tensorflow/contrib/eager/python/checkpointable_test.py new file mode 100644 index 0000000000000000000000000000000000000000..f820990bbe5fe6c9b4cdf890680aaad0847010c0 --- /dev/null +++ b/tensorflow/contrib/eager/python/checkpointable_test.py @@ -0,0 +1,277 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import functools +import six + +from tensorflow.contrib.eager.python import checkpointable +from tensorflow.contrib.eager.python import network as network_lib +from tensorflow.python.eager import context +from tensorflow.python.eager import test +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.layers import core +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables +from tensorflow.python.training import adam +from tensorflow.python.training import training_util + + +class CheckpointableDenseLayer(core.Dense, checkpointable.Checkpointable): + + def __init__(self, *args, **kwargs): + checkpointable.Checkpointable.__init__(self) + core.Dense.__init__(self, *args, **kwargs) + + def add_variable(self, name, shape, **kwargs): + # Calls both Checkpointable.add_variable and Layer.add_variable. Eventually + # Layer.add_variable should inherit from Checkpointable and simply call + # super and then do post-processing. + return checkpointable.Checkpointable.add_variable( + self, + name=name, + shape=shape, + getter=functools.partial(core.Dense.add_variable, self), + **kwargs) + + +# pylint: disable=not-callable +class CheckpointableNetwork(network_lib.Network, checkpointable.Checkpointable): + + def __init__(self): + network_lib.Network.__init__(self) + checkpointable.Checkpointable.__init__(self) + + def track_layer(self, layer, name=None): + self.track_checkpointable(layer, name=name) + return super(CheckpointableNetwork, self).track_layer(layer) + + +class CheckpointableAdam(adam.AdamOptimizer, checkpointable.Checkpointable): + + def __init__(self, *args, **kwargs): + checkpointable.Checkpointable.__init__(self) + adam.AdamOptimizer.__init__(self, *args, **kwargs) + + # NOTE: Copied from AdamOptimizer with modifications to use add_variable + # for non-slot variables. These contortions are necessary to maintain + # checkpoint compatibility with variable.name based saving. + def _create_slots(self, var_list): + # Create the beta1 and beta2 accumulators on the same device as the first + # variable. Sort the var_list to make sure this device is consistent across + # workers (these need to go on the same PS, otherwise some updates are + # silently ignored). + first_var = min(var_list, key=lambda x: x.name) + + create_new = self._beta1_power is None + if not create_new and context.in_graph_mode(): + create_new = (self._beta1_power.graph is not first_var.graph) + + if create_new: + with ops.colocate_with(first_var): + + def _variable_getter(name, shape, dtype, initializer): + del shape, dtype # not used, but there for compatibility + return variable_scope.variable( + name=name, initial_value=initializer, trainable=False) + + self._beta1_power = self.add_variable( + name="beta1_power", + shape=[], + initializer=self._beta1, + getter=_variable_getter) + self._beta2_power = self.add_variable( + name="beta2_power", + shape=[], + initializer=self._beta2, + getter=_variable_getter) + # Create slots for the first and second moments. + for v in var_list: + self._zeros_slot(v, "m", self._name) + self._zeros_slot(v, "v", self._name) + + # TODO(allenl): Override slot variable creation (_get_or_make_slot, + # _get_or_make_slot_with_initializer, _zeros_slot) to allow deferred + # loading. Likely no need to run this through add_variable, since gathering + # slot variables is special cased anyway. + + +class MyNetwork(CheckpointableNetwork): + """A concrete Network for testing.""" + + def __init__(self): + super(MyNetwork, self).__init__() + self._named = self.track_layer( + CheckpointableDenseLayer(1, use_bias=True), name="named_dense") + self._unnamed = self.track_layer( + CheckpointableDenseLayer(1, use_bias=False)) + + def call(self, values): + return self._unnamed(self._named(values)) + + +class Root(checkpointable.Checkpointable): + """A stand-in for a Trainer class.""" + + def __init__(self, optimizer, network): + super(Root, self).__init__() + self.track_checkpointable(optimizer, name="optimizer") + self.track_checkpointable(network, name="network") + self._global_step = None + + @property + def global_step(self): + if self._global_step is None: + # Get the default create_global_step utility to actually call + # self.add_variable, by setting a custom getter. + def _owned_variable_as_custom_getter(getter, *args, **kwargs): + return self.add_variable(*args, getter=getter, **kwargs) + + with variable_scope.variable_scope( + "", custom_getter=_owned_variable_as_custom_getter): + self._global_step = training_util.create_global_step() + return self._global_step + + +class CheckpointNamingTests(test.TestCase): + + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) + def testNamingWithOptimizer(self): + input_value = constant_op.constant([[3.]]) + network = MyNetwork() + # A nuisance Network using the same optimizer. Its slot variables should not + # go in the checkpoint, since it is never depended on. + other_network = MyNetwork() + optimizer = CheckpointableAdam(0.001) + root_checkpointable = Root(optimizer=optimizer, network=network) + if context.in_eager_mode(): + optimizer.minimize( + lambda: network(input_value), + global_step=root_checkpointable.global_step) + optimizer.minimize( + lambda: other_network(input_value), + global_step=root_checkpointable.global_step) + else: + train_op = optimizer.minimize( + network(input_value), global_step=root_checkpointable.global_step) + optimizer.minimize( + other_network(input_value), + global_step=root_checkpointable.global_step) + self.evaluate(variables.global_variables_initializer()) + self.evaluate(train_op) + named_variables, serialized_graph = checkpointable._serialize_object_graph( + root_checkpointable) + expected_checkpoint_names = ( + # Created in the root node, so no prefix. + "global_step", + # No name provided to track_checkpointable(), so the position (1, after + # the named track_checkpointable() which is 0) is used instead. + "network/_1/kernel", + # track_checkpointable() with a name provided, so that's used + "network/named_dense/kernel", + "network/named_dense/bias", + # The optimizer creates two non-slot variables + "optimizer/beta1_power", + "optimizer/beta2_power", + # Slot variables + "network/_1/kernel/_OPTIMIZER_SLOT/optimizer/m", + "network/_1/kernel/_OPTIMIZER_SLOT/optimizer/v", + "network/named_dense/kernel/_OPTIMIZER_SLOT/optimizer/m", + "network/named_dense/kernel/_OPTIMIZER_SLOT/optimizer/v", + "network/named_dense/bias/_OPTIMIZER_SLOT/optimizer/m", + "network/named_dense/bias/_OPTIMIZER_SLOT/optimizer/v", + ) + six.assertCountEqual(self, expected_checkpoint_names, + named_variables.keys()) + # Check that we've mapped to the right variable objects (not exhaustive) + self.assertEqual("global_step:0", named_variables["global_step"].name) + self.assertEqual("my_network/checkpointable_dense_layer_1/kernel:0", + named_variables["network/_1/kernel"].name) + self.assertEqual("my_network/checkpointable_dense_layer/kernel:0", + named_variables["network/named_dense/kernel"].name) + self.assertEqual("beta1_power:0", + named_variables["optimizer/beta1_power"].name) + self.assertEqual("beta2_power:0", + named_variables["optimizer/beta2_power"].name) + # Spot check the generated protocol buffers. + self.assertEqual(0, serialized_graph.nodes[0].children[0].local_uid) + self.assertEqual("optimizer", + serialized_graph.nodes[0].children[0].local_name) + optimizer_node = serialized_graph.nodes[serialized_graph.nodes[0].children[ + 0].node_id] + self.assertEqual("beta1_power", optimizer_node.variables[0].local_name) + self.assertEqual("beta1_power", optimizer_node.variables[0].full_name) + self.assertEqual( + "kernel", optimizer_node.slot_variables[0].original_variable_local_name) + original_variable_owner = serialized_graph.nodes[ + optimizer_node.slot_variables[0].original_variable_node_id] + self.assertEqual("kernel", original_variable_owner.variables[0].local_name) + self.assertEqual("m", optimizer_node.slot_variables[0].slot_name) + # We strip off the :0 suffix, as variable.name-based saving does. + self.assertEqual("my_network/checkpointable_dense_layer/kernel/Adam", + optimizer_node.slot_variables[0].full_name) + self.assertEqual("my_network/checkpointable_dense_layer/kernel/Adam:0", + optimizer.get_slot( + var=named_variables["network/named_dense/kernel"], + name="m").name) + + def _get_checkpoint_name(self, name): + root = checkpointable.Checkpointable() + with variable_scope.variable_scope("get_checkpoint_name"): + # Create the variable in a variable scope so that we get more relaxed + # naming rules (variables outside a scope may not start with "_", "/" or + # "-"). Since we don't use the scope part of the name, these cases are + # somewhat annoying. + root.add_variable(name=name, shape=[1, 2], dtype=dtypes.float64) + named_variables, _ = checkpointable._serialize_object_graph(root) + checkpoint_name, = named_variables.keys() + with ops.name_scope("root/" + checkpoint_name): + pass # Make sure we can use this as an op name if we prefix it. + return checkpoint_name + + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) + def testVariableNameEscaping(self): + self.assertEqual(r"a_S__b_S__c", self._get_checkpoint_name(r"a/b/c")) + self.assertEqual(r"", self._get_checkpoint_name(r"")) + self.assertEqual(r"_S__", self._get_checkpoint_name(r"/")) + self.assertEqual(r"_S___S_._", self._get_checkpoint_name(r"/_S__")) + + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) + def testNumberedPath(self): + root = checkpointable.Checkpointable() + leaf = checkpointable.Checkpointable() + root.track_checkpointable(leaf) + leaf.add_variable(name="v", shape=[]) + named_variables, _ = checkpointable._serialize_object_graph(root) + variable_name, = named_variables.keys() + self.assertEqual(r"_0/v", variable_name) + + @test_util.run_in_graph_and_eager_modes() + def testLocalNameValidation(self): + root = checkpointable.Checkpointable() + leaf = checkpointable.Checkpointable() + with self.assertRaisesRegexp(ValueError, "invalid name"): + # Leading underscores are reserved, which avoids conflicts with + # un-named edges in paths and the optimizer slots identifier. + root.track_checkpointable(leaf, name="_12") + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/eager/python/datasets.py b/tensorflow/contrib/eager/python/datasets.py index b559cce6b12a809d671ce7855680063f02a4ac22..a7f50c13bb992fd47669fb9956dde6b271e16ffd 100644 --- a/tensorflow/contrib/eager/python/datasets.py +++ b/tensorflow/contrib/eager/python/datasets.py @@ -23,6 +23,7 @@ import threading from tensorflow.contrib.data.python.ops import prefetching_ops from tensorflow.python.data.ops import iterator_ops from tensorflow.python.data.util import nest +from tensorflow.python.data.util import sparse from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -41,7 +42,7 @@ def _generate_shared_name(prefix): global _uid_counter uid = _uid_counter _uid_counter += 1 - return "{}_{}".format(prefix, uid) + return "{}{}".format(prefix, uid) class Iterator(object): @@ -75,13 +76,16 @@ class Iterator(object): format(type(self))) with ops.device("/device:CPU:0"): ds_variant = dataset._as_variant_tensor() # pylint: disable=protected-access + self._output_classes = dataset.output_classes self._output_types = dataset.output_types self._output_shapes = dataset.output_shapes - self._flat_output_types = nest.flatten(dataset.output_types) - self._flat_output_shapes = nest.flatten(dataset.output_shapes) + self._flat_output_types = nest.flatten( + sparse.as_dense_types(self._output_types, self._output_classes)) + self._flat_output_shapes = nest.flatten( + sparse.as_dense_shapes(self._output_shapes, self._output_classes)) self._resource = gen_dataset_ops.iterator( - container="", - shared_name=_generate_shared_name("eager_iterator"), + shared_name="", + container=_generate_shared_name("eageriterator"), output_types=self._flat_output_types, output_shapes=self._flat_output_shapes) gen_dataset_ops.make_iterator(ds_variant, self._resource) @@ -125,22 +129,78 @@ class Iterator(object): def __next__(self): # For Python 3 compatibility return self.next() - def next(self): - """Return the next tf.Tensor from the dataset.""" + def _next_internal(self): + """Returns a nested structure of `tf.Tensor`s containing the next element. + """ with ops.device(self._device): - try: - if self._buffer_resource_handle is not None: - ret = prefetching_ops.function_buffering_resource_get_next( - function_buffer_resource=self._buffer_resource_handle, - output_types=self._flat_output_types) - else: - # TODO(ashankar): Consider removing this ops.device() contextmanager - # and instead mimic ops placement in graphs: Operations on resource - # handles execute on the same device as where the resource is placed. - ret = gen_dataset_ops.iterator_get_next( - self._resource, - output_types=self._flat_output_types, - output_shapes=self._flat_output_shapes) - except errors.OutOfRangeError: - raise StopIteration - return nest.pack_sequence_as(self._output_types, ret) + if self._buffer_resource_handle is not None: + ret = prefetching_ops.function_buffering_resource_get_next( + function_buffer_resource=self._buffer_resource_handle, + output_types=self._flat_output_types) + else: + # TODO(ashankar): Consider removing this ops.device() contextmanager + # and instead mimic ops placement in graphs: Operations on resource + # handles execute on the same device as where the resource is placed. + ret = gen_dataset_ops.iterator_get_next( + self._resource, + output_types=self._flat_output_types, + output_shapes=self._flat_output_shapes) + + return sparse.deserialize_sparse_tensors( + nest.pack_sequence_as(self._output_types, ret), self._output_types, + self._output_shapes, self._output_classes) + + def next(self): + """Returns a nested structure of `tf.Tensor`s containing the next element. + """ + try: + return self._next_internal() + except errors.OutOfRangeError: + raise StopIteration + + @property + def output_classes(self): + """Returns the class of each component of an element of this iterator. + + The expected values are `tf.Tensor` and `tf.SparseTensor`. + + Returns: + A nested structure of Python `type` objects corresponding to each + component of an element of this dataset. + """ + return self._output_classes + + @property + def output_shapes(self): + """Returns the shape of each component of an element of this iterator. + + Returns: + A nested structure of `tf.TensorShape` objects corresponding to each + component of an element of this dataset. + """ + return self._output_shapes + + @property + def output_types(self): + """Returns the type of each component of an element of this iterator. + + Returns: + A nested structure of `tf.DType` objects corresponding to each component + of an element of this dataset. + """ + return self._output_types + + def get_next(self, name=None): + """Returns a nested structure of `tf.Tensor`s containing the next element. + + Args: + name: (Optional.) A name for the created operation. Currently unused. + + Returns: + A nested structure of `tf.Tensor` objects. + + Raises: + `tf.errors.OutOfRangeError`: If the end of the dataset has been reached. + """ + del name + return self._next_internal() diff --git a/tensorflow/contrib/eager/python/datasets_test.py b/tensorflow/contrib/eager/python/datasets_test.py index c924d81c9d85e638e4f35f260664c0ee7d03257e..a1611e92b113839c2dd2a3b2560b0ba90c0a7ef0 100644 --- a/tensorflow/contrib/eager/python/datasets_test.py +++ b/tensorflow/contrib/eager/python/datasets_test.py @@ -16,11 +16,19 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import time + +import numpy as np + +from tensorflow.contrib import lookup from tensorflow.contrib.eager.python import datasets from tensorflow.python.data import Dataset from tensorflow.python.eager import test +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import math_ops from tensorflow.python.ops import script_ops @@ -33,6 +41,15 @@ class IteratorTest(test.TestCase): got.append(t.numpy()) self.assertAllEqual([0, 1, 2, 3], got) + def testGetNext(self): + iterator = datasets.Iterator(Dataset.range(4)) + self.assertEqual(0, iterator.get_next().numpy()) + self.assertEqual(1, iterator.get_next().numpy()) + self.assertEqual(2, iterator.get_next().numpy()) + self.assertEqual(3, iterator.get_next().numpy()) + with self.assertRaises(errors.OutOfRangeError): + iterator.get_next() + def testMultipleIteratorsOnTheSameDataset(self): ds = Dataset.range(4) it1 = datasets.Iterator(ds) @@ -64,6 +81,18 @@ class IteratorTest(test.TestCase): got = [x.numpy() for x in it] self.assertAllEqual([0, 4, 16, 36], got) + def testMapCaptureLookupTable(self): + default_val = -1 + keys = constant_op.constant(['brain', 'salad', 'surgery']) + values = constant_op.constant([0, 1, 2], dtypes.int64) + table = lookup.HashTable( + lookup.KeyValueTensorInitializer(keys, values), default_val) + dataset = Dataset.from_tensor_slices(['brain', 'salad', 'surgery']) + dataset = dataset.map(table.lookup) + it = datasets.Iterator(dataset) + got = [x.numpy() for x in it] + self.assertAllEqual([0, 1, 2], got) + def testMultipleIteratorsOnADatasetThatUsesFunctions(self): ds = Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6]).map(math_ops.square) @@ -72,6 +101,53 @@ class IteratorTest(test.TestCase): got2 = [x.numpy() for x in datasets.Iterator(ds)] self.assertAllEqual(got1, got2) + def assertSparseValuesEqual(self, a, b): + self.assertAllEqual(a.indices, b.indices) + self.assertAllEqual(a.values, b.values) + self.assertAllEqual(a.dense_shape, b.dense_shape) + + def testSparseTensorElements(self): + components = (sparse_tensor.SparseTensorValue( + indices=np.array([[0, 0], [1, 0], [2, 0]]), + values=np.array([0, 0, 0]), + dense_shape=np.array([3, 1])), + sparse_tensor.SparseTensorValue( + indices=np.array([[0, 0], [1, 1], [2, 2]]), + values=np.array([1, 2, 3]), + dense_shape=np.array([3, 3]))) + + expected = [ + (sparse_tensor.SparseTensorValue( + indices=np.array([[0]]), + values=np.array([0]), + dense_shape=np.array([1])), + sparse_tensor.SparseTensorValue( + indices=np.array([[0]]), + values=np.array([1]), + dense_shape=np.array([3]))), + (sparse_tensor.SparseTensorValue( + indices=np.array([[0]]), + values=np.array([0]), + dense_shape=np.array([1])), + sparse_tensor.SparseTensorValue( + indices=np.array([[1]]), + values=np.array([2]), + dense_shape=np.array([3]))), + (sparse_tensor.SparseTensorValue( + indices=np.array([[0]]), + values=np.array([0]), + dense_shape=np.array([1])), + sparse_tensor.SparseTensorValue( + indices=np.array([[2]]), + values=np.array([3]), + dense_shape=np.array([3]))), + ] + + for i, result in enumerate( + datasets.Iterator(Dataset.from_tensor_slices(components))): + self.assertSparseValuesEqual(expected[i][0], result[0]) + self.assertSparseValuesEqual(expected[i][1], result[1]) + def testPyFunc(self): def my_map(inp): @@ -90,5 +166,64 @@ class IteratorTest(test.TestCase): self.assertAllEqual([0., 2.], x.numpy()) +class DatasetConstructorBenchmark(test.Benchmark): + + def benchmarkSliceRepeatBatchEager(self): + input_size = 10000 + batch_size = 100 + num_epochs = 100 + + input_data = np.random.randn(input_size) + + dataset = ( + Dataset.from_tensor_slices(input_data).repeat(num_epochs) + .batch(batch_size)) + iterator = datasets.Iterator(dataset) + + ends = [time.time()] + for _ in iterator: + ends.append(time.time()) + + deltas = np.ediff1d(ends) + median_wall_time = np.median(deltas) + print( + 'Slice/repeat/batch eager input size: %d batch size: %d Median wall ' + 'time per element: %f' + % (input_size, batch_size, median_wall_time)) + self.report_benchmark( + iters=len(deltas), + wall_time=median_wall_time, + name='benchmark_slice_repeat_batch_eager_input_%d_batch_%d' % + (input_size, batch_size)) + + def benchmarkSliceBatchCacheRepeatCallable(self): + input_size = 10000 + batch_size = 100 + num_epochs = 100 + + input_data = np.random.randn(input_size) + + dataset = ( + Dataset.from_tensor_slices(input_data).batch(batch_size).cache() + .repeat(num_epochs)) + iterator = datasets.Iterator(dataset) + + ends = [time.time()] + for _ in iterator: + ends.append(time.time()) + + deltas = np.ediff1d(ends) + median_wall_time = np.median(deltas) + print( + 'Slice/batch/cache/repeat eager input size: %d batch size: %d Median ' + 'wall time per element: %f' + % (input_size, batch_size, median_wall_time)) + self.report_benchmark( + iters=len(deltas), + wall_time=median_wall_time, + name='benchmark_slice_batch_cache_repeat_eager_input_%d_batch_%d' % + (input_size, batch_size)) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/eager/python/evaluator.py b/tensorflow/contrib/eager/python/evaluator.py index bd0ab02ecf7ae6025e08dde1c3ddc634db9255c1..3faaeef5903615ea122800a6690117dde682e830 100644 --- a/tensorflow/contrib/eager/python/evaluator.py +++ b/tensorflow/contrib/eager/python/evaluator.py @@ -110,7 +110,7 @@ class Evaluator(object): return self._all_metric_results() else: def f(): - with summary_ops.create_summary_file_writer( + with summary_ops.create_file_writer( summary_logdir).as_default(), summary_ops.always_record_summaries(): return self._all_metric_results() if context.in_eager_mode(): diff --git a/tensorflow/contrib/eager/python/examples/BUILD b/tensorflow/contrib/eager/python/examples/BUILD index aa21a6ab994acf929890ecebc07a86cf7ebf97db..15a21885f66eface291a39fa0ee1ff28bc297548 100644 --- a/tensorflow/contrib/eager/python/examples/BUILD +++ b/tensorflow/contrib/eager/python/examples/BUILD @@ -6,10 +6,12 @@ package(default_visibility = ["//tensorflow:internal"]) py_library( name = "examples_pip", deps = [ + "//tensorflow/contrib/eager/python/examples/gan:mnist", "//tensorflow/contrib/eager/python/examples/linear_regression", "//tensorflow/contrib/eager/python/examples/mnist", "//tensorflow/contrib/eager/python/examples/resnet50", "//tensorflow/contrib/eager/python/examples/rnn_colorbot", "//tensorflow/contrib/eager/python/examples/rnn_ptb", + "//tensorflow/contrib/eager/python/examples/spinn:data", ], ) diff --git a/tensorflow/contrib/eager/python/examples/gan/BUILD b/tensorflow/contrib/eager/python/examples/gan/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..c61ec2dbae60a782c0e6589701554b045dcb92ae --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/gan/BUILD @@ -0,0 +1,36 @@ +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//tensorflow:internal"]) + +load("//tensorflow:tensorflow.bzl", "cuda_py_test") + +py_binary( + name = "mnist", + srcs = ["mnist.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow:tensorflow_py", + "//tensorflow/contrib/eager/python:tfe", + "//tensorflow/examples/tutorials/mnist:input_data", + ], +) + +cuda_py_test( + name = "mnist_test", + srcs = ["mnist_test.py"], + additional_deps = [ + ":mnist", + "//tensorflow/contrib/eager/python:tfe", + "//tensorflow:tensorflow_py", + ], +) + +cuda_py_test( + name = "mnist_graph_test", + srcs = ["mnist_graph_test.py"], + additional_deps = [ + ":mnist", + "//third_party/py/numpy", + "//tensorflow:tensorflow_py", + ], +) diff --git a/tensorflow/contrib/eager/python/examples/gan/README.md b/tensorflow/contrib/eager/python/examples/gan/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e8c9db1a1e2eb5881b08a4d3866c82b24d64be12 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/gan/README.md @@ -0,0 +1,38 @@ +# GAN with TensorFlow eager execution + +A simple Generative Adversarial Network (GAN) example using eager execution. +The discriminator and generator networks each contain a few convolution and +fully connected layers. + +Other eager execution examples can be found under the parent directory. + +## Content + +- `mnist.py`: Model definitions and training routines. +- `mnist_test.py`: Benchmarks for training and using the models using eager +execution. +- `mnist_graph_test.py`: Benchmarks for trainig and using the models using +graph execution. The same model definitions and loss functions are used in +all benchmarks. + + +## To run + +- Make sure you have installed TensorFlow 1.5+ or the latest `tf-nightly` +or `tf-nightly-gpu` pip package in order to access the eager execution feature. + +- Train model. E.g., + + ```bash + python mnist.py + ``` + + Use `--output_dir=` to direct the script to save TensorBoard summaries + during training. Disabled by default. + + Use `--checkpoint_dir=` to direct the script to save checkpoints to + `` during training. DIR defaults to /tmp/tensorflow/mnist/checkpoints/. + The script will load the latest saved checkpoint from this directory if + one exists. + + Use `-h` for other options. diff --git a/tensorflow/contrib/eager/python/examples/gan/mnist.py b/tensorflow/contrib/eager/python/examples/gan/mnist.py new file mode 100644 index 0000000000000000000000000000000000000000..b9ac79f46c83bb709918e3b72830b90ddcfd71b4 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/gan/mnist.py @@ -0,0 +1,368 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""A deep MNIST classifier using convolutional layers. + +Sample usage: + python mnist.py --help +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import os +import sys +import time + +import tensorflow as tf + +import tensorflow.contrib.eager as tfe +from tensorflow.examples.tutorials.mnist import input_data + +FLAGS = None + + +class Discriminator(tfe.Network): + """GAN Discriminator. + + A network to differentiate between generated and real handwritten digits. + """ + + def __init__(self, data_format): + """Creates a model for discriminating between real and generated digits. + + Args: + data_format: Either 'channels_first' or 'channels_last'. + 'channels_first' is typically faster on GPUs while 'channels_last' is + typically faster on CPUs. See + https://www.tensorflow.org/performance/performance_guide#data_formats + """ + super(Discriminator, self).__init__(name='') + if data_format == 'channels_first': + self._input_shape = [-1, 1, 28, 28] + else: + assert data_format == 'channels_last' + self._input_shape = [-1, 28, 28, 1] + self.conv1 = self.track_layer(tf.layers.Conv2D(64, 5, padding='SAME', + data_format=data_format, + activation=tf.tanh)) + self.pool1 = self.track_layer( + tf.layers.AveragePooling2D(2, 2, data_format=data_format)) + self.conv2 = self.track_layer(tf.layers.Conv2D(128, 5, + data_format=data_format, + activation=tf.tanh)) + self.pool2 = self.track_layer( + tf.layers.AveragePooling2D(2, 2, data_format=data_format)) + self.flatten = self.track_layer(tf.layers.Flatten()) + self.fc1 = self.track_layer(tf.layers.Dense(1024, activation=tf.tanh)) + self.fc2 = self.track_layer(tf.layers.Dense(1, activation=None)) + + def call(self, inputs): + """Return two logits per image estimating input authenticity. + + Users should invoke __call__ to run the network, which delegates to this + method (and not call this method directly). + + Args: + inputs: A batch of images as a Tensor with shape [batch_size, 28, 28, 1] + or [batch_size, 1, 28, 28] + + Returns: + A Tensor with shape [batch_size] containing logits estimating + the probability that corresponding digit is real. + """ + x = tf.reshape(inputs, self._input_shape) + x = self.conv1(x) + x = self.pool1(x) + x = self.conv2(x) + x = self.pool2(x) + x = self.flatten(x) + x = self.fc1(x) + x = self.fc2(x) + return x + + +class Generator(tfe.Network): + """Generator of handwritten digits similar to the ones in the MNIST dataset. + """ + + def __init__(self, data_format): + """Creates a model for discriminating between real and generated digits. + + Args: + data_format: Either 'channels_first' or 'channels_last'. + 'channels_first' is typically faster on GPUs while 'channels_last' is + typically faster on CPUs. See + https://www.tensorflow.org/performance/performance_guide#data_formats + """ + super(Generator, self).__init__(name='') + self.data_format = data_format + # We are using 128 6x6 channels as input to the first deconvolution layer + if data_format == 'channels_first': + self._pre_conv_shape = [-1, 128, 6, 6] + else: + assert data_format == 'channels_last' + self._pre_conv_shape = [-1, 6, 6, 128] + self.fc1 = self.track_layer(tf.layers.Dense(6 * 6 * 128, + activation=tf.tanh)) + + # In call(), we reshape the output of fc1 to _pre_conv_shape + + # Deconvolution layer. Resulting image shape: (batch, 14, 14, 64) + self.conv1 = self.track_layer(tf.layers.Conv2DTranspose( + 64, 4, strides=2, activation=None, data_format=data_format)) + + # Deconvolution layer. Resulting image shape: (batch, 28, 28, 1) + self.conv2 = self.track_layer(tf.layers.Conv2DTranspose( + 1, 2, strides=2, activation=tf.nn.sigmoid, data_format=data_format)) + + def call(self, inputs): + """Return a batch of generated images. + + Users should invoke __call__ to run the network, which delegates to this + method (and not call this method directly). + + Args: + inputs: A batch of noise vectors as a Tensor with shape + [batch_size, length of noise vectors]. + + Returns: + A Tensor containing generated images. If data_format is 'channels_last', + the shape of returned images is [batch_size, 28, 28, 1], else + [batch_size, 1, 28, 28] + """ + + x = self.fc1(inputs) + x = tf.reshape(x, shape=self._pre_conv_shape) + x = self.conv1(x) + x = self.conv2(x) + return x + + +def discriminator_loss(discriminator_real_outputs, discriminator_gen_outputs): + """Original discriminator loss for GANs, with label smoothing. + + See `Generative Adversarial Nets` (https://arxiv.org/abs/1406.2661) for more + details. + + Args: + discriminator_real_outputs: Discriminator output on real data. + discriminator_gen_outputs: Discriminator output on generated data. Expected + to be in the range of (-inf, inf). + + Returns: + A scalar loss Tensor. + """ + + loss_on_real = tf.losses.sigmoid_cross_entropy( + tf.ones_like(discriminator_real_outputs), discriminator_real_outputs, + label_smoothing=0.25) + loss_on_generated = tf.losses.sigmoid_cross_entropy( + tf.zeros_like(discriminator_gen_outputs), discriminator_gen_outputs) + loss = loss_on_real + loss_on_generated + tf.contrib.summary.scalar('discriminator_loss', loss) + return loss + + +def generator_loss(discriminator_gen_outputs): + """Original generator loss for GANs. + + L = -log(sigmoid(D(G(z)))) + + See `Generative Adversarial Nets` (https://arxiv.org/abs/1406.2661) + for more details. + + Args: + discriminator_gen_outputs: Discriminator output on generated data. Expected + to be in the range of (-inf, inf). + + Returns: + A scalar loss Tensor. + """ + loss = tf.losses.sigmoid_cross_entropy( + tf.ones_like(discriminator_gen_outputs), discriminator_gen_outputs) + tf.contrib.summary.scalar('generator_loss', loss) + return loss + + +def train_one_epoch(generator, discriminator, + generator_optimizer, discriminator_optimizer, + dataset, log_interval, noise_dim): + """Trains `generator` and `discriminator` models on `dataset`. + + Args: + generator: Generator model. + discriminator: Discriminator model. + generator_optimizer: Optimizer to use for generator. + discriminator_optimizer: Optimizer to use for discriminator. + dataset: Dataset of images to train on. + log_interval: How many global steps to wait between logging and collecting + summaries. + noise_dim: Dimension of noise vector to use. + """ + + total_generator_loss = 0.0 + total_discriminator_loss = 0.0 + for (batch_index, images) in enumerate(tfe.Iterator(dataset)): + with tf.device('/cpu:0'): + tf.assign_add(tf.train.get_global_step(), 1) + + with tf.contrib.summary.record_summaries_every_n_global_steps(log_interval): + current_batch_size = images.shape[0] + noise = tf.random_uniform(shape=[current_batch_size, noise_dim], + minval=-1., maxval=1., seed=batch_index) + + with tfe.GradientTape(persistent=True) as g: + generated_images = generator(noise) + tf.contrib.summary.image('generated_images', + tf.reshape(generated_images, [-1, 28, 28, 1]), + max_images=10) + + discriminator_gen_outputs = discriminator(generated_images) + discriminator_real_outputs = discriminator(images) + discriminator_loss_val = discriminator_loss(discriminator_real_outputs, + discriminator_gen_outputs) + total_discriminator_loss += discriminator_loss_val + + generator_loss_val = generator_loss(discriminator_gen_outputs) + total_generator_loss += generator_loss_val + + generator_grad = g.gradient(generator_loss_val, generator.variables) + discriminator_grad = g.gradient(discriminator_loss_val, + discriminator.variables) + + with tf.variable_scope('generator'): + generator_optimizer.apply_gradients(zip(generator_grad, + generator.variables)) + with tf.variable_scope('discriminator'): + discriminator_optimizer.apply_gradients(zip(discriminator_grad, + discriminator.variables)) + + if log_interval and batch_index > 0 and batch_index % log_interval == 0: + print('Batch #%d\tAverage Generator Loss: %.6f\t' + 'Average Discriminator Loss: %.6f' % ( + batch_index, total_generator_loss/batch_index, + total_discriminator_loss/batch_index)) + + +def main(_): + (device, data_format) = ('/gpu:0', 'channels_first') + if FLAGS.no_gpu or tfe.num_gpus() <= 0: + (device, data_format) = ('/cpu:0', 'channels_last') + print('Using device %s, and data format %s.' % (device, data_format)) + + # Load the datasets + data = input_data.read_data_sets(FLAGS.data_dir) + dataset = (tf.data.Dataset + .from_tensor_slices(data.train.images) + .shuffle(60000) + .batch(FLAGS.batch_size)) + + # Create the models and optimizers + generator = Generator(data_format) + discriminator = Discriminator(data_format) + with tf.variable_scope('generator'): + generator_optimizer = tf.train.AdamOptimizer(FLAGS.lr) + with tf.variable_scope('discriminator'): + discriminator_optimizer = tf.train.AdamOptimizer(FLAGS.lr) + + # Prepare summary writer and checkpoint info + summary_writer = tf.contrib.summary.create_summary_file_writer( + FLAGS.output_dir, flush_millis=1000) + checkpoint_prefix = os.path.join(FLAGS.checkpoint_dir, 'ckpt') + latest_cpkt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir) + if latest_cpkt: + print('Using latest checkpoint at ' + latest_cpkt) + + with tf.device(device): + for epoch in range(1, 101): + with tfe.restore_variables_on_create(latest_cpkt): + global_step = tf.train.get_or_create_global_step() + start = time.time() + with summary_writer.as_default(): + train_one_epoch(generator, discriminator, generator_optimizer, + discriminator_optimizer, + dataset, FLAGS.log_interval, FLAGS.noise) + end = time.time() + print('\nTrain time for epoch #%d (global step %d): %f' % ( + epoch, global_step.numpy(), end - start)) + + all_variables = ( + generator.variables + + discriminator.variables + + generator_optimizer.variables() + + discriminator_optimizer.variables() + + [global_step]) + tfe.Saver(all_variables).save( + checkpoint_prefix, global_step=global_step) + + +if __name__ == '__main__': + tfe.enable_eager_execution() + + parser = argparse.ArgumentParser() + parser.add_argument( + '--data-dir', + type=str, + default='/tmp/tensorflow/mnist/input_data', + help=('Directory for storing input data (default ' + '/tmp/tensorflow/mnist/input_data)')) + parser.add_argument( + '--batch-size', + type=int, + default=128, + metavar='N', + help='input batch size for training (default: 128)') + parser.add_argument( + '--log-interval', + type=int, + default=100, + metavar='N', + help=('number of batches between logging and writing summaries ' + '(default: 100)')) + parser.add_argument( + '--output_dir', + type=str, + default=None, + metavar='DIR', + help='Directory to write TensorBoard summaries (defaults to none)') + parser.add_argument( + '--checkpoint_dir', + type=str, + default='/tmp/tensorflow/mnist/checkpoints/', + metavar='DIR', + help=('Directory to save checkpoints in (once per epoch) (default ' + '/tmp/tensorflow/mnist/checkpoints/)')) + parser.add_argument( + '--lr', + type=float, + default=0.001, + metavar='LR', + help='learning rate (default: 0.001)') + parser.add_argument( + '--noise', + type=int, + default=100, + metavar='N', + help='Length of noise vector for generator input (default: 100)') + parser.add_argument( + '--no-gpu', + action='store_true', + default=False, + help='disables GPU usage even if a GPU is available') + + FLAGS, unparsed = parser.parse_known_args() + tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/contrib/eager/python/examples/gan/mnist_graph_test.py b/tensorflow/contrib/eager/python/examples/gan/mnist_graph_test.py new file mode 100644 index 0000000000000000000000000000000000000000..12b39b0cde49d4c017acfa74572c725036c54eff --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/gan/mnist_graph_test.py @@ -0,0 +1,151 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tempfile +import time + +import numpy as np +import tensorflow as tf + +from tensorflow.contrib.eager.python.examples.gan import mnist + +NOISE_DIM = 100 +# Big enough so that summaries are never recorded. +# Lower this value if would like to benchmark with some summaries. +SUMMARY_INTERVAL = 10000 +SUMMARY_FLUSH_MS = 100 # Flush summaries every 100ms + + +def data_format(): + return 'channels_first' if tf.test.is_gpu_available() else 'channels_last' + + +class MnistGraphGanBenchmark(tf.test.Benchmark): + + def _create_graph(self, batch_size): + # Generate some random data. + images_data = np.random.randn(batch_size, 784).astype(np.float32) + dataset = tf.data.Dataset.from_tensors(images_data) + images = dataset.repeat().make_one_shot_iterator().get_next() + + # Create the models and optimizers + generator = mnist.Generator(data_format()) + discriminator = mnist.Discriminator(data_format()) + with tf.variable_scope('generator'): + generator_optimizer = tf.train.AdamOptimizer(0.001) + with tf.variable_scope('discriminator'): + discriminator_optimizer = tf.train.AdamOptimizer(0.001) + + # Run models and compute loss + noise_placeholder = tf.placeholder(tf.float32, + shape=[batch_size, NOISE_DIM]) + generated_images = generator(noise_placeholder) + tf.contrib.summary.image('generated_images', + tf.reshape(generated_images, [-1, 28, 28, 1]), + max_images=10) + discriminator_gen_outputs = discriminator(generated_images) + discriminator_real_outputs = discriminator(images) + generator_loss = mnist.generator_loss(discriminator_gen_outputs) + discriminator_loss = mnist.discriminator_loss(discriminator_real_outputs, + discriminator_gen_outputs) + # Get train ops + with tf.variable_scope('generator'): + generator_train = generator_optimizer.minimize( + generator_loss, var_list=generator.variables) + with tf.variable_scope('discriminator'): + discriminator_train = discriminator_optimizer.minimize( + discriminator_loss, var_list=discriminator.variables) + + return (generator_train, discriminator_train, noise_placeholder) + + def _report(self, test_name, start, num_iters, batch_size): + avg_time = (time.time() - start) / num_iters + dev = 'gpu' if tf.test.is_gpu_available() else 'cpu' + name = 'graph_%s_%s_batch_%d_%s' % (test_name, dev, batch_size, + data_format()) + extras = {'examples_per_sec': batch_size / avg_time} + self.report_benchmark( + iters=num_iters, wall_time=avg_time, name=name, extras=extras) + + def benchmark_train(self): + for batch_size in [64, 128, 256]: + with tf.Graph().as_default(): + global_step = tf.train.get_or_create_global_step() + increment_global_step = tf.assign_add(global_step, 1) + with tf.contrib.summary.create_file_writer( + tempfile.mkdtemp(), flush_millis=SUMMARY_FLUSH_MS).as_default(), ( + tf.contrib.summary.record_summaries_every_n_global_steps( + SUMMARY_INTERVAL)): + (generator_train, discriminator_train, noise_placeholder + ) = self._create_graph(batch_size) + + with tf.Session() as sess: + tf.contrib.summary.initialize(graph=tf.get_default_graph(), + session=sess) + + sess.run(tf.global_variables_initializer()) + + num_burn, num_iters = (3, 100) + for _ in range(num_burn): + noise = np.random.uniform(-1.0, 1.0, size=[batch_size, NOISE_DIM]) + # Increment global step before evaluating summary ops to avoid + # race condition. + sess.run(increment_global_step) + sess.run([generator_train, discriminator_train, + tf.contrib.summary.all_summary_ops()], + feed_dict={noise_placeholder: noise}) + + # Run and benchmark 2 epochs + start = time.time() + for _ in range(num_iters): + noise = np.random.uniform(-1.0, 1.0, size=[batch_size, NOISE_DIM]) + sess.run(increment_global_step) + sess.run([generator_train, discriminator_train, + tf.contrib.summary.all_summary_ops()], + feed_dict={noise_placeholder: noise}) + self._report('train', start, num_iters, batch_size) + + def benchmark_generate(self): + for batch_size in [64, 128, 256]: + with tf.Graph().as_default(): + # Using random weights. This will generate garbage. + generator = mnist.Generator(data_format()) + noise_placeholder = tf.placeholder(tf.float32, + shape=[batch_size, NOISE_DIM]) + generated_images = generator(noise_placeholder) + + init = tf.global_variables_initializer() + with tf.Session() as sess: + sess.run(init) + noise = np.random.uniform(-1.0, 1.0, size=[batch_size, NOISE_DIM]) + num_burn, num_iters = (30, 1000) + for _ in range(num_burn): + sess.run(generated_images, feed_dict={noise_placeholder: noise}) + + start = time.time() + for _ in range(num_iters): + # Comparison with the eager execution benchmark in mnist_test.py + # isn't entirely fair as the time here includes the cost of copying + # the feeds from CPU memory to GPU. + sess.run(generated_images, feed_dict={noise_placeholder: noise}) + self._report('generate', start, num_iters, batch_size) + + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow/contrib/eager/python/examples/gan/mnist_test.py b/tensorflow/contrib/eager/python/examples/gan/mnist_test.py new file mode 100644 index 0000000000000000000000000000000000000000..4a3ca8d82bc2619b05a734f6d2e58431c1a45995 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/gan/mnist_test.py @@ -0,0 +1,113 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tempfile +import time + +import tensorflow as tf + +import tensorflow.contrib.eager as tfe +from tensorflow.contrib.eager.python.examples.gan import mnist + +NOISE_DIM = 100 +# Big enough so that summaries are never recorded. +# Lower this value if would like to benchmark with some summaries. +SUMMARY_INTERVAL = 10000 +SUMMARY_FLUSH_MS = 100 # Flush summaries every 100ms + + +def data_format(): + return 'channels_first' if tf.test.is_gpu_available() else 'channels_last' + + +def device(): + return '/gpu:0' if tfe.num_gpus() else '/cpu:0' + + +class MnistEagerGanBenchmark(tf.test.Benchmark): + + def _report(self, test_name, start, num_iters, batch_size): + avg_time = (time.time() - start) / num_iters + dev = 'gpu' if tfe.num_gpus() else 'cpu' + name = 'eager_%s_%s_batch_%d_%s' % (test_name, dev, batch_size, + data_format()) + extras = {'examples_per_sec': batch_size / avg_time} + self.report_benchmark( + iters=num_iters, wall_time=avg_time, name=name, extras=extras) + + def benchmark_train(self): + for batch_size in [64, 128, 256]: + # Generate some random data. + burn_batches, measure_batches = (3, 100) + burn_images = [tf.random_normal([batch_size, 784]) + for _ in range(burn_batches)] + burn_dataset = tf.data.Dataset.from_tensor_slices(burn_images) + measure_images = [tf.random_normal([batch_size, 784]) + for _ in range(measure_batches)] + measure_dataset = tf.data.Dataset.from_tensor_slices(measure_images) + + tf.train.get_or_create_global_step() + with tf.device(device()): + # Create the models and optimizers + generator = mnist.Generator(data_format()) + discriminator = mnist.Discriminator(data_format()) + with tf.variable_scope('generator'): + generator_optimizer = tf.train.AdamOptimizer(0.001) + with tf.variable_scope('discriminator'): + discriminator_optimizer = tf.train.AdamOptimizer(0.001) + + with tf.contrib.summary.create_file_writer( + tempfile.mkdtemp(), flush_millis=SUMMARY_FLUSH_MS).as_default(): + + # warm up + mnist.train_one_epoch(generator, discriminator, generator_optimizer, + discriminator_optimizer, + burn_dataset, log_interval=SUMMARY_INTERVAL, + noise_dim=NOISE_DIM) + # measure + start = time.time() + mnist.train_one_epoch(generator, discriminator, generator_optimizer, + discriminator_optimizer, + measure_dataset, log_interval=SUMMARY_INTERVAL, + noise_dim=NOISE_DIM) + self._report('train', start, measure_batches, batch_size) + + def benchmark_generate(self): + for batch_size in [64, 128, 256]: + with tf.device(device()): + # Using random weights. This will generate garbage. + generator = mnist.Generator(data_format()) + + num_burn, num_iters = (30, 1000) + for _ in range(num_burn): + noise = tf.random_uniform(shape=[batch_size, NOISE_DIM], + minval=-1., maxval=1.) + generator(noise) + + start = time.time() + for _ in range(num_iters): + noise = tf.random_uniform(shape=[batch_size, NOISE_DIM], + minval=-1., maxval=1.) + generator(noise) + self._report('generate', start, num_iters, batch_size) + + +if __name__ == '__main__': + tfe.enable_eager_execution() + tf.test.main() diff --git a/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py b/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py index d0130ebd118dbaff4f0161c8b2528764c6103e02..f4b7d67f940f5d752e1d22d643b763e2d97e987e 100644 --- a/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py +++ b/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py @@ -41,7 +41,7 @@ class LinearModel(tfe.Network): For those familiar with TensorFlow graphs, notice the absence of `tf.Session`. The `forward()` method here immediately executes and returns output values. The `loss()` method immediately compares the - output of `forward()` with the target adn returns the MSE loss value. + output of `forward()` with the target and returns the MSE loss value. The `fit()` performs gradient-descent training on the model's weights and bias. """ @@ -85,7 +85,7 @@ def fit(model, dataset, optimizer, verbose=False, logdir=None): if logdir: # Support for TensorBoard summaries. Once training has started, use: # tensorboard --logdir= - summary_writer = tf.contrib.summary.create_summary_file_writer(logdir) + summary_writer = tf.contrib.summary.create_file_writer(logdir) # Training loop. for i, (xs, ys) in enumerate(tfe.Iterator(dataset)): diff --git a/tensorflow/contrib/eager/python/examples/mnist/mnist.py b/tensorflow/contrib/eager/python/examples/mnist/mnist.py index bfb7d5a9002787f6544d383de58150661ac2bde3..82b3d3919cf0176961853d2bd85802e5dafa789e 100644 --- a/tensorflow/contrib/eager/python/examples/mnist/mnist.py +++ b/tensorflow/contrib/eager/python/examples/mnist/mnist.py @@ -40,7 +40,7 @@ class MNISTModel(tfe.Network): """MNIST Network. Network structure is equivalent to: - https://github.com/tensorflow/tensorflow/blob/r1.4/tensorflow/examples/tutorials/mnist/mnist_deep.py + https://github.com/tensorflow/tensorflow/blob/r1.5/tensorflow/examples/tutorials/mnist/mnist_deep.py and https://github.com/tensorflow/models/blob/master/tutorials/image/mnist/convolutional.py @@ -190,9 +190,9 @@ def main(_): else: train_dir = None test_dir = None - summary_writer = tf.contrib.summary.create_summary_file_writer( + summary_writer = tf.contrib.summary.create_file_writer( train_dir, flush_millis=10000) - test_summary_writer = tf.contrib.summary.create_summary_file_writer( + test_summary_writer = tf.contrib.summary.create_file_writer( test_dir, flush_millis=10000, name='test') checkpoint_prefix = os.path.join(FLAGS.checkpoint_dir, 'ckpt') diff --git a/tensorflow/contrib/eager/python/examples/mnist/mnist_test.py b/tensorflow/contrib/eager/python/examples/mnist/mnist_test.py index 205709fe2edd3c260c30a84b624e322e120edf8e..136085eba21284a42282395e54f32c33bf63b5c3 100644 --- a/tensorflow/contrib/eager/python/examples/mnist/mnist_test.py +++ b/tensorflow/contrib/eager/python/examples/mnist/mnist_test.py @@ -39,22 +39,40 @@ def random_dataset(): return tf.data.Dataset.from_tensors((images, labels)) +def train_one_epoch(defun=False): + model = mnist.MNISTModel(data_format()) + if defun: + model.call = tfe.defun(model.call) + optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01) + dataset = random_dataset() + with tf.device(device()): + tf.train.get_or_create_global_step() + mnist.train_one_epoch(model, optimizer, dataset) + + +def evaluate(defun=False): + model = mnist.MNISTModel(data_format()) + dataset = random_dataset() + if defun: + model.call = tfe.defun(model.call) + with tf.device(device()): + tf.train.get_or_create_global_step() + mnist.test(model, dataset) + + class MNISTTest(tf.test.TestCase): def testTrainOneEpoch(self): - model = mnist.MNISTModel(data_format()) - optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01) - dataset = random_dataset() - with tf.device(device()): - tf.train.get_or_create_global_step() - mnist.train_one_epoch(model, optimizer, dataset) + train_one_epoch(defun=False) def testTest(self): - model = mnist.MNISTModel(data_format()) - dataset = random_dataset() - with tf.device(device()): - tf.train.get_or_create_global_step() - mnist.test(model, dataset) + evaluate(defun=False) + + def testTrainOneEpochWithDefunCall(self): + train_one_epoch(defun=True) + + def testTestWithDefunCall(self): + evaluate(defun=True) if __name__ == "__main__": diff --git a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_graph_test.py b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_graph_test.py index 14c82c87a72457d414c4a1d3c53d4d1a68a400e6..23317886e712323f4b520000e0fd372734fc53a1 100644 --- a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_graph_test.py +++ b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_graph_test.py @@ -73,7 +73,7 @@ class ResNet50GraphTest(tf.test.TestCase): tf.train.get_or_create_global_step() logdir = tempfile.mkdtemp() with tf.contrib.summary.always_record_summaries(): - with tf.contrib.summary.create_summary_file_writer( + with tf.contrib.summary.create_file_writer( logdir, max_queue=0, name='t0').as_default(): model = resnet50.ResNet50(data_format()) diff --git a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py index 582f4837c6f3197081cb558063e963866d173f29..e2ae665a74fcf297b3174006783a7b8fed19ff03 100644 --- a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py +++ b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py @@ -64,14 +64,22 @@ def train_one_step(model, images, labels, optimizer): class ResNet50Test(tf.test.TestCase): - def test_apply(self): + def _apply(self, defun=False): device, data_format = device_and_data_format() model = resnet50.ResNet50(data_format) + if defun: + model.call = tfe.defun(model.call) with tf.device(device): images, _ = random_batch(2) output = model(images) self.assertEqual((2, 1000), output.shape) + def test_apply(self): + self._apply(defun=False) + + def test_apply_with_defun(self): + self._apply(defun=True) + def test_apply_no_top(self): device, data_format = device_and_data_format() model = resnet50.ResNet50(data_format, include_top=False) @@ -95,7 +103,7 @@ class ResNet50Test(tf.test.TestCase): model = resnet50.ResNet50(data_format) tf.train.get_or_create_global_step() logdir = tempfile.mkdtemp() - with tf.contrib.summary.create_summary_file_writer( + with tf.contrib.summary.create_file_writer( logdir, max_queue=0, name='t0').as_default(), tf.contrib.summary.always_record_summaries(): with tf.device(device): @@ -175,9 +183,11 @@ class ResNet50Benchmarks(tf.test.Benchmark): # a sync. This is a roundabout way, yes. tf.constant(1.).cpu() - def benchmark_eager_apply(self): + def _benchmark_eager_apply(self, label, defun=False): device, data_format = device_and_data_format() model = resnet50.ResNet50(data_format) + if defun: + model.call = tfe.defun(model.call) batch_size = 64 num_burn = 5 num_iters = 30 @@ -189,16 +199,23 @@ class ResNet50Benchmarks(tf.test.Benchmark): start = time.time() for _ in xrange(num_iters): model(images).cpu() - self._report('eager_apply', start, num_iters, device, batch_size, - data_format) + self._report(label, start, num_iters, device, batch_size, data_format) + + def benchmark_eager_apply(self): + self._benchmark_eager_apply('eager_apply', defun=False) + + def benchmark_eager_apply_with_defun(self): + self._benchmark_eager_apply('eager_apply_with_defun', defun=True) - def _benchmark_eager_train(self, label, make_iterator): + def _benchmark_eager_train(self, label, make_iterator, defun=False): device, data_format = device_and_data_format() for batch_size in self._train_batch_sizes(): (images, labels) = random_batch(batch_size) num_burn = 3 num_iters = 10 model = resnet50.ResNet50(data_format) + if defun: + model.call = tfe.defun(model.call) optimizer = tf.train.GradientDescentOptimizer(0.1) with tf.device(device): @@ -217,7 +234,11 @@ class ResNet50Benchmarks(tf.test.Benchmark): self._report(label, start, num_iters, device, batch_size, data_format) def benchmark_eager_train(self): - self._benchmark_eager_train('eager_train', MockIterator) + self._benchmark_eager_train('eager_train', MockIterator, defun=False) + + def benchmark_eager_train_with_defun(self): + self._benchmark_eager_train( + 'eager_train_with_defun', MockIterator, defun=True) def benchmark_eager_train_datasets(self): @@ -226,7 +247,18 @@ class ResNet50Benchmarks(tf.test.Benchmark): ds = tf.data.Dataset.from_tensors(tensors).repeat() return tfe.Iterator(ds) - self._benchmark_eager_train('eager_train_dataset', make_iterator) + self._benchmark_eager_train( + 'eager_train_dataset', make_iterator, defun=False) + + def benchmark_eager_train_datasets_with_defun(self): + + def make_iterator(tensors): + with tf.device('/device:CPU:0'): + ds = tf.data.Dataset.from_tensors(tensors).repeat() + return tfe.Iterator(ds) + + self._benchmark_eager_train( + 'eager_train_dataset_with_defun', make_iterator, defun=True) if __name__ == '__main__': diff --git a/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py b/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py index 609cbd28772c3ae8da70648ca5b1b264a8a255e2..40919f2d4cf511eb35fac954719286366aef6c7c 100644 --- a/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py +++ b/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py @@ -247,9 +247,9 @@ def main(_): log_dir = os.path.join(FLAGS.dir, "summaries") tf.gfile.MakeDirs(log_dir) - train_summary_writer = tf.contrib.summary.create_summary_file_writer( + train_summary_writer = tf.contrib.summary.create_file_writer( os.path.join(log_dir, "train"), flush_millis=10000) - test_summary_writer = tf.contrib.summary.create_summary_file_writer( + test_summary_writer = tf.contrib.summary.create_file_writer( os.path.join(log_dir, "eval"), flush_millis=10000, name="eval") with tf.device(device): diff --git a/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py b/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py index 30bb3c8ad33d38453bd96a76c7770071e24bb034..7b9637a9d58c87e93c7c0ea7173a6b88c885ee25 100644 --- a/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py +++ b/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py @@ -22,6 +22,11 @@ Usage: python ./rnn_ptb.py --data-path= Penn Treebank (PTB) dataset from: http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz """ + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + import argparse import os import sys @@ -209,7 +214,7 @@ class Datasets(object): """Load the Penn Treebank dataset. Args: - path: Path to the data/ directory of the dataset from from Tomas Mikolov's + path: Path to the data/ directory of the dataset from Tomas Mikolov's webpage - http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz """ diff --git a/tensorflow/contrib/eager/python/examples/spinn/BUILD b/tensorflow/contrib/eager/python/examples/spinn/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..a1f8a759e2a556bc219f0aa13942f293c4f34cfa --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/spinn/BUILD @@ -0,0 +1,42 @@ +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//tensorflow:internal"]) + +load("//tensorflow:tensorflow.bzl", "cuda_py_test") +load("//tensorflow:tensorflow.bzl", "py_test") + +py_library( + name = "data", + srcs = ["data.py"], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = ["//third_party/py/numpy"], +) + +py_test( + name = "data_test", + size = "small", + srcs = ["data_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":data", + "//tensorflow:tensorflow_py", + ], +) + +cuda_py_test( + name = "spinn_test", + size = "medium", + srcs = ["spinn_test.py"], + additional_deps = [ + ":data", + "//third_party/examples/eager/spinn", + "//third_party/py/numpy", + "//tensorflow:tensorflow_py", + "//tensorflow/contrib/summary:summary_test_util", + "//tensorflow/python/eager:test", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_test_lib", + ], + tags = ["no_pip"], # because spinn.py is under third_party/. +) diff --git a/tensorflow/contrib/eager/python/examples/spinn/README.md b/tensorflow/contrib/eager/python/examples/spinn/README.md new file mode 100644 index 0000000000000000000000000000000000000000..eb0637df473e22e5d39ca1b0816464cb2b7c6435 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/spinn/README.md @@ -0,0 +1,13 @@ +# SPINN: Dynamic neural network with TensorFlow eager execution + +This directory contains files supporting the +[spinn.py model in third_party/examples/eager/spinn/](../../../../../../third_party/examples/eager/spinn/spinn.py), +including + +- `data.py`: Utility library for loading and preprocessing the SNLI and GloVe + data. +- `data_test.py` and `spinn_test.py`: Unit tests for the data and model modules. + +See the [README.md in third_party/examples/eager/spinn/](../../../../../../third_party/examples/eager/spinn/README.md) +for detailed background, license and usage information regarding the SPINN code. + diff --git a/tensorflow/contrib/eager/python/examples/spinn/data.py b/tensorflow/contrib/eager/python/examples/spinn/data.py new file mode 100644 index 0000000000000000000000000000000000000000..a6e046320f78541bef4e091e97f08fd51857af83 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/spinn/data.py @@ -0,0 +1,350 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities of SNLI data and GloVe word vectors for SPINN model. + +See more details about the SNLI data set at: + https://nlp.stanford.edu/projects/snli/ + +See more details about the GloVe pretrained word embeddings at: + https://nlp.stanford.edu/projects/glove/ +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import glob +import math +import os +import random + +import numpy as np + +POSSIBLE_LABELS = ("entailment", "contradiction", "neutral") + +UNK_CODE = 0 # Code for unknown word tokens. +PAD_CODE = 1 # Code for padding tokens. + +SHIFT_CODE = 3 +REDUCE_CODE = 2 + +WORD_VECTOR_LEN = 300 # Embedding dimensions. + +LEFT_PAREN = "(" +RIGHT_PAREN = ")" +PARENTHESES = (LEFT_PAREN, RIGHT_PAREN) + + +def get_non_parenthesis_words(items): + """Get the non-parenthesis items from a SNLI parsed sentence. + + Args: + items: Data items from a parsed SNLI setence, with parentheses. E.g., + ["(", "Man", "(", "(", "(", "(", "(", "wearing", "pass", ")", ... + + Returns: + A list of non-parenthis word items, all converted to lower case. E.g., + ["man", "wearing", "pass", ... + """ + return [x.lower() for x in items if x not in PARENTHESES and x] + + +def get_shift_reduce(items): + """Obtain shift-reduce vector from a list of items from the SNLI data. + + Args: + items: Data items as a list of str, e.g., + ["(", "Man", "(", "(", "(", "(", "(", "wearing", "pass", ")", ... + + Returns: + A list of shift-reduce transitions, encoded as `SHIFT_CODE` for shift and + `REDUCE_CODE` for reduce. See code above for the values of `SHIFT_CODE` + and `REDUCE_CODE`. + """ + trans = [] + for item in items: + if item == LEFT_PAREN: + continue + elif item == RIGHT_PAREN: + trans.append(REDUCE_CODE) + else: + trans.append(SHIFT_CODE) + return trans + + +def pad_and_reverse_word_ids(sentences): + """Pad a list of sentences to the common maximum length + 1. + + Args: + sentences: A list of sentences as a list of list of integers. Each integer + is a word ID. Each list of integer corresponds to one sentence. + + Returns: + A numpy.ndarray of shape (num_sentences, max_length + 1), wherein max_length + is the maximum sentence length (in # of words). Each sentence is reversed + and then padded with an extra one at head, as required by the model. + """ + max_len = max(len(sent) for sent in sentences) + for sent in sentences: + if len(sent) < max_len: + sent.extend([PAD_CODE] * (max_len - len(sent))) + # Reverse in time order and pad an extra one. + sentences = np.fliplr(np.array(sentences, dtype=np.int64)) + sentences = np.concatenate( + [np.ones([sentences.shape[0], 1], dtype=np.int64), sentences], axis=1) + return sentences + + +def pad_transitions(sentences_transitions): + """Pad a list of shift-reduce transitions to the maximum length.""" + max_len = max(len(transitions) for transitions in sentences_transitions) + for transitions in sentences_transitions: + if len(transitions) < max_len: + transitions.extend([PAD_CODE] * (max_len - len(transitions))) + return np.array(sentences_transitions, dtype=np.int64) + + +def load_vocabulary(data_root): + """Load vocabulary from SNLI data files. + + Args: + data_root: Root directory of the data. It is assumed that the SNLI data + files have been downloaded and extracted to the "snli/snli_1.0" + subdirectory of it. + + Returns: + Vocabulary as a set of strings. + + Raises: + ValueError: If SNLI data files cannot be found. + """ + snli_path = os.path.join(data_root, "snli") + snli_glob_pattern = os.path.join(snli_path, "snli_1.0/snli_1.0_*.txt") + file_names = glob.glob(snli_glob_pattern) + if not file_names: + raise ValueError( + "Cannot find SNLI data files at %s. " + "Please download and extract SNLI data first." % snli_glob_pattern) + + print("Loading vocabulary...") + vocab = set() + for file_name in file_names: + with open(os.path.join(snli_path, file_name), "rt") as f: + for i, line in enumerate(f): + if i == 0: + continue + items = line.split("\t") + premise_words = get_non_parenthesis_words(items[1].split(" ")) + hypothesis_words = get_non_parenthesis_words(items[2].split(" ")) + vocab.update(premise_words) + vocab.update(hypothesis_words) + return vocab + + +def load_word_vectors(data_root, vocab): + """Load GloVe word vectors for words present in the vocabulary. + + Args: + data_root: Data root directory. It is assumed that the GloVe file + has been downloaded and extracted at the "glove/" subdirectory of it. + vocab: A `set` of words, representing the vocabulary. + + Returns: + 1. word2index: A dict from lower-case word to row index in the embedding + matrix, i.e, `embed` below. + 2. embed: The embedding matrix as a float32 numpy array. Its shape is + [vocabulary_size, WORD_VECTOR_LEN]. vocabulary_size is len(vocab). + WORD_VECTOR_LEN is the embedding dimension (300). + + Raises: + ValueError: If GloVe embedding file cannot be found. + """ + glove_path = os.path.join(data_root, "glove/glove.42B.300d.txt") + if not os.path.isfile(glove_path): + raise ValueError( + "Cannot find GloVe embedding file at %s. " + "Please download and extract GloVe embeddings first." % glove_path) + + print("Loading word vectors...") + + word2index = dict() + embed = [] + + embed.append([0] * WORD_VECTOR_LEN) # + embed.append([0] * WORD_VECTOR_LEN) # + word2index[""] = UNK_CODE + word2index[""] = PAD_CODE + + with open(glove_path, "rt") as f: + for line in f: + items = line.split(" ") + word = items[0] + if word in vocab and word not in word2index: + word2index[word] = len(embed) + vector = np.array([float(item) for item in items[1:]]) + assert (WORD_VECTOR_LEN,) == vector.shape + embed.append(vector) + embed = np.array(embed, dtype=np.float32) + return word2index, embed + + +def calculate_bins(length2count, min_bin_size): + """Cacluate bin boundaries given a histogram of lengths and mininum bin size. + + Args: + length2count: A `dict` mapping length to sentence count. + min_bin_size: Minimum bin size in terms of total number of sentence pairs + in the bin. + + Returns: + A `list` representing the right bin boundaries, starting from the inclusive + right boundary of the first bin. For example, if the output is + [10, 20, 35], + it means there are three bins: [1, 10], [11, 20] and [21, 35]. + """ + bounds = [] + lengths = sorted(length2count.keys()) + cum_count = 0 + for length in lengths: + cum_count += length2count[length] + if cum_count >= min_bin_size: + bounds.append(length) + cum_count = 0 + if bounds[-1] != lengths[-1]: + bounds.append(lengths[-1]) + return bounds + + +class SnliData(object): + """A split of SNLI data.""" + + def __init__(self, data_file, word2index, sentence_len_limit=-1): + """SnliData constructor. + + Args: + data_file: Full path to the data file, e.g., + "/tmp/spinn-data/snli/snli_1.0/snli_1.0.train.txt" + word2index: A dict from lower-case word to row index in the embedding + matrix (see `load_word_vectors()` for details). + sentence_len_limit: Maximum allowed sentence length (# of words). + A value of <= 0 means unlimited. Sentences longer than this limit + are currently discarded, not truncated. + """ + + self._labels = [] + self._premises = [] + self._premise_transitions = [] + self._hypotheses = [] + self._hypothesis_transitions = [] + + with open(data_file, "rt") as f: + for i, line in enumerate(f): + if i == 0: + # Skip header line. + continue + items = line.split("\t") + if items[0] not in POSSIBLE_LABELS: + continue + + premise_items = items[1].split(" ") + hypothesis_items = items[2].split(" ") + premise_words = get_non_parenthesis_words(premise_items) + hypothesis_words = get_non_parenthesis_words(hypothesis_items) + + if (sentence_len_limit > 0 and + (len(premise_words) > sentence_len_limit or + len(hypothesis_words) > sentence_len_limit)): + # TODO(cais): Maybe truncate; do not discard. + continue + + premise_ids = [ + word2index.get(word, UNK_CODE) for word in premise_words] + hypothesis_ids = [ + word2index.get(word, UNK_CODE) for word in hypothesis_words] + + self._premises.append(premise_ids) + self._hypotheses.append(hypothesis_ids) + self._premise_transitions.append(get_shift_reduce(premise_items)) + self._hypothesis_transitions.append(get_shift_reduce(hypothesis_items)) + assert (len(self._premise_transitions[-1]) == + 2 * len(premise_words) - 1) + assert (len(self._hypothesis_transitions[-1]) == + 2 * len(hypothesis_words) - 1) + + self._labels.append(POSSIBLE_LABELS.index(items[0]) + 1) + + assert len(self._labels) == len(self._premises) + assert len(self._labels) == len(self._hypotheses) + assert len(self._labels) == len(self._premise_transitions) + assert len(self._labels) == len(self._hypothesis_transitions) + + def num_batches(self, batch_size): + """Calculate number of batches given batch size.""" + return int(math.ceil(len(self._labels) / batch_size)) + + def get_generator(self, batch_size): + """Obtain a generator for batched data. + + All examples of this SnliData object are randomly shuffled, sorted + according to the maximum sentence length of the premise and hypothesis + sentences in the pair, and batched. + + Args: + batch_size: Desired batch size. + + Returns: + A generator for data batches. The generator yields a 5-tuple: + label: An array of the shape (batch_size,). + premise: An array of the shape (max_premise_len, batch_size), wherein + max_premise_len is the maximum length of the (padded) premise + sentence in the batch. + premise_transitions: An array of the shape (2 * max_premise_len -3, + batch_size). + hypothesis: Same as `premise`, but for hypothesis sentences. + hypothesis_transitions: Same as `premise_transitions`, but for + hypothesis sentences. + All the elements of the 5-tuple have dtype `int64`. + """ + # Randomly shuffle examples. + zipped = list(zip( + self._labels, self._premises, self._premise_transitions, + self._hypotheses, self._hypothesis_transitions)) + random.shuffle(zipped) + # Then sort the examples by maximum of the premise and hypothesis sentence + # lengths in the pair. During training, the batches are expected to be + # shuffled. So it is okay to leave them sorted by max length here. + (labels, premises, premise_transitions, hypotheses, + hypothesis_transitions) = zip( + *sorted(zipped, key=lambda x: max(len(x[1]), len(x[3])))) + + def _generator(): + begin = 0 + while begin < len(labels): + # The sorting above and the batching here makes sure that sentences of + # similar max lengths are batched together, minimizing the inefficiency + # due to uneven max lengths. The sentences are batched differently in + # each call to get_generator() due to the shuffling before sotring + # above. The pad_and_reverse_word_ids() and pad_transitions() functions + # take care of any remaning unevenness of the max sentence lengths. + end = min(begin + batch_size, len(labels)) + # Transpose, because the SPINN model requires time-major, instead of + # batch-major. + yield (labels[begin:end], + pad_and_reverse_word_ids(premises[begin:end]).T, + pad_transitions(premise_transitions[begin:end]).T, + pad_and_reverse_word_ids(hypotheses[begin:end]).T, + pad_transitions(hypothesis_transitions[begin:end]).T) + begin = end + return _generator diff --git a/tensorflow/contrib/eager/python/examples/spinn/data_test.py b/tensorflow/contrib/eager/python/examples/spinn/data_test.py new file mode 100644 index 0000000000000000000000000000000000000000..e4f0b37c5099e45b7e3b258b258c0a203c36b3b7 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/spinn/data_test.py @@ -0,0 +1,243 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Unit tests for SPINN data module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import shutil +import tempfile + +import tensorflow as tf + +from tensorflow.contrib.eager.python.examples.spinn import data + + +class DataTest(tf.test.TestCase): + + def setUp(self): + super(DataTest, self).setUp() + self._temp_data_dir = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self._temp_data_dir) + super(DataTest, self).tearDown() + + def testGenNonParenthesisWords(self): + seq_with_parse = ( + "( Man ( ( ( ( ( wearing pass ) ( on ( a lanyard ) ) ) and " + ") ( standing ( in ( ( a crowd ) ( of people ) ) ) ) ) . ) )") + self.assertEqual( + ["man", "wearing", "pass", "on", "a", "lanyard", "and", "standing", + "in", "a", "crowd", "of", "people", "."], + data.get_non_parenthesis_words(seq_with_parse.split(" "))) + + def testGetShiftReduce(self): + seq_with_parse = ( + "( Man ( ( ( ( ( wearing pass ) ( on ( a lanyard ) ) ) and " + ") ( standing ( in ( ( a crowd ) ( of people ) ) ) ) ) . ) )") + self.assertEqual( + [3, 3, 3, 2, 3, 3, 3, 2, 2, 2, 3, 2, 3, 3, 3, 3, 2, 3, 3, 2, 2, 2, 2, 2, + 3, 2, 2], data.get_shift_reduce(seq_with_parse.split(" "))) + + def testPadAndReverseWordIds(self): + id_sequences = [[0, 2, 3, 4, 5], + [6, 7, 8], + [9, 10, 11, 12, 13, 14, 15, 16]] + self.assertAllClose( + [[1, 1, 1, 1, 5, 4, 3, 2, 0], + [1, 1, 1, 1, 1, 1, 8, 7, 6], + [1, 16, 15, 14, 13, 12, 11, 10, 9]], + data.pad_and_reverse_word_ids(id_sequences)) + + def testPadTransitions(self): + unpadded = [[3, 3, 3, 2, 2, 2, 2], + [3, 3, 2, 2, 2]] + self.assertAllClose( + [[3, 3, 3, 2, 2, 2, 2], + [3, 3, 2, 2, 2, 1, 1]], + data.pad_transitions(unpadded)) + + def testCalculateBins(self): + length2count = { + 1: 10, + 2: 15, + 3: 25, + 4: 40, + 5: 35, + 6: 10} + self.assertEqual([2, 3, 4, 5, 6], + data.calculate_bins(length2count, 20)) + self.assertEqual([3, 4, 6], data.calculate_bins(length2count, 40)) + self.assertEqual([4, 6], data.calculate_bins(length2count, 60)) + + def testLoadVoacbulary(self): + snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0") + fake_train_file = os.path.join(snli_1_0_dir, "snli_1.0_train.txt") + fake_dev_file = os.path.join(snli_1_0_dir, "snli_1.0_dev.txt") + os.makedirs(snli_1_0_dir) + + with open(fake_train_file, "wt") as f: + f.write("gold_label\tsentence1_binary_parse\tsentence2_binary_parse\t" + "sentence1_parse\tsentence2_parse\tsentence1\tsentence2\t" + "captionID\tpairID\tlabel1\tlabel2\tlabel3\tlabel4\tlabel5\n") + f.write("neutral\t( ( Foo bar ) . )\t( ( foo baz ) . )\t" + "DummySentence1Parse\tDummySentence2Parse\t" + "Foo bar.\tfoo baz.\t" + "4705552913.jpg#2\t4705552913.jpg#2r1n\t" + "neutral\tentailment\tneutral\tneutral\tneutral\n") + with open(fake_dev_file, "wt") as f: + f.write("gold_label\tsentence1_binary_parse\tsentence2_binary_parse\t" + "sentence1_parse\tsentence2_parse\tsentence1\tsentence2\t" + "captionID\tpairID\tlabel1\tlabel2\tlabel3\tlabel4\tlabel5\n") + f.write("neutral\t( ( Quux quuz ) ? )\t( ( Corge grault ) ! )\t" + "DummySentence1Parse\tDummySentence2Parse\t" + "Quux quuz?\t.Corge grault!\t" + "4705552913.jpg#2\t4705552913.jpg#2r1n\t" + "neutral\tentailment\tneutral\tneutral\tneutral\n") + + vocab = data.load_vocabulary(self._temp_data_dir) + self.assertSetEqual( + {".", "?", "!", "foo", "bar", "baz", "quux", "quuz", "corge", "grault"}, + vocab) + + def testLoadVoacbularyWithoutFileRaisesError(self): + with self.assertRaisesRegexp(ValueError, "Cannot find SNLI data files at"): + data.load_vocabulary(self._temp_data_dir) + + os.makedirs(os.path.join(self._temp_data_dir, "snli")) + with self.assertRaisesRegexp(ValueError, "Cannot find SNLI data files at"): + data.load_vocabulary(self._temp_data_dir) + + os.makedirs(os.path.join(self._temp_data_dir, "snli/snli_1.0")) + with self.assertRaisesRegexp(ValueError, "Cannot find SNLI data files at"): + data.load_vocabulary(self._temp_data_dir) + + def testLoadWordVectors(self): + glove_dir = os.path.join(self._temp_data_dir, "glove") + os.makedirs(glove_dir) + glove_file = os.path.join(glove_dir, "glove.42B.300d.txt") + + words = [".", ",", "foo", "bar", "baz"] + with open(glove_file, "wt") as f: + for i, word in enumerate(words): + f.write("%s " % word) + for j in range(data.WORD_VECTOR_LEN): + f.write("%.5f" % (i * 0.1)) + if j < data.WORD_VECTOR_LEN - 1: + f.write(" ") + else: + f.write("\n") + + vocab = {"foo", "bar", "baz", "qux", "."} + # Notice that "qux" is not present in `words`. + word2index, embed = data.load_word_vectors(self._temp_data_dir, vocab) + + self.assertEqual(6, len(word2index)) + self.assertEqual(0, word2index[""]) + self.assertEqual(1, word2index[""]) + self.assertEqual(2, word2index["."]) + self.assertEqual(3, word2index["foo"]) + self.assertEqual(4, word2index["bar"]) + self.assertEqual(5, word2index["baz"]) + self.assertEqual((6, data.WORD_VECTOR_LEN), embed.shape) + self.assertAllClose([0.0] * data.WORD_VECTOR_LEN, embed[0, :]) + self.assertAllClose([0.0] * data.WORD_VECTOR_LEN, embed[1, :]) + self.assertAllClose([0.0] * data.WORD_VECTOR_LEN, embed[2, :]) + self.assertAllClose([0.2] * data.WORD_VECTOR_LEN, embed[3, :]) + self.assertAllClose([0.3] * data.WORD_VECTOR_LEN, embed[4, :]) + self.assertAllClose([0.4] * data.WORD_VECTOR_LEN, embed[5, :]) + + def testLoadWordVectorsWithoutFileRaisesError(self): + vocab = {"foo", "bar", "baz", "qux", "."} + with self.assertRaisesRegexp( + ValueError, "Cannot find GloVe embedding file at"): + data.load_word_vectors(self._temp_data_dir, vocab) + + os.makedirs(os.path.join(self._temp_data_dir, "glove")) + with self.assertRaisesRegexp( + ValueError, "Cannot find GloVe embedding file at"): + data.load_word_vectors(self._temp_data_dir, vocab) + + def testSnliData(self): + """Unit test for SnliData objects.""" + snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0") + fake_train_file = os.path.join(snli_1_0_dir, "snli_1.0_train.txt") + os.makedirs(snli_1_0_dir) + + # Four sentences in total. + with open(fake_train_file, "wt") as f: + f.write("gold_label\tsentence1_binary_parse\tsentence2_binary_parse\t" + "sentence1_parse\tsentence2_parse\tsentence1\tsentence2\t" + "captionID\tpairID\tlabel1\tlabel2\tlabel3\tlabel4\tlabel5\n") + f.write("neutral\t( ( Foo bar ) . )\t( ( foo . )\t" + "DummySentence1Parse\tDummySentence2Parse\t" + "Foo bar.\tfoo baz.\t" + "4705552913.jpg#2\t4705552913.jpg#2r1n\t" + "neutral\tentailment\tneutral\tneutral\tneutral\n") + f.write("contradiction\t( ( Bar foo ) . )\t( ( baz . )\t" + "DummySentence1Parse\tDummySentence2Parse\t" + "Foo bar.\tfoo baz.\t" + "4705552913.jpg#2\t4705552913.jpg#2r1n\t" + "neutral\tentailment\tneutral\tneutral\tneutral\n") + f.write("entailment\t( ( Quux quuz ) . )\t( ( grault . )\t" + "DummySentence1Parse\tDummySentence2Parse\t" + "Foo bar.\tfoo baz.\t" + "4705552913.jpg#2\t4705552913.jpg#2r1n\t" + "neutral\tentailment\tneutral\tneutral\tneutral\n") + f.write("entailment\t( ( Quuz quux ) . )\t( ( garply . )\t" + "DummySentence1Parse\tDummySentence2Parse\t" + "Foo bar.\tfoo baz.\t" + "4705552913.jpg#2\t4705552913.jpg#2r1n\t" + "neutral\tentailment\tneutral\tneutral\tneutral\n") + + glove_dir = os.path.join(self._temp_data_dir, "glove") + os.makedirs(glove_dir) + glove_file = os.path.join(glove_dir, "glove.42B.300d.txt") + + words = [".", "foo", "bar", "baz", "quux", "quuz", "grault", "garply"] + with open(glove_file, "wt") as f: + for i, word in enumerate(words): + f.write("%s " % word) + for j in range(data.WORD_VECTOR_LEN): + f.write("%.5f" % (i * 0.1)) + if j < data.WORD_VECTOR_LEN - 1: + f.write(" ") + else: + f.write("\n") + + vocab = data.load_vocabulary(self._temp_data_dir) + word2index, _ = data.load_word_vectors(self._temp_data_dir, vocab) + + train_data = data.SnliData(fake_train_file, word2index) + self.assertEqual(4, train_data.num_batches(1)) + self.assertEqual(2, train_data.num_batches(2)) + self.assertEqual(2, train_data.num_batches(3)) + self.assertEqual(1, train_data.num_batches(4)) + + generator = train_data.get_generator(2)() + for i in range(2): + label, prem, prem_trans, hypo, hypo_trans = next(generator) + self.assertEqual(2, len(label)) + self.assertEqual((4, 2), prem.shape) + self.assertEqual((5, 2), prem_trans.shape) + self.assertEqual((3, 2), hypo.shape) + self.assertEqual((3, 2), hypo_trans.shape) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py b/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py new file mode 100644 index 0000000000000000000000000000000000000000..84e25cf81a2223800c47994b26d000caddee6b01 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py @@ -0,0 +1,409 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import gc +import glob +import os +import shutil +import tempfile +import time + +import numpy as np +import tensorflow as tf + +# pylint: disable=g-bad-import-order +import tensorflow.contrib.eager as tfe +from tensorflow.contrib.eager.python.examples.spinn import data +from third_party.examples.eager.spinn import spinn +from tensorflow.contrib.summary import summary_test_util +from tensorflow.python.eager import test +from tensorflow.python.framework import test_util +# pylint: enable=g-bad-import-order + + +def _generate_synthetic_snli_data_batch(sequence_length, + batch_size, + vocab_size): + """Generate a fake batch of SNLI data for testing.""" + with tf.device("cpu:0"): + labels = tf.random_uniform([batch_size], minval=1, maxval=4, dtype=tf.int64) + prem = tf.random_uniform( + (sequence_length, batch_size), maxval=vocab_size, dtype=tf.int64) + prem_trans = tf.constant(np.array( + [[3, 3, 2, 3, 3, 3, 2, 2, 2, 3, 3, 3, + 2, 3, 3, 2, 2, 3, 3, 3, 2, 2, 2, 2, + 3, 2, 2]] * batch_size, dtype=np.int64).T) + hypo = tf.random_uniform( + (sequence_length, batch_size), maxval=vocab_size, dtype=tf.int64) + hypo_trans = tf.constant(np.array( + [[3, 3, 2, 3, 3, 3, 2, 2, 2, 3, 3, 3, + 2, 3, 3, 2, 2, 3, 3, 3, 2, 2, 2, 2, + 3, 2, 2]] * batch_size, dtype=np.int64).T) + if tfe.num_gpus(): + labels = labels.gpu() + prem = prem.gpu() + prem_trans = prem_trans.gpu() + hypo = hypo.gpu() + hypo_trans = hypo_trans.gpu() + return labels, prem, prem_trans, hypo, hypo_trans + + +def _test_spinn_config(d_embed, d_out, logdir=None): + config_tuple = collections.namedtuple( + "Config", ["d_hidden", "d_proj", "d_tracker", "predict", + "embed_dropout", "mlp_dropout", "n_mlp_layers", "d_mlp", + "d_out", "projection", "lr", "batch_size", "epochs", + "force_cpu", "logdir", "log_every", "dev_every", "save_every", + "lr_decay_every", "lr_decay_by"]) + return config_tuple( + d_hidden=d_embed, + d_proj=d_embed * 2, + d_tracker=8, + predict=False, + embed_dropout=0.1, + mlp_dropout=0.1, + n_mlp_layers=2, + d_mlp=32, + d_out=d_out, + projection=True, + lr=2e-2, + batch_size=2, + epochs=10, + force_cpu=False, + logdir=logdir, + log_every=1, + dev_every=2, + save_every=2, + lr_decay_every=1, + lr_decay_by=0.75) + + +class SpinnTest(test_util.TensorFlowTestCase): + + def setUp(self): + super(SpinnTest, self).setUp() + self._test_device = "gpu:0" if tfe.num_gpus() else "cpu:0" + self._temp_data_dir = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self._temp_data_dir) + super(SpinnTest, self).tearDown() + + def testBundle(self): + with tf.device(self._test_device): + lstm_iter = [np.array([[0, 1], [2, 3]], dtype=np.float32), + np.array([[0, -1], [-2, -3]], dtype=np.float32), + np.array([[0, 2], [4, 6]], dtype=np.float32), + np.array([[0, -2], [-4, -6]], dtype=np.float32)] + out = spinn._bundle(lstm_iter) + + self.assertEqual(2, len(out)) + self.assertEqual(tf.float32, out[0].dtype) + self.assertEqual(tf.float32, out[1].dtype) + self.assertAllEqual(np.array([[0, 2, 0, -2, 0, 4, 0, -4]]).T, + out[0].numpy()) + self.assertAllEqual(np.array([[1, 3, -1, -3, 2, 6, -2, -6]]).T, + out[1].numpy()) + + def testUnbunbdle(self): + with tf.device(self._test_device): + state = [np.array([[0, 1, 2], [3, 4, 5]], dtype=np.float32), + np.array([[0, -1, -2], [-3, -4, -5]], dtype=np.float32)] + out = spinn._unbundle(state) + + self.assertEqual(2, len(out)) + self.assertEqual(tf.float32, out[0].dtype) + self.assertEqual(tf.float32, out[1].dtype) + self.assertAllEqual(np.array([[0, 1, 2, 0, -1, -2]]), + out[0].numpy()) + self.assertAllEqual(np.array([[3, 4, 5, -3, -4, -5]]), + out[1].numpy()) + + def testReducer(self): + with tf.device(self._test_device): + batch_size = 3 + size = 10 + tracker_size = 8 + reducer = spinn.Reducer(size, tracker_size=tracker_size) + + left_in = [] + right_in = [] + tracking = [] + for _ in range(batch_size): + left_in.append(tf.random_normal((1, size * 2))) + right_in.append(tf.random_normal((1, size * 2))) + tracking.append(tf.random_normal((1, tracker_size * 2))) + + out = reducer(left_in, right_in, tracking=tracking) + self.assertEqual(batch_size, len(out)) + self.assertEqual(tf.float32, out[0].dtype) + self.assertEqual((1, size * 2), out[0].shape) + + def testReduceTreeLSTM(self): + with tf.device(self._test_device): + size = 10 + tracker_size = 8 + reducer = spinn.Reducer(size, tracker_size=tracker_size) + + lstm_in = np.array([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9], + [0, -1, -2, -3, -4, -5, -6, -7, -8, -9]], + dtype=np.float32) + c1 = np.array([[0, 1], [2, 3]], dtype=np.float32) + c2 = np.array([[0, -1], [-2, -3]], dtype=np.float32) + + h, c = reducer._tree_lstm(c1, c2, lstm_in) + self.assertEqual(tf.float32, h.dtype) + self.assertEqual(tf.float32, c.dtype) + self.assertEqual((2, 2), h.shape) + self.assertEqual((2, 2), c.shape) + + def testTracker(self): + with tf.device(self._test_device): + batch_size = 2 + size = 10 + tracker_size = 8 + buffer_length = 18 + stack_size = 3 + + tracker = spinn.Tracker(tracker_size, False) + tracker.reset_state() + + # Create dummy inputs for testing. + bufs = [] + buf = [] + for _ in range(buffer_length): + buf.append(tf.random_normal((batch_size, size * 2))) + bufs.append(buf) + self.assertEqual(1, len(bufs)) + self.assertEqual(buffer_length, len(bufs[0])) + self.assertEqual((batch_size, size * 2), bufs[0][0].shape) + + stacks = [] + stack = [] + for _ in range(stack_size): + stack.append(tf.random_normal((batch_size, size * 2))) + stacks.append(stack) + self.assertEqual(1, len(stacks)) + self.assertEqual(3, len(stacks[0])) + self.assertEqual((batch_size, size * 2), stacks[0][0].shape) + + for _ in range(2): + out1, out2 = tracker(bufs, stacks) + self.assertIsNone(out2) + self.assertEqual(batch_size, len(out1)) + self.assertEqual(tf.float32, out1[0].dtype) + self.assertEqual((1, tracker_size * 2), out1[0].shape) + + self.assertEqual(tf.float32, tracker.state.c.dtype) + self.assertEqual((batch_size, tracker_size), tracker.state.c.shape) + self.assertEqual(tf.float32, tracker.state.h.dtype) + self.assertEqual((batch_size, tracker_size), tracker.state.h.shape) + + def testSPINN(self): + with tf.device(self._test_device): + embedding_dims = 10 + d_tracker = 8 + sequence_length = 15 + num_transitions = 27 + + config_tuple = collections.namedtuple( + "Config", ["d_hidden", "d_proj", "d_tracker", "predict"]) + config = config_tuple( + embedding_dims, embedding_dims * 2, d_tracker, False) + s = spinn.SPINN(config) + + # Create some fake data. + buffers = tf.random_normal((sequence_length, 1, config.d_proj)) + transitions = tf.constant( + [[3], [3], [2], [3], [3], [3], [2], [2], [2], [3], [3], [3], + [2], [3], [3], [2], [2], [3], [3], [3], [2], [2], [2], [2], + [3], [2], [2]], dtype=tf.int64) + self.assertEqual(tf.int64, transitions.dtype) + self.assertEqual((num_transitions, 1), transitions.shape) + + out = s(buffers, transitions, training=True) + self.assertEqual(tf.float32, out.dtype) + self.assertEqual((1, embedding_dims), out.shape) + + def testSNLIClassifierAndTrainer(self): + with tf.device(self._test_device): + vocab_size = 40 + batch_size = 2 + d_embed = 10 + sequence_length = 15 + d_out = 4 + + config = _test_spinn_config(d_embed, d_out) + + # Create fake embedding matrix. + embed = tf.random_normal((vocab_size, d_embed)) + + model = spinn.SNLIClassifier(config, embed) + trainer = spinn.SNLIClassifierTrainer(model, config.lr) + + (labels, prem, prem_trans, hypo, + hypo_trans) = _generate_synthetic_snli_data_batch(sequence_length, + batch_size, + vocab_size) + + # Invoke model under non-training mode. + logits = model(prem, prem_trans, hypo, hypo_trans, training=False) + self.assertEqual(tf.float32, logits.dtype) + self.assertEqual((batch_size, d_out), logits.shape) + + # Invoke model under training model. + logits = model(prem, prem_trans, hypo, hypo_trans, training=True) + self.assertEqual(tf.float32, logits.dtype) + self.assertEqual((batch_size, d_out), logits.shape) + + # Calculate loss. + loss1 = trainer.loss(labels, logits) + self.assertEqual(tf.float32, loss1.dtype) + self.assertEqual((), loss1.shape) + + loss2, logits = trainer.train_batch( + labels, prem, prem_trans, hypo, hypo_trans) + self.assertEqual(tf.float32, loss2.dtype) + self.assertEqual((), loss2.shape) + self.assertEqual(tf.float32, logits.dtype) + self.assertEqual((batch_size, d_out), logits.shape) + # Training on the batch should have led to a change in the loss value. + self.assertNotEqual(loss1.numpy(), loss2.numpy()) + + def testTrainSpinn(self): + """Test with fake toy SNLI data and GloVe vectors.""" + + # 1. Create and load a fake SNLI data file and a fake GloVe embedding file. + snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0") + fake_train_file = os.path.join(snli_1_0_dir, "snli_1.0_train.txt") + os.makedirs(snli_1_0_dir) + + # Four sentences in total. + with open(fake_train_file, "wt") as f: + f.write("gold_label\tsentence1_binary_parse\tsentence2_binary_parse\t" + "sentence1_parse\tsentence2_parse\tsentence1\tsentence2\t" + "captionID\tpairID\tlabel1\tlabel2\tlabel3\tlabel4\tlabel5\n") + f.write("neutral\t( ( Foo bar ) . )\t( ( foo . )\t" + "DummySentence1Parse\tDummySentence2Parse\t" + "Foo bar.\tfoo baz.\t" + "4705552913.jpg#2\t4705552913.jpg#2r1n\t" + "neutral\tentailment\tneutral\tneutral\tneutral\n") + f.write("contradiction\t( ( Bar foo ) . )\t( ( baz . )\t" + "DummySentence1Parse\tDummySentence2Parse\t" + "Foo bar.\tfoo baz.\t" + "4705552913.jpg#2\t4705552913.jpg#2r1n\t" + "neutral\tentailment\tneutral\tneutral\tneutral\n") + f.write("entailment\t( ( Quux quuz ) . )\t( ( grault . )\t" + "DummySentence1Parse\tDummySentence2Parse\t" + "Foo bar.\tfoo baz.\t" + "4705552913.jpg#2\t4705552913.jpg#2r1n\t" + "neutral\tentailment\tneutral\tneutral\tneutral\n") + f.write("entailment\t( ( Quuz quux ) . )\t( ( garply . )\t" + "DummySentence1Parse\tDummySentence2Parse\t" + "Foo bar.\tfoo baz.\t" + "4705552913.jpg#2\t4705552913.jpg#2r1n\t" + "neutral\tentailment\tneutral\tneutral\tneutral\n") + + glove_dir = os.path.join(self._temp_data_dir, "glove") + os.makedirs(glove_dir) + glove_file = os.path.join(glove_dir, "glove.42B.300d.txt") + + words = [".", "foo", "bar", "baz", "quux", "quuz", "grault", "garply"] + with open(glove_file, "wt") as f: + for i, word in enumerate(words): + f.write("%s " % word) + for j in range(data.WORD_VECTOR_LEN): + f.write("%.5f" % (i * 0.1)) + if j < data.WORD_VECTOR_LEN - 1: + f.write(" ") + else: + f.write("\n") + + vocab = data.load_vocabulary(self._temp_data_dir) + word2index, embed = data.load_word_vectors(self._temp_data_dir, vocab) + + train_data = data.SnliData(fake_train_file, word2index) + dev_data = data.SnliData(fake_train_file, word2index) + test_data = data.SnliData(fake_train_file, word2index) + print(embed) + + # 2. Create a fake config. + config = _test_spinn_config( + data.WORD_VECTOR_LEN, 4, + logdir=os.path.join(self._temp_data_dir, "logdir")) + + # 3. Test training of a SPINN model. + spinn.train_spinn(embed, train_data, dev_data, test_data, config) + + # 4. Load train loss values from the summary files and verify that they + # decrease with training. + summary_file = glob.glob(os.path.join(config.logdir, "events.out.*"))[0] + events = summary_test_util.events_from_file(summary_file) + train_losses = [event.summary.value[0].simple_value for event in events + if event.summary.value + and event.summary.value[0].tag == "train/loss"] + self.assertEqual(config.epochs, len(train_losses)) + self.assertLess(train_losses[-1], train_losses[0]) + + +class EagerSpinnSNLIClassifierBenchmark(test.Benchmark): + + def benchmarkEagerSpinnSNLIClassifier(self): + test_device = "gpu:0" if tfe.num_gpus() else "cpu:0" + with tf.device(test_device): + burn_in_iterations = 2 + benchmark_iterations = 10 + + vocab_size = 1000 + batch_size = 128 + sequence_length = 15 + d_embed = 200 + d_out = 4 + + embed = tf.random_normal((vocab_size, d_embed)) + + config = _test_spinn_config(d_embed, d_out) + model = spinn.SNLIClassifier(config, embed) + trainer = spinn.SNLIClassifierTrainer(model, config.lr) + + (labels, prem, prem_trans, hypo, + hypo_trans) = _generate_synthetic_snli_data_batch(sequence_length, + batch_size, + vocab_size) + + for _ in range(burn_in_iterations): + trainer.train_batch(labels, prem, prem_trans, hypo, hypo_trans) + + gc.collect() + start_time = time.time() + for _ in xrange(benchmark_iterations): + trainer.train_batch(labels, prem, prem_trans, hypo, hypo_trans) + wall_time = time.time() - start_time + # Named "examples"_per_sec to conform with other benchmarks. + extras = {"examples_per_sec": benchmark_iterations / wall_time} + self.report_benchmark( + name="Eager_SPINN_SNLIClassifier_Benchmark", + iters=benchmark_iterations, + wall_time=wall_time, + extras=extras) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/eager/python/g3doc/guide.md b/tensorflow/contrib/eager/python/g3doc/guide.md index 147b7047f42b7ccba5829b61370e82e217ce5838..0095ffa0db99d46d25654d73504d0d7d41c18b6f 100644 --- a/tensorflow/contrib/eager/python/g3doc/guide.md +++ b/tensorflow/contrib/eager/python/g3doc/guide.md @@ -757,7 +757,7 @@ For example, to record summaries once every 100 global steps, use: ```python tf.train.get_or_create_global_step() # Ensuring the global step variable exists -writer = tf.contrib.summary.create_summary_file_writer(logdir) +writer = tf.contrib.summary.create_file_writer(logdir) for _ in range(iterations): with writer.as_default(): diff --git a/tensorflow/contrib/eager/python/metrics_impl.py b/tensorflow/contrib/eager/python/metrics_impl.py index 2f8016ede3caee6dbb6fd8f5226f1464b5c3976b..bf029ca5f9dddb152274da6a1cc96bea7981d8fd 100644 --- a/tensorflow/contrib/eager/python/metrics_impl.py +++ b/tensorflow/contrib/eager/python/metrics_impl.py @@ -49,6 +49,20 @@ class Metric(object): Example use with graph execution: + ```python + m = SomeMetric(...) + inputs = ... # Some tensors to compute the metric on. + m_update = m(inputs) + # Variables defined in first call, so get the initialization op afterwards. + m_init = m.init_variables() # or tf.global_variables_initializer() + m_result = m.result() + with tf.Session() as sess: + sess.run(m_init) + for input in ...: + sess.run(m_update) + print(sess.run(m_result)) + ``` + Example use with graph execution with placeholders and feed_dict: ```python m = SomeMetric(...) m_placeholder = tf.placeholder(...) @@ -107,6 +121,7 @@ class Metric(object): """Returns op to execute to update this metric for these inputs. Returns None if eager execution is enabled. + Returns a graph-mode function if graph execution is enabled. Args: *args: @@ -183,6 +198,13 @@ class Metric(object): """Computes and returns a final value for the metric.""" raise NotImplementedError("Metrics must define a result() member function") + def value(self): + """In graph mode returns the result Tensor while in eager the callable.""" + if context.in_graph_mode(): + return self.result() + else: + return self.result + # We can support two different strategies of for doing data-parallel # distributed metric computations: # * Put metric variables on the first device and rely on small diff --git a/tensorflow/contrib/eager/python/metrics_test.py b/tensorflow/contrib/eager/python/metrics_test.py index 96eb1b4f2a0e4c4af1f3310a2801b1b6aee285d6..9cf34fd9b2dcf1b123cacc6863af817419eda007 100644 --- a/tensorflow/contrib/eager/python/metrics_test.py +++ b/tensorflow/contrib/eager/python/metrics_test.py @@ -27,6 +27,7 @@ from tensorflow.python.eager import context from tensorflow.python.eager import test from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.training import training_util @@ -67,7 +68,7 @@ class MetricsTest(test.TestCase): m([1, 10, 100]) training_util.get_or_create_global_step() logdir = tempfile.mkdtemp() - with summary_ops.create_summary_file_writer( + with summary_ops.create_file_writer( logdir, max_queue=0, name="t0").as_default(), summary_ops.always_record_summaries(): m.result() # As a side-effect will write summaries. @@ -137,7 +138,7 @@ class MetricsTest(test.TestCase): self.assertEqual(m1.name, "has space") self.assertEqual(m1.numer.name, "has_space/numer:0") - def testGraph(self): + def testGraphWithPlaceholder(self): with context.graph_mode(), self.test_session() as sess: m = metrics.Mean() p = array_ops.placeholder(dtypes.float32) @@ -153,6 +154,22 @@ class MetricsTest(test.TestCase): sess.run(accumulate, feed_dict={p: 7}) self.assertAllEqual(m.result().eval(), 7) + @test_util.run_in_graph_and_eager_modes() + def testGraphAndEagerTensor(self): + m = metrics.Mean() + inputs = ops.convert_to_tensor([1.0, 2.0]) + accumulate = m(inputs) + result = m.result() + self.evaluate(m.init_variables()) + self.evaluate(accumulate) + self.assertEqual(self.evaluate(result), 1.5) + # Second init resets all the variables. + self.evaluate(m.init_variables()) + inputs = ops.convert_to_tensor([2.0, 3.0]) + self.evaluate(m(inputs)) + value = m.value() + self.assertEqual(self.evaluate(value), 2.5) + def testTwoMeansGraph(self): # Verify two metrics with the same class and name don't # accidentally share state. diff --git a/tensorflow/contrib/eager/python/network.py b/tensorflow/contrib/eager/python/network.py index 0388aaa8495f380595b2635529bc2e33e808b06f..e3c13cbd2e8ccd2ab79da74e0e97905c6ed5c02d 100644 --- a/tensorflow/contrib/eager/python/network.py +++ b/tensorflow/contrib/eager/python/network.py @@ -451,8 +451,30 @@ class Network(base.Layer): "at https://github.com/tensorflow/tensorflow/issues/new if this is " "important to you") + def add_loss(self, losses, inputs=None): + raise RuntimeError( + "add_loss is not supported in Network class yet. Please file an issue " + "at https://github.com/tensorflow/tensorflow/issues/new if this is " + "important to you") + + @property + def losses(self): + """Gather losses from `Layer`s in the `Network`. + + Note that when executing eagerly, `Layer.losses` evaluates + regularizers. When using graph execution, variable regularization ops have + already been created and are simply returned here. + + Returns: + A list of tensors. + """ + layer_losses = [] + for layer in self.layers: + layer_losses.extend(layer.losses) + return layer_losses + # TODO(allenl): Support other Layer methods needed for graph mode, such as for - # losses and updates + # updates class Sequential(Network): diff --git a/tensorflow/contrib/eager/python/network_test.py b/tensorflow/contrib/eager/python/network_test.py index e7835a63e6db926aa2d4b6c76c681c8a301757bd..8e6b947e5cb28910bcb4877aa66150992a8d6445 100644 --- a/tensorflow/contrib/eager/python/network_test.py +++ b/tensorflow/contrib/eager/python/network_test.py @@ -19,6 +19,7 @@ from __future__ import print_function import gc from tensorflow.contrib.eager.python import network +from tensorflow.contrib.layers.python.layers import regularizers from tensorflow.python.eager import context from tensorflow.python.eager import function from tensorflow.python.eager import test @@ -45,6 +46,22 @@ class MyNetwork(network.Network): return self.l1(x) +class RegularizedNetwork(network.Network): + + def __init__(self): + super(RegularizedNetwork, self).__init__() + self.l1 = self.track_layer(core.Dense( + 1, + bias_regularizer=regularizers.l1_regularizer(2.0), + kernel_regularizer=regularizers.l1_regularizer(2.0))) + self.l2 = self.track_layer(core.Dense( + 1, + bias_regularizer=regularizers.l1_regularizer(2.0))) + + def call(self, values): + return self.l2(self.l1(values)) + + class NetworkTest(test.TestCase): def _save_modify_load_network_built(self, net, global_step=None): @@ -88,15 +105,13 @@ class NetworkTest(test.TestCase): result = net(constant_op.constant([[2.0]])) self.assertEqual(34.0, self.evaluate(result)) - # TODO(akshayka): This test should be changed once an API for compiling - # `call` into a defun is implemented. def testReplacingNetworkCallWithDefun(self): net = MyNetwork(name="abcd") + net.call = function.defun(net.call) x = constant_op.constant([[2.0]]) net(x) # Force variables to be created. self.evaluate(net.trainable_variables[0].assign([[17.0]])) - net.call = function.defun(net.call) result = net(x) # Build and execute the TensorFlow function self.assertEqual(34.0, self.evaluate(result)) @@ -484,6 +499,18 @@ class NetworkTest(test.TestCase): _check_op_prefixes(expected_prefix="my_network_1/dense/", checked_ops=checked_ops) + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) + def testVariableRegularizers(self): + net = RegularizedNetwork() + net(constant_op.constant([[1.]])) + self.evaluate(net.variables[0].assign([[2.]])) + self.evaluate(net.variables[1].assign([3.])) + self.evaluate(net.variables[2].assign([[-2.]])) + self.evaluate(net.variables[3].assign([4.])) + self.assertAllEqual([4., 6., 8.], self.evaluate(net.losses)) + self.evaluate(net.variables[3].assign([5.])) + self.assertAllEqual([4., 6., 10.], self.evaluate(net.losses)) + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) def testDuplicateNameError(self): one = constant_op.constant([[1.]]) diff --git a/tensorflow/contrib/eager/python/summary_writer.py b/tensorflow/contrib/eager/python/summary_writer.py deleted file mode 100644 index 5d8c41b545b3c9fd03af85f302ba05a394f085a4..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/eager/python/summary_writer.py +++ /dev/null @@ -1,242 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""TensorBoard Summary Writer for TensorFlow Eager Execution.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import uuid - -from tensorflow.contrib.summary import gen_summary_ops -from tensorflow.python.eager import context -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.ops import init_ops -from tensorflow.python.ops import resource_variable_ops -from tensorflow.python.ops import state_ops -from tensorflow.python.ops import summary_op_util -from tensorflow.python.ops import variable_scope - - -def _maybe_cpu(v): - if isinstance(v, (ops.EagerTensor, ops.Tensor)): - return v.cpu() - else: - return v - - -def _summary_writer_function(name, tensor, function, family=None): - def record(): - with summary_op_util.summary_scope( - name, family, values=[tensor]) as (tag, scope): - function(tag, scope) - return True - return record - - -class SummaryWriter(object): - """Writes summaries for TensorBoard, compatible with eager execution. - - This class is the supported way of writing TensorBoard summaries under - eager execution. - """ - - _CPU_DEVICE = "cpu:0" - - def __init__(self, - logdir, - max_queue=10, - flush_secs=120, - filename_suffix=""): - """Summary writer for TensorBoard, compatible with eager execution. - - If necessary, multiple instances of `SummaryWriter` can be created, with - distinct `logdir`s and `name`s. Each `SummaryWriter` instance will retain - its independent `global_step` counter and data writing destination. - - Example: - ```python - writer = tfe.SummaryWriter("my_model") - - # ... Code that sets up the model and data batches ... - - for _ in xrange(train_iters): - loss = model.train_batch(batch) - writer.scalar("loss", loss) - writer.step() - ``` - - Args: - logdir: Directory in which summary files will be written. - max_queue: Number of summary items to buffer before flushing to - filesystem. If 0, summaries will be flushed immediately. - flush_secs: Number of secondsbetween forced commits to disk. - filename_suffix: Suffix of the event protobuf files in which the summary - data are stored. - - Raises: - ValueError: If this constructor is called not under eager execution. - """ - # TODO(apassos, ashankar): Make this class and the underlying - # contrib.summary_ops compatible with graph model and remove this check. - if not context.in_eager_mode(): - raise ValueError( - "Use of SummaryWriter is currently supported only with eager " - "execution enabled. File an issue at " - "https://github.com/tensorflow/tensorflow/issues/new to express " - "interest in fixing this.") - - # TODO(cais): Consider adding name keyword argument, which if None or empty, - # will register the global global_step that training_util.get_global_step() - # can find. - with context.device(self._CPU_DEVICE): - self._name = uuid.uuid4().hex - self._global_step = 0 - self._global_step_tensor = variable_scope.get_variable( - "global_step/summary_writer/" + self._name, - shape=[], dtype=dtypes.int64, - initializer=init_ops.zeros_initializer()) - self._global_step_dirty = False - self._resource = gen_summary_ops.summary_writer(shared_name=self._name) - gen_summary_ops.create_summary_file_writer( - self._resource, logdir, max_queue, flush_secs, filename_suffix) - # Delete the resource when this object is deleted - self._resource_deleter = resource_variable_ops.EagerResourceDeleter( - handle=self._resource, handle_device=self._CPU_DEVICE) - - def step(self): - """Increment the global step counter of this SummaryWriter instance.""" - self._global_step += 1 - self._global_step_dirty = True - - @property - def global_step(self): - """Obtain the current global_step value of this SummaryWriter instance. - - Returns: - An `int` representing the current value of the global_step of this - `SummaryWriter` instance. - """ - return self._global_step - - def _update_global_step_tensor(self): - with context.device(self._CPU_DEVICE): - if self._global_step_dirty: - self._global_step_dirty = False - return state_ops.assign(self._global_step_tensor, self._global_step) - else: - return self._global_step_tensor - - def generic(self, name, tensor, metadata, family=None): - """Write a generic-type summary. - - Args: - name: A name for the generated node. Will also serve as the series name in - TensorBoard. - tensor: A `Tensor` or compatible value type containing the value of the - summary. - metadata: Metadata about the summary. - family: Optional; if provided, used as the prefix of the summary tag name, - which controls the tab name used for display on Tensorboard. - """ - with context.device(self._CPU_DEVICE): - with summary_op_util.summary_scope( - name, family, values=[tensor]) as (tag, scope): - gen_summary_ops.write_summary( - self._resource, - self._update_global_step_tensor(), - _maybe_cpu(tensor), - tag, - _maybe_cpu(metadata), - name=scope) - - def scalar(self, name, tensor, family=None): - """Write a scalar summary. - - Args: - name: A name for the generated node. Will also serve as the series name in - TensorBoard. - tensor: A real numeric `Tensor` or compatible value type containing a - single value. - family: Optional; if provided, used as the prefix of the summary tag name, - which controls the tab name used for display on Tensorboard. - - Returns: - A summary writer function for scalars. - """ - with context.device(self._CPU_DEVICE): - with summary_op_util.summary_scope( - name, family, values=[tensor]) as (tag, scope): - gen_summary_ops.write_scalar_summary( - self._resource, self._update_global_step_tensor(), - tag, _maybe_cpu(tensor), name=scope) - - def histogram(self, name, tensor, family=None): - """Write a histogram summary. - - Args: - name: A name for the generated node. Will also serve as a series name in - TensorBoard. - tensor: A real numeric `Tensor` or compatible value type. Any shape. - Values to use to build the histogram. - family: Optional; if provided, used as the prefix of the summary tag name, - which controls the tab name used for display on Tensorboard. - """ - with context.device(self._CPU_DEVICE): - with summary_op_util.summary_scope( - name, family, values=[tensor]) as (tag, scope): - gen_summary_ops.write_histogram_summary( - self._resource, self._update_global_step_tensor(), - tag, _maybe_cpu(tensor), name=scope) - - def image(self, name, tensor, bad_color=None, max_images=3, family=None): - """Write an image summary.""" - with context.device(self._CPU_DEVICE): - if bad_color is None: - bad_color_ = constant_op.constant([255, 0, 0, 255], dtype=dtypes.uint8) - with summary_op_util.summary_scope( - name, family, values=[tensor]) as (tag, scope): - gen_summary_ops.write_image_summary( - self._resource, self._update_global_step_tensor(), - tag, _maybe_cpu(tensor), bad_color_, max_images, - name=scope) - - def audio(self, name, tensor, sample_rate, max_outputs, family=None): - """Write an audio summary. - - Args: - name: A name for the generated node. Will also serve as a series name in - TensorBoard. - tensor: A 3-D `float32` `Tensor` of shape `[batch_size, frames, channels]` - or a 2-D `float32` `Tensor` of shape `[batch_size, frames]`, or - compatible value type. - sample_rate: A Scalar `float32` `Tensor` indicating the sample rate of the - signal in hertz. - max_outputs: Max number of batch elements to generate audio for. - family: Optional; if provided, used as the prefix of the summary tag name, - which controls the tab name used for display on Tensorboard. - """ - with context.device(self._CPU_DEVICE): - with summary_op_util.summary_scope( - name, family, values=[tensor]) as (tag, scope): - gen_summary_ops.write_audio_summary( - self._resource, self._update_global_step_tensor(), - tag, - _maybe_cpu(tensor), - sample_rate=_maybe_cpu(sample_rate), - max_outputs=max_outputs, - name=scope) diff --git a/tensorflow/contrib/eager/python/summary_writer_test.py b/tensorflow/contrib/eager/python/summary_writer_test.py deleted file mode 100644 index 5ebb36d04fcba8f4558fa1c09716314af42f559f..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/eager/python/summary_writer_test.py +++ /dev/null @@ -1,150 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Unit tests for eager execution SummaryWriter.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import os -import shutil -import tempfile - -import numpy as np - -from tensorflow.contrib.eager.python import summary_writer -from tensorflow.core.util import event_pb2 -from tensorflow.python.eager import context -from tensorflow.python.eager import test -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.lib.io import tf_record -from tensorflow.python.platform import gfile - - -class SummaryWriterTest(test.TestCase): - - def setUp(self): - super(SummaryWriterTest, self).setUp() - self._test_device = "gpu:0" if context.num_gpus() else "cpu:0" - self._tmp_logdir = tempfile.mkdtemp() - with context.device(self._test_device): - # Use max_queue=0 so that summaries are immediately flushed to filesystem, - # making testing easier. - self._writer = summary_writer.SummaryWriter(self._tmp_logdir, max_queue=0) - - def tearDown(self): - if os.path.isdir(self._tmp_logdir): - shutil.rmtree(self._tmp_logdir) - super(SummaryWriterTest, self).tearDown() - - def _readLastEvent(self, logdir=None): - if not logdir: - logdir = self._tmp_logdir - files = [f for f in gfile.ListDirectory(logdir) - if not gfile.IsDirectory(os.path.join(logdir, f))] - file_path = os.path.join(logdir, files[0]) - records = list(tf_record.tf_record_iterator(file_path)) - event = event_pb2.Event() - event.ParseFromString(records[-1]) - return event - - def testGlobalStep(self): - with context.device(self._test_device): - orig_step = self._writer.global_step - self._writer.step() - self.assertEqual(orig_step + 1, self._writer.global_step) - self.assertEqual(orig_step + 1, self._writer.global_step) - self._writer.step() - self._writer.step() - self.assertEqual(orig_step + 3, self._writer.global_step) - - def testGenericSummary(self): - with context.device(self._test_device): - x = constant_op.constant(1337.0) - with context.device("cpu:0"): - metadata = constant_op.constant("foo") - self._writer.generic("x", x, metadata) - event = self._readLastEvent() - self.assertEqual("x", event.summary.value[0].tag) - - def testScalarSummary(self): - with context.device(self._test_device): - x = constant_op.constant(1337.0) - self._writer.scalar("x", x) - event = self._readLastEvent() - self.assertTrue("x", event.summary.value[0].tag) - self.assertEqual(1337.0, event.summary.value[0].simple_value) - - def testHistogramSummary(self): - with context.device(self._test_device): - y = constant_op.constant([1.0, 3.0, 3.0, 7.0]) - self._writer.histogram("y", y) - event = self._readLastEvent() - self.assertEqual("y", event.summary.value[0].tag) - self.assertTrue(event.summary.value[0].histo) - - def testImageSummary(self): - with context.device(self._test_device): - a = constant_op.constant([[10.0, 20.0], [-20.0, -10.0]]) - self._writer.histogram("image1", a) - event = self._readLastEvent() - self.assertEqual("image1", event.summary.value[0].tag) - self.assertTrue(event.summary.value[0].image) - - def testAudioSummary(self): - with context.device(self._test_device): - w = constant_op.constant(np.random.rand(3, 10, 2), dtype=dtypes.float32) - fs = constant_op.constant(44100.0, dtype=dtypes.float32) - max_outputs = 1 - self._writer.audio("audio1", w, fs, max_outputs) - event = self._readLastEvent() - self.assertTrue(event.summary.value[0].audio) - - def testTwoSummaryWritersGlobalStepsWorkWithoutCrosstalk(self): - tmp_logdir2 = os.path.join(self._tmp_logdir, "_writer2_") - writer2 = summary_writer.SummaryWriter(tmp_logdir2, max_queue=0) - - self.assertEqual(0, writer2.global_step) - self._writer.step() - self.assertEqual(0, writer2.global_step) - writer2.step() - writer2.step() - writer2.step() - self.assertEqual(3, writer2.global_step) - - x = constant_op.constant(1337.0) - writer_orig_step = self._writer.global_step - self._writer.step() - self._writer.scalar("x", x) - - event = self._readLastEvent() - self.assertEqual(writer_orig_step + 1, event.step) - - writer2.scalar("x", x) - event = self._readLastEvent(tmp_logdir2) - self.assertEqual(3, event.step) - - self._writer.step() - self._writer.scalar("x", x) - - event = self._readLastEvent() - self.assertEqual(writer_orig_step + 2, event.step) - - -# TODO(cais): Add performance benchmark for SummaryWriter. - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/eager/python/tfe.py b/tensorflow/contrib/eager/python/tfe.py index 1697c879def8af5c05f3c9b11d318d570785d6de..712d1cb94d2f565bf6216f6c07a45d3d855efe9c 100644 --- a/tensorflow/contrib/eager/python/tfe.py +++ b/tensorflow/contrib/eager/python/tfe.py @@ -23,7 +23,9 @@ To use, at program startup, call `tfe.enable_eager_execution()`. @@list_devices @@num_gpus +@@py_func @@defun +@@make_template @@implicit_gradients @@implicit_value_and_gradients @@gradients_function @@ -50,6 +52,7 @@ To use, at program startup, call `tfe.enable_eager_execution()`. @@EagerVariableStore @@Network +@@Sequential @@save_network_checkpoint @@restore_network_checkpoint @@ -74,6 +77,7 @@ from __future__ import print_function from tensorflow.contrib.eager.python import metrics from tensorflow.contrib.eager.python.datasets import Iterator from tensorflow.contrib.eager.python.network import Network +from tensorflow.contrib.eager.python.network import Sequential from tensorflow.contrib.eager.python.network import save_network_checkpoint from tensorflow.contrib.eager.python.network import restore_network_checkpoint from tensorflow.contrib.eager.python.saver import get_optimizer_variables @@ -101,9 +105,13 @@ from tensorflow.python.framework.test_util import IsolateTest from tensorflow.python.framework.test_util import run_in_graph_and_eager_modes as run_test_in_graph_and_eager_modes from tensorflow.python.ops.resource_variable_ops import ResourceVariable as Variable from tensorflow.python.ops.variable_scope import EagerVariableStore +from tensorflow.python.ops import script_ops +from tensorflow.python.ops import template from tensorflow.python.util.all_util import remove_undocumented +py_func = script_ops.eager_py_func defun = function.defun +make_template = template.make_template_internal implicit_gradients = backprop.implicit_grad implicit_value_and_gradients = backprop.implicit_val_and_grad gradients_function = backprop.gradients_function diff --git a/tensorflow/contrib/estimator/BUILD b/tensorflow/contrib/estimator/BUILD index 8395e2db5ec0ce6f4adae5fa2467159549e70143..cdbe05e4d2d7117c5acb12d679f359a9db17c9cc 100644 --- a/tensorflow/contrib/estimator/BUILD +++ b/tensorflow/contrib/estimator/BUILD @@ -88,8 +88,9 @@ py_library( py_test( name = "dnn_linear_combined_test", - size = "small", + size = "medium", srcs = ["python/estimator/dnn_linear_combined_test.py"], + shard_count = 3, srcs_version = "PY2AND3", tags = [ "no_pip", @@ -204,6 +205,7 @@ py_test( "//tensorflow/python/estimator:metric_keys", "//tensorflow/python/estimator:model_fn", "//tensorflow/python/estimator:prediction_keys", + "//tensorflow/python/ops/losses", "//tensorflow/python/saved_model:signature_constants", "//third_party/py/numpy", "@six_archive//:six", @@ -330,23 +332,24 @@ py_library( "//tensorflow/python:device", "//tensorflow/python:device_lib", "//tensorflow/python:framework_ops", - "//tensorflow/python:gradients", "//tensorflow/python:math_ops", "//tensorflow/python:platform", + "//tensorflow/python:sparse_ops", + "//tensorflow/python:sparse_tensor", "//tensorflow/python:state_ops", "//tensorflow/python:training", "//tensorflow/python:variable_scope", - "//tensorflow/python:variables", "//tensorflow/python/estimator:export_output", "//tensorflow/python/estimator:model_fn", "//tensorflow/python/estimator:util", + "//tensorflow/python/ops/losses", "@six_archive//:six", ], ) cuda_py_test( name = "replicate_model_fn_test", - size = "small", + size = "medium", srcs = ["python/estimator/replicate_model_fn_test.py"], additional_deps = [ "//tensorflow/python/estimator", @@ -374,5 +377,9 @@ cuda_py_test( "//tensorflow/python:variables", ":replicate_model_fn", ], - tags = ["multi_gpu"], + tags = [ + "manual", + "multi_gpu", + "notap", + ], ) diff --git a/tensorflow/contrib/estimator/__init__.py b/tensorflow/contrib/estimator/__init__.py index 8191e06faed004df6927708ea04a67b90bd464de..0f75b77050b0ba4c752a6a74fdc7024170b6f318 100644 --- a/tensorflow/contrib/estimator/__init__.py +++ b/tensorflow/contrib/estimator/__init__.py @@ -26,6 +26,7 @@ from tensorflow.contrib.estimator.python.estimator.head import * from tensorflow.contrib.estimator.python.estimator.linear import * from tensorflow.contrib.estimator.python.estimator.logit_fns import * from tensorflow.contrib.estimator.python.estimator.multi_head import * +from tensorflow.contrib.estimator.python.estimator.replicate_model_fn import * from tensorflow.python.util.all_util import remove_undocumented # pylint: enable=unused-import,line-too-long,wildcard-import @@ -45,6 +46,8 @@ _allowed_symbols = [ 'call_logit_fn', 'dnn_logit_fn_builder', 'linear_logit_fn_builder', + 'replicate_model_fn', + 'TowerOptimizer', ] remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/estimator/python/estimator/extenders.py b/tensorflow/contrib/estimator/python/estimator/extenders.py index 29c3c7358534f6e8ebbd31cbfcd7e34086d9b506..c99bf8badb35e6fffb7cae8761db9d402b8b3a8f 100644 --- a/tensorflow/contrib/estimator/python/estimator/extenders.py +++ b/tensorflow/contrib/estimator/python/estimator/extenders.py @@ -100,7 +100,7 @@ def add_metrics(estimator, metric_fn): def clip_gradients_by_norm(optimizer, clip_norm): - """Returns an optimizer which clips gradients before appliying them. + """Returns an optimizer which clips gradients before applying them. Example: diff --git a/tensorflow/contrib/estimator/python/estimator/head.py b/tensorflow/contrib/estimator/python/estimator/head.py index a9311a20f127d92f02a95b8b48082fc90850635a..d6ca33e18923a5dd996431b0ff87c6ad3bccea92 100644 --- a/tensorflow/contrib/estimator/python/estimator/head.py +++ b/tensorflow/contrib/estimator/python/estimator/head.py @@ -44,6 +44,7 @@ _DEFAULT_SERVING_KEY = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY def multi_class_head(n_classes, weight_column=None, label_vocabulary=None, + loss_reduction=losses.Reduction.SUM, name=None): """Creates a `_Head` for multi class classification. @@ -76,6 +77,8 @@ def multi_class_head(n_classes, integer within [0, n_classes). If given, labels must be of string type and have any value in `label_vocabulary`. Note that errors will be raised if `label_vocabulary` is not provided but labels are strings. + loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to + reduce training loss over batch. Defaults to `SUM`. name: name of the head. If provided, summary and metrics keys will be suffixed by `"/" + name`. Also used as `name_scope` when creating ops. @@ -83,17 +86,20 @@ def multi_class_head(n_classes, An instance of `_Head` for multi class classification. Raises: - ValueError: if `n_classes`, `metric_class_ids` or `label_keys` is invalid. + ValueError: if `n_classes`, `label_vocabulary` or `loss_reduction` is + invalid. """ return head_lib._multi_class_head_with_softmax_cross_entropy_loss( # pylint:disable=protected-access n_classes=n_classes, weight_column=weight_column, label_vocabulary=label_vocabulary, + loss_reduction=loss_reduction, name=name) def binary_classification_head( - weight_column=None, thresholds=None, label_vocabulary=None, name=None): + weight_column=None, thresholds=None, label_vocabulary=None, + loss_reduction=losses.Reduction.SUM, name=None): """Creates a `_Head` for single label binary classification. This head uses `sigmoid_cross_entropy_with_logits` loss. @@ -128,6 +134,8 @@ def binary_classification_head( [0, 1]. If given, labels must be string type and have any value in `label_vocabulary`. Note that errors will be raised if `label_vocabulary` is not provided but labels are strings. + loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to + reduce training loss over batch. Defaults to `SUM`. name: name of the head. If provided, summary and metrics keys will be suffixed by `"/" + name`. Also used as `name_scope` when creating ops. @@ -135,17 +143,20 @@ def binary_classification_head( An instance of `_Head` for binary classification. Raises: - ValueError: if `thresholds` contains a value outside of `(0, 1)`. + ValueError: If `thresholds` contains a value outside of `(0, 1)`. + ValueError: If `loss_reduction` is invalid. """ return head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( # pylint:disable=protected-access weight_column=weight_column, thresholds=thresholds, label_vocabulary=label_vocabulary, + loss_reduction=loss_reduction, name=name) def regression_head(weight_column=None, label_dimension=1, + loss_reduction=losses.Reduction.SUM, name=None): """Creates a `_Head` for regression using the `mean_squared_error` loss. @@ -172,15 +183,21 @@ def regression_head(weight_column=None, label_dimension: Number of regression labels per example. This is the size of the last dimension of the labels `Tensor` (typically, this has shape `[batch_size, label_dimension]`). + loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to + reduce training loss over batch. Defaults to `SUM`. name: name of the head. If provided, summary and metrics keys will be suffixed by `"/" + name`. Also used as `name_scope` when creating ops. Returns: An instance of `_Head` for linear regression. + + Raises: + ValueError: If `label_dimension` or `loss_reduction` is invalid. """ return head_lib._regression_head_with_mean_squared_error_loss( # pylint:disable=protected-access weight_column=weight_column, label_dimension=label_dimension, + loss_reduction=loss_reduction, name=name) @@ -188,6 +205,7 @@ def multi_label_head(n_classes, weight_column=None, thresholds=None, label_vocabulary=None, + loss_reduction=losses.Reduction.SUM, loss_fn=None, name=None): """Creates a `_Head` for multi-label classification. @@ -237,6 +255,8 @@ def multi_label_head(n_classes, [0, n_classes) or multi-hot Tensor. If given, labels must be SparseTensor string type and have any value in `label_vocabulary`. Also there will be errors if vocabulary is not provided and labels are string. + loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to + reduce training loss over batch. Defaults to `SUM`. loss_fn: Optional loss function. name: name of the head. If provided, summary and metrics keys will be suffixed by `"/" + name`. Also used as `name_scope` when creating ops. @@ -245,7 +265,8 @@ def multi_label_head(n_classes, An instance of `_Head` for multi-label classification. Raises: - ValueError: if `n_classes`, `thresholds`, or `loss_fn` is invalid. + ValueError: if `n_classes`, `thresholds`, `loss_reduction` or `loss_fn` is + invalid. """ thresholds = tuple(thresholds) if thresholds else tuple() if n_classes is None or n_classes < 2: @@ -267,9 +288,13 @@ def multi_label_head(n_classes, 'Given: {}'.format(n_classes, len(label_vocabulary))) if loss_fn: _validate_loss_fn_args(loss_fn) + if (loss_reduction not in losses.Reduction.all() or + loss_reduction == losses.Reduction.NONE): + raise ValueError('Invalid loss_reduction: {}'.format(loss_reduction)) return _MultiLabelHead( n_classes=n_classes, weight_column=weight_column, thresholds=thresholds, - label_vocabulary=label_vocabulary, loss_fn=loss_fn, name=name) + label_vocabulary=label_vocabulary, loss_reduction=loss_reduction, + loss_fn=loss_fn, name=name) class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access @@ -280,12 +305,14 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access weight_column=None, thresholds=None, label_vocabulary=None, + loss_reduction=losses.Reduction.SUM, loss_fn=None, name=None): self._n_classes = n_classes self._weight_column = weight_column self._thresholds = thresholds self._label_vocabulary = label_vocabulary + self._loss_reduction = loss_reduction self._loss_fn = loss_fn self._name = name @@ -356,14 +383,12 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access unweighted_loss, axis=-1, keep_dims=True) weights = head_lib._get_weights_and_check_match_logits( # pylint:disable=protected-access, features=features, weight_column=self._weight_column, logits=logits) - weighted_sum_loss = losses.compute_weighted_loss( - unweighted_loss, weights=weights, reduction=losses.Reduction.SUM) - # _weights() can return 1. - example_weight_sum = math_ops.reduce_sum( - weights * array_ops.ones_like(unweighted_loss)) + training_loss = losses.compute_weighted_loss( + unweighted_loss, weights=weights, reduction=self._loss_reduction) return head_lib.LossSpec( - weighted_sum_loss=weighted_sum_loss, - example_weight_sum=example_weight_sum, + training_loss=training_loss, + unreduced_loss=unweighted_loss, + weights=weights, processed_labels=processed_labels) def create_estimator_spec( @@ -394,60 +419,60 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access export_output.PredictOutput(predictions)) }) - (weighted_sum_loss, example_weight_sum, + (training_loss, unreduced_loss, weights, processed_labels) = self.create_loss( features=features, mode=mode, logits=logits, labels=labels) # Eval. if mode == model_fn.ModeKeys.EVAL: - weights = head_lib._get_weights_and_check_match_logits( # pylint:disable=protected-access, - features=features, weight_column=self._weight_column, logits=logits) return model_fn.EstimatorSpec( mode=model_fn.ModeKeys.EVAL, predictions=predictions, - loss=weighted_sum_loss, + loss=training_loss, eval_metric_ops=self._eval_metric_ops( labels=processed_labels, probabilities=probabilities, weights=weights, - weighted_sum_loss=weighted_sum_loss, - example_weight_sum=example_weight_sum)) + unreduced_loss=unreduced_loss)) # Train. if train_op_fn is None: raise ValueError('train_op_fn can not be None.') + # Only summarize mean_loss for SUM reduction to preserve backwards + # compatibility. Otherwise skip it to avoid unnecessary computation. + if self._loss_reduction == losses.Reduction.SUM: + example_weight_sum = math_ops.reduce_sum( + weights * array_ops.ones_like(unreduced_loss)) + mean_loss = training_loss / example_weight_sum + else: + mean_loss = None with ops.name_scope(''): summary.scalar( head_lib._summary_key(self._name, metric_keys.MetricKeys.LOSS), # pylint:disable=protected-access - weighted_sum_loss) - summary.scalar( - head_lib._summary_key( # pylint:disable=protected-access - self._name, metric_keys.MetricKeys.LOSS_MEAN), - weighted_sum_loss / example_weight_sum) + training_loss) + if mean_loss is not None: + summary.scalar( + head_lib._summary_key( # pylint:disable=protected-access + self._name, metric_keys.MetricKeys.LOSS_MEAN), + mean_loss) return model_fn.EstimatorSpec( mode=model_fn.ModeKeys.TRAIN, predictions=predictions, - loss=weighted_sum_loss, - train_op=train_op_fn(weighted_sum_loss)) + loss=training_loss, + train_op=train_op_fn(training_loss)) - def _eval_metric_ops(self, labels, probabilities, weights, weighted_sum_loss, - example_weight_sum): + def _eval_metric_ops(self, labels, probabilities, weights, unreduced_loss): """Returns a dict of metrics for eval_metric_ops.""" with ops.name_scope( None, 'metrics', - [labels, probabilities, weights, weighted_sum_loss, example_weight_sum - ]): + [labels, probabilities, weights, unreduced_loss]): keys = metric_keys.MetricKeys metric_ops = { # Estimator already adds a metric for loss. head_lib._summary_key(self._name, keys.LOSS_MEAN): # pylint:disable=protected-access metrics_lib.mean( - # Both values and weights here are reduced, scalar Tensors. - # values is the actual mean we want, but we pass the scalar - # example_weight_sum in order to return the correct update_op - # alongside the value_op for streaming metrics. - values=(weighted_sum_loss / example_weight_sum), - weights=example_weight_sum, + values=unreduced_loss, + weights=weights, name=keys.LOSS_MEAN), head_lib._summary_key(self._name, keys.AUC): # pylint:disable=protected-access metrics_lib.auc(labels=labels, predictions=probabilities, diff --git a/tensorflow/contrib/estimator/python/estimator/head_test.py b/tensorflow/contrib/estimator/python/estimator/head_test.py index d1cf9090048470181818c573647923c9f5824dfa..e39e44541d2d30b1ecc9d4d41d0760decdc58168 100644 --- a/tensorflow/contrib/estimator/python/estimator/head_test.py +++ b/tensorflow/contrib/estimator/python/estimator/head_test.py @@ -35,6 +35,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import string_ops +from tensorflow.python.ops.losses import losses from tensorflow.python.platform import test from tensorflow.python.saved_model import signature_constants from tensorflow.python.training import monitored_session @@ -132,6 +133,16 @@ class MultiLabelHead(test.TestCase): r'Length of label_vocabulary must be n_classes \(3\). Given: 2'): head_lib.multi_label_head(n_classes=3, label_vocabulary=['foo', 'bar']) + def test_invalid_loss_reduction(self): + with self.assertRaisesRegexp( + ValueError, r'Invalid loss_reduction: invalid_loss_reduction'): + head_lib.multi_label_head( + n_classes=3, loss_reduction='invalid_loss_reduction') + with self.assertRaisesRegexp( + ValueError, r'Invalid loss_reduction: none'): + head_lib.multi_label_head( + n_classes=3, loss_reduction=losses.Reduction.NONE) + def test_loss_fn_arg_labels_missing(self): def _loss_fn(logits): del logits # Unused @@ -262,17 +273,17 @@ class MultiLabelHead(test.TestCase): labels = np.array([[1, 0], [1, 1]], dtype=np.int64) # loss = labels * -log(sigmoid(logits)) + # (1 - labels) * -log(1 - sigmoid(logits)) - expected_weighted_sum_loss = np.sum( + expected_training_loss = np.sum( _sigmoid_cross_entropy(labels=labels, logits=logits)) - actual_weighted_sum_loss = head.create_loss( + actual_training_loss = head.create_loss( features={'x': np.array(((42,),), dtype=np.int32)}, mode=model_fn.ModeKeys.EVAL, logits=logits, labels=labels)[0] with self.test_session(): _initialize_variables(self, monitored_session.Scaffold()) - self.assertAllClose(expected_weighted_sum_loss, - actual_weighted_sum_loss.eval()) + self.assertAllClose(expected_training_loss, + actual_training_loss.eval()) def test_eval_create_loss_large_logits(self): """Tests head.create_loss for eval mode and large logits.""" @@ -286,9 +297,9 @@ class MultiLabelHead(test.TestCase): # For large logits, this is approximated as: # loss = labels * (logits < 0) * (-logits) + # (1 - labels) * (logits > 0) * logits - expected_weighted_sum_loss = np.sum( + expected_training_loss = np.sum( np.array([[(10. + 10.) / 2.], [(15. + 0.) / 2.]], dtype=np.float32)) - actual_weighted_sum_loss = head.create_loss( + actual_training_loss = head.create_loss( features={'x': np.array(((42,),), dtype=np.int32)}, mode=model_fn.ModeKeys.EVAL, logits=logits, @@ -296,9 +307,7 @@ class MultiLabelHead(test.TestCase): with self.test_session(): _initialize_variables(self, monitored_session.Scaffold()) self.assertAllClose( - expected_weighted_sum_loss, - actual_weighted_sum_loss.eval(), - atol=1e-4) + expected_training_loss, actual_training_loss.eval(), atol=1e-4) def test_eval_create_loss_labels_wrong_shape(self): """Tests head.create_loss for eval mode when labels has the wrong shape.""" @@ -307,7 +316,7 @@ class MultiLabelHead(test.TestCase): logits = np.array([[-1., 1.], [-1.5, 1.]], dtype=np.float32) labels_placeholder = array_ops.placeholder(dtype=dtypes.int64) - actual_weighted_sum_loss = head.create_loss( + actual_training_loss = head.create_loss( features={'x': np.array(((42,),), dtype=np.int32)}, mode=model_fn.ModeKeys.EVAL, logits=logits, @@ -317,14 +326,14 @@ class MultiLabelHead(test.TestCase): with self.assertRaisesRegexp( errors.InvalidArgumentError, r'\[expected_labels_shape: \] \[2 2\] \[labels_shape: \] \[2 1\]'): - actual_weighted_sum_loss.eval({ + actual_training_loss.eval({ labels_placeholder: np.array([[1], [1]], dtype=np.int64) }) with self.assertRaisesRegexp( errors.InvalidArgumentError, r'labels shape must be \[D0, D1, ... DN, 2\]\..*' r'\[Received shape: \] \[2\]'): - actual_weighted_sum_loss.eval({ + actual_training_loss.eval({ labels_placeholder: np.array([1, 1], dtype=np.int64) }) @@ -344,14 +353,14 @@ class MultiLabelHead(test.TestCase): return constant_op.constant(loss) head = head_lib.multi_label_head(n_classes=2, loss_fn=_loss_fn) - actual_weighted_sum_loss = head.create_loss( + actual_training_loss = head.create_loss( features={'x': np.array(((42,),), dtype=np.int32)}, mode=model_fn.ModeKeys.EVAL, logits=logits_input, labels=labels_input)[0] with self.test_session(): _initialize_variables(self, monitored_session.Scaffold()) - self.assertAllClose(np.sum(loss), actual_weighted_sum_loss.eval()) + self.assertAllClose(np.sum(loss), actual_training_loss.eval()) def test_eval_create_loss_loss_fn_wrong_shape(self): """Tests custom loss_fn that returns Tensor of unexpected shape.""" @@ -363,7 +372,7 @@ class MultiLabelHead(test.TestCase): logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32) labels = np.array([[1, 0], [1, 1]], dtype=np.int64) - actual_weighted_sum_loss = head.create_loss( + actual_training_loss = head.create_loss( features={'x': np.array(((42,),), dtype=np.int32)}, mode=model_fn.ModeKeys.EVAL, logits=logits, @@ -374,7 +383,7 @@ class MultiLabelHead(test.TestCase): errors.InvalidArgumentError, r'loss_fn must return Tensor of shape \[batch_size, 1\]\. ' r'Given: \] \[2\]'): - actual_weighted_sum_loss.eval() + actual_training_loss.eval() def test_eval_labels_none(self): """Tests that error is raised when labels is None.""" @@ -618,12 +627,44 @@ class MultiLabelHead(test.TestCase): # For large logits, this is approximated as: # loss = labels * (logits < 0) * (-logits) + # (1 - labels) * (logits > 0) * logits - expected_weighted_sum_loss = np.sum( - np.array( - [[1. * (10. + 10.) / 2.], [2. * (15. + 0.) / 2.]], - dtype=np.float32)) - expected_example_weight_sum = 1. + 2. - actual_weighted_sum_loss, actual_example_weight_sum, _ = head.create_loss( + expected_unreduced_loss = [[(10. + 10.) / 2.], [(15. + 0.) / 2.]] + expected_weights = [[1.], [2.]] + expected_training_loss = 1. * (10. + 10.) / 2. + 2. * (15. + 0.) / 2. + training_loss, unreduced_loss, actual_weights, _ = head.create_loss( + features={ + 'x': np.array(((42,),), dtype=np.int32), + 'example_weights': weights + }, + mode=model_fn.ModeKeys.TRAIN, + logits=logits, + labels=labels) + with self.test_session(): + _initialize_variables(self, monitored_session.Scaffold()) + self.assertAllClose( + expected_training_loss, training_loss.eval(), atol=1e-4) + self.assertAllClose( + expected_unreduced_loss, unreduced_loss.eval(), atol=1e-4) + self.assertAllClose(expected_weights, actual_weights.eval()) + + def test_train_create_loss_loss_reduction(self): + """Tests head.create_loss with loss_reduction.""" + n_classes = 2 + head = head_lib.multi_label_head( + n_classes, weight_column='example_weights', + loss_reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS) + + logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32) + labels = np.array([[1, 0], [1, 1]], dtype=np.int64) + weights = np.array([[1.], [2.]], dtype=np.float32) + # loss = labels * -log(sigmoid(logits)) + + # (1 - labels) * -log(1 - sigmoid(logits)) + # For large logits, this is approximated as: + # loss = labels * (logits < 0) * (-logits) + + # (1 - labels) * (logits > 0) * logits + expected_unreduced_loss = [[(10. + 10.) / 2.], [(15. + 0.) / 2.]] + expected_weights = [[1.], [2.]] + expected_training_loss = (1. * (10. + 10.) / 2. + 2. * (15. + 0.) / 2.) / 2. + training_loss, unreduced_loss, actual_weights, _ = head.create_loss( features={ 'x': np.array(((42,),), dtype=np.int32), 'example_weights': weights @@ -634,13 +675,10 @@ class MultiLabelHead(test.TestCase): with self.test_session(): _initialize_variables(self, monitored_session.Scaffold()) self.assertAllClose( - expected_weighted_sum_loss, - actual_weighted_sum_loss.eval(), - atol=1e-4) + expected_training_loss, training_loss.eval(), atol=1e-4) self.assertAllClose( - expected_example_weight_sum, - actual_example_weight_sum.eval(), - atol=1e-4) + expected_unreduced_loss, unreduced_loss.eval(), atol=1e-4) + self.assertAllClose(expected_weights, actual_weights.eval()) def test_train_labels_none(self): """Tests that error is raised when labels is None.""" @@ -851,12 +889,15 @@ class MultiLabelHead(test.TestCase): labels = np.array([[[1, 0, 0], [1, 0, 0]], [[0, 1, 1], [0, 1, 1]]], dtype=np.int64) weights = np.array([[1., 1.5], [2., 2.5]], dtype=np.float32) - # loss = [[10 + 10 + 0, 0 + 0 + 10], [0 + 0 + 12, 12 + 12 + 0]] / 3 - # = [[20/3, 10/3], [4, 8]] + # unreduced_loss = + # [[10 + 10 + 0, 0 + 0 + 10], [0 + 0 + 12, 12 + 12 + 0]] / 3 + # = [[20/3, 10/3], [4, 8]] + expected_unreduced_loss = [[[20./3.], [10./3.]], [[4.], [8.]]] + # weights are reshaped to [2, 2, 1] to match logits. + expected_weights = [[[1.], [1.5]], [[2.], [2.5]]] # weighted_sum_loss = 1*20/3 + 1.5*10/3 + 2*4 + 2.5*8 = 39.6667 - expected_weighted_sum_loss = 39.6667 - expected_example_weight_sum = np.sum(weights) - actual_weighted_sum_loss, actual_example_weight_sum, _ = head.create_loss( + expected_training_loss = 39.6667 + training_loss, unreduced_loss, actual_weights, _ = head.create_loss( features={'weights': weights}, mode=model_fn.ModeKeys.TRAIN, logits=logits, @@ -865,11 +906,10 @@ class MultiLabelHead(test.TestCase): with self.test_session(): _initialize_variables(self, monitored_session.Scaffold()) self.assertAllClose( - expected_weighted_sum_loss, actual_weighted_sum_loss.eval(), - atol=atol) + expected_training_loss, training_loss.eval(), atol=atol) self.assertAllClose( - expected_example_weight_sum, actual_example_weight_sum.eval(), - atol=atol) + expected_unreduced_loss, unreduced_loss.eval(), atol=atol) + self.assertAllClose(expected_weights, actual_weights.eval()) def test_multi_dim_weighted_train(self): """Logits and labels of shape [2, 2, 3], weights [2, 2].""" diff --git a/tensorflow/contrib/estimator/python/estimator/multi_head.py b/tensorflow/contrib/estimator/python/estimator/multi_head.py index f2a6eae03ec021e5c28d48b3887870d8a057e077..0346ddc24bffd61068177f4622bd03be4acd53d9 100644 --- a/tensorflow/contrib/estimator/python/estimator/multi_head.py +++ b/tensorflow/contrib/estimator/python/estimator/multi_head.py @@ -186,40 +186,44 @@ class _MultiHead(head_lib._Head): # pylint:disable=protected-access logits_dict = logits else: logits_dict = self._split_logits(logits) - weighted_sum_losses = [] - example_weight_sums = [] + training_losses = [] labels_by_head = {} - for head in self._heads: - (weighted_sum_loss, - example_weight_sum, processed_labels) = head.create_loss( + unreduced_losses_by_head = {} + example_weights_by_head = {} + for i, head in enumerate(self._heads): + (training_loss, unreduced_loss, + weights, processed_labels) = head.create_loss( features, mode, logits_dict[head.name], labels[head.name]) - weighted_sum_losses.append(weighted_sum_loss) - example_weight_sums.append(example_weight_sum) + training_losses.append(training_loss) labels_by_head[head.name] = processed_labels + if self._head_weights: + head_weight = self._head_weights[i] + unreduced_losses_by_head[head.name] = math_ops.multiply( + unreduced_loss, head_weight) + example_weights_by_head[head.name] = math_ops.multiply( + weights, head_weight) + else: + unreduced_losses_by_head[head.name] = unreduced_loss + example_weights_by_head[head.name] = weights - weighted_sum_losses = tuple(weighted_sum_losses) - with ops.name_scope('merge_losses', - values=weighted_sum_losses + (self._head_weights or - tuple())): + training_losses = tuple(training_losses) + with ops.name_scope( + 'merge_losses', + values=training_losses + (self._head_weights or tuple())): if self._head_weights: - head_weighted_losses = [] - head_weighted_example_weight_sums = [] - for loss, example_weight_sum, weight in zip(weighted_sum_losses, - example_weight_sums, - self._head_weights): - head_weighted_losses.append(math_ops.multiply(loss, weight)) - head_weighted_example_weight_sums.append(math_ops.multiply( - example_weight_sum, weight)) - merged_weighted_sum_loss = math_ops.add_n(head_weighted_losses) - merged_example_weight_sum = math_ops.add_n( - head_weighted_example_weight_sums) + head_weighted_training_losses = [] + for training_loss, head_weight in zip( + training_losses, self._head_weights): + head_weighted_training_losses.append( + math_ops.multiply(training_loss, head_weight)) + merged_training_loss = math_ops.add_n(head_weighted_training_losses) else: - merged_weighted_sum_loss = math_ops.add_n(weighted_sum_losses) - merged_example_weight_sum = math_ops.add_n(example_weight_sums) + merged_training_loss = math_ops.add_n(training_losses) return head_lib.LossSpec( - weighted_sum_loss=merged_weighted_sum_loss, - example_weight_sum=merged_example_weight_sum, + training_loss=merged_training_loss, + unreduced_loss=unreduced_losses_by_head, + weights=example_weights_by_head, processed_labels=labels_by_head) def create_estimator_spec( diff --git a/tensorflow/contrib/estimator/python/estimator/multi_head_test.py b/tensorflow/contrib/estimator/python/estimator/multi_head_test.py index 68f2d5d1cd53456f7dd82222e171b3619052321a..65ea89ba1b9236d0bf4d2de430fab168ef50bf97 100644 --- a/tensorflow/contrib/estimator/python/estimator/multi_head_test.py +++ b/tensorflow/contrib/estimator/python/estimator/multi_head_test.py @@ -370,7 +370,7 @@ class MultiHeadTest(test.TestCase): 'head1': np.array([[1, 0], [1, 1]], dtype=np.int64), 'head2': np.array([[0, 1, 0], [1, 1, 0]], dtype=np.int64), } - weighted_sum_loss, example_weight_sum, _ = multi_head.create_loss( + training_loss, unreduced_losses, weights, _ = multi_head.create_loss( features={ 'x': np.array(((42,),), dtype=np.int32), 'weights1': weights1, @@ -383,14 +383,23 @@ class MultiHeadTest(test.TestCase): with self.test_session(): # loss of the first head is [[(10 + 10) / 2], [(15 + 0) / 2]] # = [10, 7.5] - # weighted_sum_loss = 1 * 10 + 2 * 7.5 = 25 + # training_loss = 1 * 10 + 2 * 7.5 = 25 + # head-weighted unreduced_loss = 1 * [10, 7.5] + self.assertAllClose( + [[10.], [7.5]], unreduced_losses['head1'].eval(), rtol=tol, atol=tol) # loss of the second head is [[(20 + 20 + 20) / 3], [(30 + 0 + 0) / 3]] # = [20, 10] - # weighted_sum_loss = 2 * 20 + 3 * 10 = 70 - # head-weighted merge = 1 * 25 + 2 * 70 = 165 - self.assertAllClose(165, weighted_sum_loss.eval(), rtol=tol, atol=tol) - # example_weight_sum = 1 * (1 + 2) + 2 * (2 + 3) = 13 - self.assertAllClose(13., example_weight_sum.eval(), rtol=tol, atol=tol) + # training_loss = 2 * 20 + 3 * 10 = 70 + # head-weighted unreduced_loss = 2 * [20, 10] + self.assertAllClose( + [[40.], [20.]], unreduced_losses['head2'].eval(), rtol=tol, atol=tol) + # head-weighted training_loss = 1 * 25 + 2 * 70 = 165 + self.assertAllClose(165, training_loss.eval(), rtol=tol, atol=tol) + # head-weighted example weights + self.assertAllClose( + [[1.], [2.]], weights['head1'].eval(), rtol=tol, atol=tol) + self.assertAllClose( + [[4.], [6.]], weights['head2'].eval(), rtol=tol, atol=tol) def test_train_create_loss_logits_tensor(self): """Tests create_loss with logits Tensor.""" @@ -409,7 +418,7 @@ class MultiHeadTest(test.TestCase): 'head1': np.array([[1, 0], [1, 1]], dtype=np.int64), 'head2': np.array([[0, 1, 0], [1, 1, 0]], dtype=np.int64), } - weighted_sum_loss, example_weight_sum, _ = multi_head.create_loss( + training_loss, unreduced_losses, weights, _ = multi_head.create_loss( features={ 'x': np.array(((42,),), dtype=np.int32), 'weights1': weights1, @@ -422,14 +431,23 @@ class MultiHeadTest(test.TestCase): with self.test_session(): # loss of the first head is [[(10 + 10) / 2], [(15 + 0) / 2]] # = [10, 7.5] - # weighted_sum_loss = 1 * 10 + 2 * 7.5 = 25 + # training_loss = 1 * 10 + 2 * 7.5 = 25 + # head-weighted unreduced_loss = 1 * [10, 7.5] + self.assertAllClose( + [[10.], [7.5]], unreduced_losses['head1'].eval(), rtol=tol, atol=tol) # loss of the second head is [[(20 + 20 + 20) / 3], [(30 + 0 + 0) / 3]] # = [20, 10] - # weighted_sum_loss = 2 * 20 + 3 * 10 = 70 - # head-weighted merge = 1 * 25 + 2 * 70 = 165 - self.assertAllClose(165, weighted_sum_loss.eval(), rtol=tol, atol=tol) - # example_weight_sum = 1 * (1 + 2) + 2 * (2 + 3) = 13 - self.assertAllClose(13., example_weight_sum.eval(), rtol=tol, atol=tol) + # training_loss = 2 * 20 + 3 * 10 = 70 + # head-weighted unreduced_loss = 2 * [20, 10] + self.assertAllClose( + [[40.], [20.]], unreduced_losses['head2'].eval(), rtol=tol, atol=tol) + # head-weighted training_loss = 1 * 25 + 2 * 70 = 165 + self.assertAllClose(165, training_loss.eval(), rtol=tol, atol=tol) + # head-weighted example weights + self.assertAllClose( + [[1.], [2.]], weights['head1'].eval(), rtol=tol, atol=tol) + self.assertAllClose( + [[4.], [6.]], weights['head2'].eval(), rtol=tol, atol=tol) def test_train_create_loss_logits_tensor_multi_dim(self): """Tests create_loss with multi-dimensional logits of shape [2, 2, 5].""" @@ -455,20 +473,17 @@ class MultiHeadTest(test.TestCase): # loss2 = (0-2)^2 + (1+2)^2 + (0-2)^2 + (0-2)^2 + (1+2)^2 + (0-2)^2 + # (2+2)^2 + (2-2)^2 + (0+2)^2 + (2+2)^2 + (2-2)^2 + (0+2)^2 # = 74 - expected_weighted_sum_loss = 28. + 74. + expected_training_loss = 28. + 74. - weighted_sum_loss, example_weight_sum, _ = multi_head.create_loss( + training_loss = multi_head.create_loss( features={}, mode=model_fn.ModeKeys.TRAIN, logits=logits, - labels=labels) + labels=labels)[0] tol = 1e-3 with self.test_session(): self.assertAllClose( - expected_weighted_sum_loss, weighted_sum_loss.eval(), - rtol=tol, atol=tol) - self.assertAllClose( - 2. * 2. * 5., example_weight_sum.eval(), rtol=tol, atol=tol) + expected_training_loss, training_loss.eval(), rtol=tol, atol=tol) def test_train_one_head(self): head1 = head_lib.multi_label_head(n_classes=2, name='head1') diff --git a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py index d9c83aa86577aa129458c56887ff4668c103d0db..caa9dd83233b6b850385335fde96431271d85c3a 100644 --- a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py +++ b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py @@ -23,6 +23,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from collections import defaultdict +from contextlib import contextmanager import copy import six @@ -41,20 +43,24 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import sparse_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope +from tensorflow.python.ops.losses import losses from tensorflow.python.platform import tf_logging -from tensorflow.python.training import training_util +from tensorflow.python.training import device_setter as device_setter_lib +from tensorflow.python.training import optimizer as optimizer_lib -def replicate_model_fn(model_fn, optimizer_fn, devices=None): - """Replicate `Estimator.model_fn` over GPUs within a single host. +def replicate_model_fn(model_fn, + loss_reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS, + devices=None): + """Replicate `Estimator.model_fn` over GPUs. The given `model_fn` specifies a single forward pass of a model. To replicate such a model over GPUs, each GPU gets its own instance of the forward pass (a.k.a. a tower). The input features and labels get sharded into the chunks - that correspond to the number of GPUs. Each tower computes its own loss based + that correspond to the number of GPUs. Each tower computes a loss based on its input. For each such loss, gradients are computed. After that, the - available losses are summed to form aggregated loss. The available - gradients are summed too. Then, they update weights using the specified + available losses are aggregated to form aggregated loss. Available + gradients are summed. Then, they update weights using the specified optimizer. If `devices` are `None`, then all available GPUs are going to be used for @@ -63,36 +69,38 @@ def replicate_model_fn(model_fn, optimizer_fn, devices=None): Two modes of local replication over available GPUs are supported: 1) If exactly 1 GPU is detected, then variables and operations are placed - onto GPU. + onto the GPU. 2) If more than 1 GPU is detected, then variables are going to be placed on the CPU. Replicas of operations are placed on each individual GPU. Here is an example of how one might use their `model_fn` to run over GPUs: ```python - def optimizer_fn(): - return tf.train.GradientDescentOptimizer(learning_rate=0.001) ... def model_fn(...): # See `model_fn` in `Estimator`. loss = ... + optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.001) + optimizer = tf.contrib.estimator.TowerOptimizer(optimizer) if mode == tf.estimator.ModeKeys.TRAIN: # See the section below on `EstimatorSpec.train_op`. - return EstimatorSpec(mode=mode, loss=loss, train_op=tf.noop()) + return EstimatorSpec(mode=mode, loss=loss, + train_op=optimizer.minimize(loss)) # No change for `ModeKeys.EVAL` or `ModeKeys.PREDICT`. return EstimatorSpec(...) ... classifier = tf.estimator.Estimator( - model_fn=replicate_model_fn.replicate_model_fn(model_fn, optimizer_fn)) + model_fn=tf.contrib.estimator.replicate_model_fn(model_fn)) ``` + Please see `DNNClassifierIntegrationTest` for an example with a canned + Estimator. + On `EstimatorSpec.train_op`: `model_fn` returns `EstimatorSpec.train_op` for `tf.estimator.GraphKeys.TRAIN`. It is typically derived using an optimizer. - `replicate_model_fn` ignores the returned `EstimatorSpec.train_op`, so there - is no need to use an optimizer inside the user's `model_fn`. The - `EstimatorSpec.loss` subgraph is going to be executed, while - `EstimatorSpec.train_op` isn't going to be executed. One could pass - `train_op=tf.noop()` to `EstimatorSpec`. + Towers are expected to populate it in the same way. Gradients from all towers + are reduced and applied in the last tower. To achieve that in the case of + multiple towers, `TowerOptimizer` needs to be used. See `TowerOptimizer`. On sharding input features and labels: Input features and labels are split for consumption by each tower. They are @@ -101,7 +109,7 @@ def replicate_model_fn(model_fn, optimizer_fn, devices=None): On reduction algorithms: Certain algorithms were chosen for aggregating results of computations on multiple towers: - - Losses from all towers are reduced using sum. + - Losses from all towers are reduced according to `loss_reduction`. - Gradients are reduced using sum for each trainable variable. - `eval_metrics_ops` are reduced per metric using `reduce_mean`. - `EstimatorSpec.predictions` and `EstimatorSpec.export_outputs` are @@ -109,65 +117,332 @@ def replicate_model_fn(model_fn, optimizer_fn, devices=None): - For all other fields of `EstimatorSpec` the values of the first tower are taken. - On replication of variables: + On distribution of variables: Variables are not duplicated between towers. Instead, they are placed on a single device as defined above and shared across towers. - Other current limitations: - - `predictions` are not supported for `ModeKeys.EVAL`. That is required for - `tf.contrib.estimator.add_metrics`. + On overhead: + If only one device is specified, then aggregation of loss and gradients + doesn't happen. Replication consists of placing `model_fn` onto the + specified device. + + On current limitations: + - `predictions` are not supported for `ModeKeys.EVAL`. They are required + for `tf.contrib.estimator.add_metrics`. Args: model_fn: `model_fn` as defined in `Estimator`. See the section above about the train_op argument of `EstimatorSpec`. - optimizer_fn: a function that returns an optimizer instance. The function - may accept one `params` argument. This is the `params` argument as - defined by `Estimator`. See the `Estimator` documentation for details. + loss_reduction: controls whether losses are summed or averaged. devices: Optional list of devices to replicate the model across. This argument can be used to replice only on the subset of available GPUs. If `None`, then all available GPUs are going to be used for replication. If no GPUs are available, then the model is going to be placed on the CPU. + Raises: + ValueError: if there is no `loss_reduction` or if TowerOptimizer is + mis-used. + Returns: A replicated version of the supplied `model_fn`. Returned function that conforms to the requirements of `Estimator`'s `model_fn` and can be used instead of the supplied `model_fn`. """ + return _replicate_model_fn_with_mode( + model_fn, + loss_reduction, + devices, + # TODO(isaprykin): Query the system configuration to choose modes other + # than `SHARED_LOCAL_PARAMETER_SERVER`, even though it is often + # appropriate. + mode=_VariableDistributionMode.SHARED_LOCAL_PARAMETER_SERVER) + + +class _VariableDistributionMode(object): + """Modes for variable distribution used for forcing a particular one. + + Forcing a mode is meant for performance experimentation purposes rather than + for general use cases. + """ + + SHARED_LOCAL_PARAMETER_SERVER = 1 + """Variables are placed on a single device and shared across all devices. + + Two ways to achieve this distribution over available GPUs are supported: + 1) If exactly 1 GPU is detected, then variables and operations are placed + onto GPU. + 2) If more than 1 GPU is detected, then variables are going to be placed on + the CPU. Replicas of operations are placed on each individual GPU. + """ + + SHARED_ROUND_ROBIN = 2 + """Variables are placed on all devices in a round-robin fashion. + + Every subsequent variable is placed on the next device. There is only one + copy of each variable that is shared across all devices. + """ + + +def _replicate_model_fn_with_mode( + model_fn, + loss_reduction, + devices=None, + mode=_VariableDistributionMode.SHARED_LOCAL_PARAMETER_SERVER): + """A version of `replicate_model_fn` that allows to specify a `mode`.""" + if loss_reduction == losses.Reduction.NONE: + raise ValueError('Tower losses need to be reduced in some way, yet {} ' + 'reduction is specified.'.format(loss_reduction)) if not devices: devices = _get_local_devices('GPU') or _get_local_devices('CPU') is_a_single_gpu_case = len(devices) == 1 and 'GPU' in devices[0] - local_ps_device = '/{}:0'.format('GPU' if is_a_single_gpu_case else 'CPU') + consolidation_device = devices[0] if is_a_single_gpu_case else '/CPU:0' - tf_logging.info('Replicating the `model_fn` across {}. Local parameter ' - 'server device is going to be {}.'.format( - devices, local_ps_device)) + ps_devices = [consolidation_device] + if mode == _VariableDistributionMode.SHARED_ROUND_ROBIN: + ps_devices = devices + + tf_logging.info('Replicating the `model_fn` across {}. Variables are going ' + 'to be placed on {}. Consolidation device is going to be {}.' + .format(devices, ps_devices, consolidation_device)) + + def single_device_model_fn(features, labels, mode, params=None, config=None): + """`model_fn` on a single device without reduction overhead.""" + return _get_loss_towers( + model_fn=model_fn, + mode=mode, + features=[features], + labels=[labels], + params=params, + loss_reduction=loss_reduction, + config=config, + devices=devices, + local_ps_devices=ps_devices)[0] # One device, so one spec is out. def replicated_model_fn(features, labels, mode, params=None, config=None): """Replicated version of `model_fn` to be used instead.""" feature_shards, label_shards = _split_batch( - features, labels, len(devices), device=local_ps_device) + features, labels, len(devices), device=consolidation_device) tower_specs = _get_loss_towers( model_fn=model_fn, mode=mode, features=feature_shards, labels=label_shards, params=params, + loss_reduction=loss_reduction, config=config, devices=devices, - local_ps_device=local_ps_device) + local_ps_devices=ps_devices) if mode == model_fn_lib.ModeKeys.TRAIN: - train_op = _minimize_towers(tower_specs, - _call_optimizer_fn(optimizer_fn, params)) + train_op = _minimize_towers(tower_specs) return _train_spec( - tower_specs, train_op, aggregation_device=local_ps_device) + tower_specs, train_op, aggregation_device=consolidation_device) elif mode == model_fn_lib.ModeKeys.EVAL: - return _eval_spec(tower_specs, aggregation_device=local_ps_device) + return _eval_spec(tower_specs, aggregation_device=consolidation_device) elif mode == model_fn_lib.ModeKeys.PREDICT: - return _predict_spec(tower_specs, aggregation_device=local_ps_device) + return _predict_spec(tower_specs, aggregation_device=consolidation_device) + + if len(devices) == 1: + return single_device_model_fn + else: + return replicated_model_fn + + +class TowerOptimizer(optimizer_lib.Optimizer): + """Gathers gradients from all towers and reduces them in the last one.""" + + COLLECTION_FOR_GRAPH_STATES = 'replicate_model_fn_graph_states' - return replicated_model_fn + def __init__(self, optimizer_or_optimizer_fn): + """Wrap an existing optimizer for gathering gradients across towers. + + Each invocation of model_fn has to call the same optimizers in the same + order. + + Multiple optimizers that use the same or different losses are supported. + + If TowerOptimizer is used but `replicate_model_fn` isn't, then no + aggregation will happen. All calls will simply be forwarded to the + underlying optimizer. The behavior is similar if there is only one tower. + + If TowerOptimizer is used together with SyncReplicasOptimizer that wraps + the user's optimizer, then it's the SyncReplicasOptimizer that needs to be + wrapped with TowerOptimizer. + + Args: + optimizer_or_optimizer_fn: an instance of optimizer to wrap. That + instance is going to be used for optimizer-specific logic. This can + also be a no-argument function that returns such an optimizer instance. + """ + self._optimizer_or_optimizer_fn = optimizer_or_optimizer_fn + + @staticmethod + def has_been_used(): + return TowerOptimizer._graph_state().has_tower_optimizer_been_used + + def get_slot(self, *args, **kwargs): + return self._get_optimizer().get_slot(*args, **kwargs) + + def get_slot_names(self, *args, **kwargs): + return self._get_optimizer().get_slot_names(*args, **kwargs) + + def get_name(self, *args, **kwargs): + return self._get_optimizer().get_name(*args, **kwargs) + + def variables(self, *args, **kwargs): + return self._get_optimizer().variables(*args, **kwargs) + + def compute_gradients(self, loss, *args, **kwargs): + """Compute gradients, but first, if needed, scale the loss.""" + loss = _scale_loss(loss, + self._graph_state().loss_reduction, + self._graph_state().number_of_towers) + return self._get_optimizer().compute_gradients(loss, *args, **kwargs) + + def apply_gradients(self, grads_and_vars, global_step=None, **kwargs): + """Collect gradients updates to apply them with the last tower.""" + if self._graph_state().number_of_towers == 1: + # Avoid the overhead of reduction if there's only one tower. + # + # There assumed to be only one tower if aggregation-related methods were + # not called by `_get_loss_towers`, for example if the model_fn uses + # TowerEstimator, but `replicate_model_fn` isn't used. + return self._get_optimizer().apply_gradients(grads_and_vars, global_step, + **kwargs) + + self._graph_state().collect_gradients(grads_and_vars) + + if not self._graph_state().is_the_last_tower: + with ops_lib.control_dependencies(_extract_tensors(grads_and_vars)): + return self._construct_no_op_train_op() + else: + # Gradients need to be gathered and applied in the scope of the first + # tower, so that the tensors are accessible via names without prefixes. + var_scope, name_scope = self._graph_state().scopes_of_the_first_tower + with variable_scope.variable_scope(var_scope): + with ops_lib.name_scope(name_scope): + return self._apply_gathered_gradients(global_step, **kwargs) + + def _apply_gathered_gradients(self, global_step, **kwargs): + graph_state = self._graph_state() + optimizer = self._get_optimizer() + + grad_lists = {} + for grad, var in graph_state.get_latest_gradients_from_all_towers(): + if grad is not None: + grad_lists.setdefault(var, []).append(grad) + + aggregated_grads = [] + with ops_lib.name_scope('gradient_aggregating'): + for var, grads in six.iteritems(grad_lists): + grad = _compute_sum_on_device(grads, var.device) + aggregated_grads.append((grad, var)) + return optimizer.apply_gradients( + aggregated_grads, global_step=global_step, **kwargs) + + def _get_optimizer(self): + if callable(self._optimizer_or_optimizer_fn): + # If optimizer is given as a function then we need to wait till we are + # under the right graph context before constructing it. That's why the + # optimizer is constructed in _get_optimizer() rather than __init__(). + self._optimizer_or_optimizer_fn = self._optimizer_or_optimizer_fn() + self._graph_state().has_tower_optimizer_been_used = True + return self._optimizer_or_optimizer_fn + + def _construct_no_op_train_op(self): + return control_flow_ops.no_op(name='train_op_placeholder') + + @staticmethod + def _graph_state(): + graph_states = ops_lib.get_default_graph().get_collection_ref( + TowerOptimizer.COLLECTION_FOR_GRAPH_STATES) + if not graph_states: + graph_states.append(TowerOptimizer._PerGraphState()) + return graph_states[-1] + + @staticmethod + def _did_towers_have_same_optimizer_calls(): + graph_state = TowerOptimizer._graph_state() + return graph_state.did_towers_have_same_optimizer_calls() + + @staticmethod + def _clear_graph_state(): + # Clearing the Graph collection will prevent _PerGraphState from being + # serialized. + ops_lib.get_default_graph().clear_collection( + TowerOptimizer.COLLECTION_FOR_GRAPH_STATES) + + class _PerGraphState(object): + """Gradient reduction related state of a Tensorflow graph.""" + + def __init__(self): + self._collected_grads_and_vars = defaultdict(list) + self._current_tower_index = 0 + self._number_of_towers = 1 + self._loss_reduction = None + # Scopes of the first tower that don't have a prefix: + self._variable_scope = None + self._name_scope = None + # If needed, alert that TowerOptimizer needs to be used with model_fn. + self._has_tower_optimizer_been_used = False + + def collect_gradients(self, grads_and_vars): + self._collected_grads_and_vars[self._current_tower_index].append( + grads_and_vars) + + def get_latest_gradients_from_all_towers(self): + """Get gradients across towers for the last called optimizer.""" + grads_and_vars = [] + index_of_last_gradients = len( + self._collected_grads_and_vars[self._current_tower_index]) - 1 + for tower_id in range(self._current_tower_index + 1): + grads_and_vars.extend( + self._collected_grads_and_vars[tower_id][index_of_last_gradients]) + return grads_and_vars + + def set_reduction_across_towers(self, loss_reduction, number_of_towers): + self._loss_reduction = loss_reduction + self._number_of_towers = number_of_towers + + @contextmanager + def tower(self, tower_id, var_scope, name_scope): + if tower_id == 0: + self._variable_scope = var_scope + self._name_scope = name_scope + self._current_tower_index = tower_id + yield + + @property + def scopes_of_the_first_tower(self): + return self._variable_scope, self._name_scope + + @property + def is_the_last_tower(self): + return self._current_tower_index == (self._number_of_towers - 1) + + @property + def number_of_towers(self): + return self._number_of_towers + + @property + def loss_reduction(self): + return self._loss_reduction + + @property + def has_tower_optimizer_been_used(self): + return self._has_tower_optimizer_been_used + + @has_tower_optimizer_been_used.setter + def has_tower_optimizer_been_used(self, value): + self._has_tower_optimizer_been_used = value + + def did_towers_have_same_optimizer_calls(self): + total_number_of_grads = sum([ + len(grads) + for _, grads in six.iteritems(self._collected_grads_and_vars) + ]) + return total_number_of_grads % self._number_of_towers == 0 def _get_local_devices(device_type): @@ -222,7 +497,8 @@ def _get_loss_towers(model_fn, params, config, devices, - local_ps_device, + local_ps_devices, + loss_reduction, name_scope_pattern=_DEFAULT_NAME_SCOPE_PATTERN): """Replicate the loss computation across devices.""" tower_specs = [] @@ -234,36 +510,64 @@ def _get_loss_towers(model_fn, if 'config' in model_fn_args: optional_params['config'] = copy.deepcopy(config) + # pylint: disable=protected-access + round_robin_strategy = device_setter_lib._RoundRobinStrategy( + num_tasks=len(local_ps_devices)) + TowerOptimizer._graph_state().set_reduction_across_towers( + loss_reduction, len(devices)) + for i, device in enumerate(devices): is_the_first_tower = (i == 0) device_setter = _local_device_setter( - worker_device=device, ps_device=local_ps_device) + worker_device=device, + ps_devices=local_ps_devices, + ps_strategy=round_robin_strategy) - # We would like to preserve the names of the variables and ops that a user - # might be relying on. Names with prefix are going to resolve to variables - # and ops of the first tower. + # We would like to preserve the names of the variables and ops that the user + # might be relying on. Names without a prefix are going to resolve to + # variables and ops of the first tower. name_scope = name_scope_pattern if is_the_first_tower: name_scope = '' - with variable_scope.variable_scope('', reuse=not is_the_first_tower): - with ops_lib.name_scope(name_scope.format(i)): - with ops_lib.device(device_setter): - labels_shard = None - if labels: - labels_shard = labels[i] - - tower_specs.append( - model_fn( - mode=mode, - features=features[i], - labels=labels_shard, - **optional_params)) + with variable_scope.variable_scope( + '', reuse=not is_the_first_tower) as var_scope: + with ops_lib.name_scope(name_scope.format(i)) as name_scope: + with TowerOptimizer._graph_state().tower( + tower_id=i, var_scope=var_scope, name_scope=name_scope): + with ops_lib.device(device_setter): + labels_shard = None + if labels: + labels_shard = labels[i] + + tower_spec = model_fn( + mode=mode, + features=features[i], + labels=labels_shard, + **optional_params) + + if (tower_spec.train_op is not None and len(devices) > 1 and + not TowerOptimizer.has_been_used()): + raise ValueError('Please wrap optimizers with TowerOptimizer' + ' in order to use replicate_model_fn with' + ' multiple `devices`.') + + # Scaling the loss here doesn't actually affect gradients. Another + # instance of scaling happens inside the TowerOptimizer. + tower_spec = _scale_tower_loss( + tower_spec, loss_reduction, number_of_towers=len(devices)) + tower_specs.append(tower_spec) + + if not TowerOptimizer._did_towers_have_same_optimizer_calls(): + raise ValueError('Each invocation of model_fn was supposed to make the same' + ' optimizer calls.') + TowerOptimizer._clear_graph_state() + # pylint: enable=protected-access return tower_specs -def _local_device_setter(ps_device, worker_device): +def _local_device_setter(worker_device, ps_devices, ps_strategy): """A device setter that puts distributes Var/Ops to PS/workers.""" ps_ops = ['Variable', 'VariableV2', 'VarHandleOp'] @@ -273,7 +577,7 @@ def _local_device_setter(ps_device, worker_device): node_def = op if isinstance(op, node_def_pb2.NodeDef) else op.node_def if node_def.op in ps_ops: ps_device_spec = framework_device.DeviceSpec.from_string( - '{}'.format(ps_device)) + '{}'.format(ps_devices[ps_strategy(op)])) ps_device_spec.merge_from(current_device) return ps_device_spec.to_string() @@ -286,33 +590,33 @@ def _local_device_setter(ps_device, worker_device): return local_device_chooser -def _minimize_towers(tower_specs, optimizer): - """Aggregate and apply gradients for computed losses.""" - grad_lists = {} - for tower_spec in tower_specs: - with ops_lib.device(tower_spec.loss.device): - for grad, var in optimizer.compute_gradients(tower_spec.loss): - if grad is not None: - grad_lists.setdefault(var, []).append(grad) +def _scale_tower_loss(tower_spec, loss_reduction, number_of_towers): + """Produce an EstimatorSpec with approproriately scaled loss.""" + if tower_spec.loss is None: + return tower_spec + + estimator_spec = _asdict(tower_spec) + estimator_spec['loss'] = _scale_loss(tower_spec.loss, loss_reduction, + number_of_towers) + return model_fn_lib.EstimatorSpec(**estimator_spec) - aggregated_grads = [] - with ops_lib.name_scope('gradient_aggregating'): - for var, grads in six.iteritems(grad_lists): - grad = _compute_sum_on_device(grads, var.device) - aggregated_grads.append((grad, var)) - train_op = optimizer.apply_gradients( - aggregated_grads, global_step=training_util.get_global_step()) +def _scale_loss(loss, loss_reduction, number_of_towers): + """If needed, scale down the loss for averaging loss by summing.""" + if loss is None: + return None + if number_of_towers == 1: + return loss - return train_op + if loss_reduction != losses.Reduction.SUM: + return math_ops.div(loss, 1.0 * number_of_towers, name='averaged_loss') + else: + return loss -def _call_optimizer_fn(optimizer_fn, params): - arguments = {} - optimizer_fn_arguments = util.fn_args(optimizer_fn) - if 'params' in optimizer_fn_arguments: - arguments['params'] = params - return optimizer_fn(**arguments) +def _minimize_towers(tower_specs): + """`train_op` of the last tower applies aggregated gradients.""" + return tower_specs[-1].train_op def _compute_sum_on_device(values, device, name=None): @@ -335,7 +639,12 @@ def _train_spec(tower_specs, aggregation_device, aggregated_loss_name='loss'): """Populate replicated EstimatorSpec for `GraphKeys.TRAIN`.""" - estimator_spec = tower_specs[0]._asdict() + # Spec of the last tower is used as the template for the final spec, because + # some `EstimatorSpec.training_hooks` rely on calls made in model_fn. For + # example, `SyncReplicasOptimizerHook` validates the + # `SyncReplicasOptimizer.apply_gradients` call. `TowerEstimator` makes that + # call only in the last tower. + estimator_spec = _asdict(tower_specs[-1]) estimator_spec['mode'] = model_fn_lib.ModeKeys.TRAIN estimator_spec['train_op'] = train_op estimator_spec['loss'] = _compute_sum_on_device( @@ -346,7 +655,7 @@ def _train_spec(tower_specs, def _eval_spec(tower_specs, aggregation_device, aggregated_loss_name='loss'): """Populate replicated EstimatorSpec for `GraphKeys.EVAL`.""" - estimator_spec = tower_specs[0]._asdict() + estimator_spec = _asdict(tower_specs[0]) estimator_spec['mode'] = model_fn_lib.ModeKeys.EVAL estimator_spec['loss'] = _compute_sum_on_device( [spec.loss for spec in tower_specs], aggregation_device, @@ -370,7 +679,7 @@ def _eval_spec(tower_specs, aggregation_device, aggregated_loss_name='loss'): def _reduce_metric_variables(number_of_towers): """Aggregate local variables used in metrics into the first tower.""" if number_of_towers == 1: - return control_flow_ops.no_op() + return control_flow_ops.no_op(name='no_eval_metric_reduction') metric_variables = ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES) variables_per_tower = len(metric_variables) // number_of_towers @@ -414,7 +723,7 @@ def _reduce_metric_variables(number_of_towers): def _predict_spec(tower_specs, aggregation_device): """Populate replicated EstimatorSpec for `GraphKeys.PREDICT`.""" - estimator_spec = tower_specs[0]._asdict() + estimator_spec = _asdict(tower_specs[0]) estimator_spec['mode'] = model_fn_lib.ModeKeys.PREDICT with ops_lib.device(aggregation_device): @@ -465,6 +774,17 @@ def _concat_tensor_dicts(*tensor_dicts): } +def _extract_tensors(tensors_and_vars): + tensors = [] + for tensor_and_var in tensors_and_vars: + tensor, _ = tensor_and_var + if isinstance(tensor, ops_lib.IndexedSlices): + tensors.append(tensor.values) + else: + tensors.append(tensor) + return tensors + + def _dict_concat(*dicts): list_dict = {} for d in dicts: @@ -474,3 +794,19 @@ def _dict_concat(*dicts): for k, v in six.iteritems(d): list_dict.setdefault(k, []).append(v) return list_dict + + +def _asdict(namedtuple): + """Returns a namedtuple as a dictionary. + + This is required because `_asdict()` in Python 3.x.x is broken in classes + that inherit from `collections.namedtuple`. See + https://bugs.python.org/issue24931 for more details. + + Args: + namedtuple: An object that inherits from `collections.namedtuple`. + + Returns: + A dictionary version of the tuple. + """ + return {k: getattr(namedtuple, k) for k in namedtuple._fields} diff --git a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py index ffe69f89b4c4d48d329a1aef3aa3cad2b17b3fdf..03d31226af613960a19ce116b19b30153b1fdcee 100644 --- a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py +++ b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py @@ -40,6 +40,7 @@ from tensorflow.python.framework import ops as ops_lib from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import losses from tensorflow.python.ops import math_ops from tensorflow.python.ops import metrics as metrics_lib from tensorflow.python.ops import variable_scope @@ -49,15 +50,32 @@ from tensorflow.python.platform import gfile from tensorflow.python.platform import test from tensorflow.python.saved_model import signature_constants from tensorflow.python.summary.writer import writer_cache +from tensorflow.python.training import adam +from tensorflow.python.training import device_setter from tensorflow.python.training import gradient_descent +from tensorflow.python.training import training +# TODO(isaprykin): Parametrize all the tests on +# replicate_model_fn._VariableDistributionMode when it's supported. class DNNClassifierIntegrationTest(test_util.TensorFlowTestCase): def setUp(self): self._model_dir = tempfile.mkdtemp() - def test_complete_flow(self): + def test_complete_flow_with_public_version(self): + return self._complete_flow_with_mode(mode=None) + + def test_complete_flow_with_mode_local_ps_server(self): + return self._complete_flow_with_mode( + replicate_model_fn._VariableDistributionMode. + SHARED_LOCAL_PARAMETER_SERVER) + + def test_complete_flow_with_mode_round_robin(self): + return self._complete_flow_with_mode( + replicate_model_fn._VariableDistributionMode.SHARED_ROUND_ROBIN) + + def _complete_flow_with_mode(self, mode): n_classes = 3 input_dimension = 2 batch_size = 12 @@ -96,20 +114,30 @@ class DNNClassifierIntegrationTest(test_util.TensorFlowTestCase): 0., len(x_data), len(x_data), dtype=np.int64)), 1) ] + def optimizer_fn(): + return optimizers.get_optimizer_instance('Adagrad', learning_rate=0.05) + estimator = dnn.DNNClassifier( hidden_units=(2, 2), + # Adagrad is configured with `get_optimizer_instance`, so the function + # form of `TowerOptimizer.__init__` is used. + optimizer=replicate_model_fn.TowerOptimizer(optimizer_fn), feature_columns=feature_columns, n_classes=n_classes, model_dir=self._model_dir) - def optimizer_fn(): - return optimizers.get_optimizer_instance('Adagrad', learning_rate=0.05) + if not mode: # Use the public `replicate_model_fn`. + model_fn = replicate_model_fn.replicate_model_fn( + estimator.model_fn, devices=['/gpu:0', '/gpu:1', '/gpu:2']) + else: + model_fn = replicate_model_fn._replicate_model_fn_with_mode( + estimator.model_fn, + devices=['/gpu:0', '/gpu:1', '/gpu:2'], + loss_reduction=losses.Reduction.SUM, + mode=mode) estimator = estimator_lib.Estimator( - model_fn=replicate_model_fn.replicate_model_fn( - estimator.model_fn, - optimizer_fn, - devices=['/gpu:0', '/gpu:1', '/gpu:2']), + model_fn=model_fn, model_dir=estimator.model_dir, config=estimator.config, params=estimator.params) @@ -134,6 +162,10 @@ class DNNClassifierIntegrationTest(test_util.TensorFlowTestCase): serving_input_receiver_fn) self.assertTrue(gfile.Exists(export_dir)) + # Nothing should be left in the graph so that it doesn't get serialized. + self.assertFalse(ops_lib.get_default_graph().get_collection_ref( + replicate_model_fn.TowerOptimizer.COLLECTION_FOR_GRAPH_STATES)) + def _as_label(self, data_in_float): return np.rint(data_in_float).astype(np.int64) @@ -153,28 +185,24 @@ class ReplicateModelTest(test_util.TensorFlowTestCase): predictions = math_ops.multiply(features, c) - loss = None - if mode is not model_fn_lib.ModeKeys.PREDICT: - loss = losses.absolute_difference( - labels=labels, - predictions=predictions, - reduction=losses.Reduction.SUM) - loss = math_ops.reduce_sum(loss) + loss = losses.absolute_difference( + labels=labels, predictions=predictions, reduction=losses.Reduction.SUM) + loss = math_ops.reduce_sum(loss) metrics = { 'accuracy': metrics_lib.accuracy(labels, predictions), 'auc': metrics_lib.auc(labels, predictions) } + optimizer = replicate_model_fn.TowerOptimizer( + gradient_descent.GradientDescentOptimizer(params['learning_rate'])) + return model_fn_lib.EstimatorSpec( mode=mode, loss=loss, eval_metric_ops=metrics, predictions={'probabilities': predictions}, - train_op=control_flow_ops.no_op()) # This train_op isn't actually used. - - def optimizer_fn(self, params): - return gradient_descent.GradientDescentOptimizer(params['learning_rate']) + train_op=optimizer.minimize(loss)) @property def params(self): @@ -188,7 +216,9 @@ class ReplicateModelTest(test_util.TensorFlowTestCase): with self.test_session() as session: replicated_model_fn = replicate_model_fn.replicate_model_fn( - self.model_fn, self.optimizer_fn, devices=['/gpu:0', '/gpu:1']) + self.model_fn, + loss_reduction=losses.Reduction.SUM, + devices=['/gpu:0', '/gpu:1']) estimator_spec = replicated_model_fn( features, labels, model_fn_lib.ModeKeys.TRAIN, self.params) session.run(variables.global_variables_initializer()) @@ -197,31 +227,71 @@ class ReplicateModelTest(test_util.TensorFlowTestCase): total_loss = (1.0 * 10 - 1.0) + (2.0 * 10 - 2.0) self.assertEqual(total_loss, session.run(estimator_spec.loss)) - # loss' of c is 3. + # derivative of loss = (1*c - 1) + (2*c - 2) is 3. # new value of c = 10 - learning rate * 3 = 7.0. session.run(estimator_spec.train_op) with variable_scope.variable_scope('', reuse=True): c = variable_scope.get_variable('c', dtype=dtypes.float64) self.assertEqual(7.0, session.run(c)) - def test_train_spec_with_optimizer_without_params(self): - - def optimizer_fn_without_params(): - return gradient_descent.GradientDescentOptimizer(learning_rate=1.0) - + def test_train_with_mean_reduction(self): features = np.array([[1.0], [2.0]]) labels = np.array([[1.0], [2.0]]) - with self.test_session() as session: # pylint: disable=unused-variable + with self.test_session() as session: replicated_model_fn = replicate_model_fn.replicate_model_fn( - self.model_fn, - optimizer_fn_without_params, - devices=['/gpu:0', '/gpu:1']) - # This call is going to fail if `replicated_model_fn` is still passing - # `params` inside `optimizer_fn`, even though the latter doesn't take any: + self.model_fn, losses.Reduction.MEAN, devices=['/gpu:0', '/gpu:1']) estimator_spec = replicated_model_fn( features, labels, model_fn_lib.ModeKeys.TRAIN, self.params) - del estimator_spec + session.run(variables.global_variables_initializer()) + + # loss = feature * c - label + total_loss = ((1.0 * 10 - 1.0) + (2.0 * 10 - 2.0)) / 2.0 + self.assertEqual(total_loss, session.run(estimator_spec.loss)) + + # derivative of loss = (1*c - 1)/2 + (2*c - 2)/2 is 1.5. + # It's the same computation as without mean reduction, but the + # loss from every tower is scaled by 1/. + # new value of c = 10 - learning rate * 1.5 = 8.5 + session.run(estimator_spec.train_op) + with variable_scope.variable_scope('', reuse=True): + c = variable_scope.get_variable('c', dtype=dtypes.float64) + self.assertEqual(8.5, session.run(c)) + + def test_train_two_steps_collected_gradients_are_reset_between_steps(self): + with ops_lib.Graph().as_default(): + features = array_ops.placeholder(dtypes.float64) + labels = array_ops.placeholder(dtypes.float64) + + feature_inputs = np.array([[1.0], [2.0]]), np.array([[1.5], [2.5]]) + label_inputs = np.array([[1.0], [2.0]]), np.array([[1.5], [2.5]]) + + # loss = feature * c - label + expected_losses = ((1.0 * 10 - 1.0) + (2.0 * 10 - 2.0), + (1.5 * 7.0 - 1.5) + (2.5 * 7.0 - 2.5)) + # Derivative of the loss is 1.0 + 2.0 for the first step and 1.5 + 2.5 + # for the second. + expected_c = 10.0 - 3.0, 7.0 - 4.0 + + with self.test_session() as session, variable_scope.variable_scope( + '', reuse=variable_scope.AUTO_REUSE): + replicated_model_fn = replicate_model_fn.replicate_model_fn( + self.model_fn, + loss_reduction=losses.Reduction.SUM, + devices=['/gpu:0', '/gpu:1']) + estimator_spec = replicated_model_fn( + features, labels, model_fn_lib.ModeKeys.TRAIN, self.params) + session.run(variables.global_variables_initializer()) + + for feature_input, label_input, loss, weight in zip( + feature_inputs, label_inputs, expected_losses, expected_c): + feeds = {features: feature_input, labels: label_input} + + self.assertEqual(loss, session.run(estimator_spec.loss, feeds)) + + session.run(estimator_spec.train_op, feeds) + c = variable_scope.get_variable('c', dtype=dtypes.float64) + self.assertEqual(weight, session.run(c, feeds)) def test_eval(self): features = np.array([[0.01], [0.002]]) @@ -229,7 +299,9 @@ class ReplicateModelTest(test_util.TensorFlowTestCase): with self.test_session() as session: replicated_model_fn = replicate_model_fn.replicate_model_fn( - self.model_fn, self.optimizer_fn, devices=['/gpu:0', '/gpu:1']) + self.model_fn, + loss_reduction=losses.Reduction.SUM, + devices=['/gpu:0', '/gpu:1']) estimator_spec = replicated_model_fn( features, labels, model_fn_lib.ModeKeys.EVAL, self.params) session.run(variables.local_variables_initializer()) @@ -252,13 +324,42 @@ class ReplicateModelTest(test_util.TensorFlowTestCase): self.assertEqual(0, auc) self.assertNear(total_loss, session.run(estimator_spec.loss), 0.01) + def test_eval_with_mean_reduction(self): + features = np.array([[0.01], [0.002]]) + labels = np.array([[0.01], [0.02]]) + + with self.test_session() as session: + replicated_model_fn = replicate_model_fn.replicate_model_fn( + self.model_fn, losses.Reduction.MEAN, devices=['/gpu:0', '/gpu:1']) + estimator_spec = replicated_model_fn( + features, labels, model_fn_lib.ModeKeys.EVAL, self.params) + session.run(variables.local_variables_initializer()) + session.run(variables.global_variables_initializer()) + + accuracy, a = estimator_spec.eval_metric_ops['accuracy'] + auc, b = estimator_spec.eval_metric_ops['auc'] + + session.run([a, b]) + accuracy = session.run(accuracy) + auc = session.run(auc) + + # loss[i] = features[i] * 10 - labels[i]. + # Accuracy is 0.0 (no match) in the first tower. + # Accuracy is 1.0 (match) in the second tower, since the feature + # times weight "c" happened to be equal to the label. + total_loss = ((0.01 * 10 - 0.01) + (0.002 * 10 - 0.02)) / 2.0 + + self.assertNear((0.0 + 1.0) / 2.0, accuracy, 0.01) + self.assertEqual(0, auc) + self.assertNear(total_loss, session.run(estimator_spec.loss), 0.01) + def test_predict(self): features = np.array([[0.01], [0.002]]) labels = np.array([[0.01], [0.02]]) with self.test_session() as session: replicated_model_fn = replicate_model_fn.replicate_model_fn( - self.model_fn, self.optimizer_fn, devices=['/gpu:0', '/gpu:1']) + self.model_fn, devices=['/gpu:0', '/gpu:1']) estimator_spec = replicated_model_fn( features, labels, model_fn_lib.ModeKeys.PREDICT, self.params) session.run(variables.global_variables_initializer()) @@ -273,7 +374,7 @@ class ReplicateModelTest(test_util.TensorFlowTestCase): with self.test_session() as session: replicated_model_fn = replicate_model_fn.replicate_model_fn( - self.model_fn, self.optimizer_fn) + self.model_fn, devices=['/gpu:0']) estimator_spec = replicated_model_fn( features, labels, model_fn_lib.ModeKeys.TRAIN, self.params) session.run(variables.global_variables_initializer()) @@ -295,7 +396,7 @@ class ReplicateModelTest(test_util.TensorFlowTestCase): with self.test_session() as session: replicated_model_fn = replicate_model_fn.replicate_model_fn( - self.model_fn, self.optimizer_fn, devices=['/gpu:0']) + self.model_fn, devices=['/gpu:0']) estimator_spec = replicated_model_fn( features, labels, model_fn_lib.ModeKeys.EVAL, self.params) session.run(variables.local_variables_initializer()) @@ -323,7 +424,7 @@ class ReplicateModelTest(test_util.TensorFlowTestCase): with self.test_session() as session: replicated_model_fn = replicate_model_fn.replicate_model_fn( - self.model_fn, self.optimizer_fn, devices=['/gpu:0']) + self.model_fn, devices=['/gpu:0']) estimator_spec = replicated_model_fn( features, labels, model_fn_lib.ModeKeys.PREDICT, self.params) session.run(variables.global_variables_initializer()) @@ -332,6 +433,412 @@ class ReplicateModelTest(test_util.TensorFlowTestCase): 'probabilities': np.array([[0.1], [0.02]]) }, session.run(estimator_spec.predictions)) + def test_unsupported_loss_reduction(self): + with self.assertRaisesRegexp(ValueError, + '.+none.+reduction.+is.+specified.+'): + _ = replicate_model_fn.replicate_model_fn(self.model_fn, + losses.Reduction.NONE) + + +class ReplicateAcrossASingleDeviceWithoutTowerOptimizer( + test_util.TensorFlowTestCase): + + def model_fn(self, mode, features, labels, params): + c = variable_scope.get_variable( + 'c', + initializer=constant_op.constant(10, dtype=dtypes.float64), + dtype=dtypes.float64) + + predictions = math_ops.multiply(features, c) + + loss = losses.absolute_difference( + labels=labels, predictions=predictions, reduction=losses.Reduction.SUM) + loss = math_ops.reduce_sum(loss) + + metrics = { + 'accuracy': metrics_lib.accuracy(labels, predictions), + 'auc': metrics_lib.auc(labels, predictions) + } + + optimizer = gradient_descent.GradientDescentOptimizer( + params['learning_rate']) + + return model_fn_lib.EstimatorSpec( + mode=mode, + loss=loss, + eval_metric_ops=metrics, + predictions={'probabilities': predictions}, + train_op=optimizer.minimize(loss)) + + @property + def params(self): + params = {} + params['learning_rate'] = 1.0 + return params + + def test_train_single_tower(self): + features = np.array([[1.0], [2.0]]) + labels = np.array([[1.0], [2.0]]) + + with self.test_session() as session: + replicated_model_fn = replicate_model_fn.replicate_model_fn( + self.model_fn, devices=['/gpu:0']) + estimator_spec = replicated_model_fn( + features, labels, model_fn_lib.ModeKeys.TRAIN, self.params) + session.run(variables.global_variables_initializer()) + + # loss = feature * c - label + total_loss = (1.0 * 10 - 1.0) + (2.0 * 10 - 2.0) + self.assertEqual(total_loss, session.run(estimator_spec.loss)) + + # loss' of c is 3. + # new value of c = 10 - learning rate * 3 = 7.0. + session.run(estimator_spec.train_op) + with variable_scope.variable_scope('', reuse=True): + c = variable_scope.get_variable('c', dtype=dtypes.float64) + self.assertEqual(7.0, session.run(c)) + + +class UseTowerEstimatorWithoutReplication(test_util.TensorFlowTestCase): + + def model_fn(self, mode, features, labels, params): + c = variable_scope.get_variable( + 'c', + initializer=constant_op.constant(10, dtype=dtypes.float64), + dtype=dtypes.float64) + + features = features['features'] + predictions = math_ops.multiply(features, c) + + loss = losses.absolute_difference( + labels=labels, predictions=predictions, reduction=losses.Reduction.SUM) + loss = math_ops.reduce_sum(loss) + + metrics = { + 'accuracy': metrics_lib.accuracy(labels, predictions), + 'auc': metrics_lib.auc(labels, predictions) + } + + optimizer = replicate_model_fn.TowerOptimizer( + gradient_descent.GradientDescentOptimizer(params['learning_rate'])) + + return model_fn_lib.EstimatorSpec( + mode=mode, + loss=loss, + eval_metric_ops=metrics, + predictions={'probabilities': predictions}, + train_op=optimizer.minimize(loss)) + + @property + def params(self): + params = {} + params['learning_rate'] = 1.0 + return params + + def test_train_single_tower(self): + features = np.array([[1.0], [2.0]]) + labels = np.array([[1.0], [2.0]]) + + train_input_fn = numpy_io.numpy_input_fn( + x={'features': features}, y=labels, batch_size=2, shuffle=False) + + with self.test_session(): + estimator = estimator_lib.Estimator( + model_fn=self.model_fn, + model_dir=tempfile.mkdtemp(), + params=self.params) + estimator.train(train_input_fn, steps=1) + + self.assertEqual(7.0, estimator.get_variable_value('c')) + + +class MakeSureSyncReplicasOptimizerWorks(test_util.TensorFlowTestCase): + + def model_fn(self, mode, features, labels, params): + c = variable_scope.get_variable( + 'c', + initializer=constant_op.constant(10, dtype=dtypes.float64), + dtype=dtypes.float64) + + features = features['features'] + predictions = math_ops.multiply(features, c) + + loss = losses.absolute_difference( + labels=labels, predictions=predictions, reduction=losses.Reduction.SUM) + loss = math_ops.reduce_sum(loss) + + metrics = { + 'accuracy': metrics_lib.accuracy(labels, predictions), + 'auc': metrics_lib.auc(labels, predictions) + } + + optimizer = gradient_descent.GradientDescentOptimizer( + params['learning_rate']) + optimizer = training.SyncReplicasOptimizer( + optimizer, replicas_to_aggregate=1) + sync_hook = optimizer.make_session_run_hook(True) + optimizer = replicate_model_fn.TowerOptimizer(optimizer) + + return model_fn_lib.EstimatorSpec( + mode=mode, + loss=loss, + eval_metric_ops=metrics, + training_hooks=[sync_hook], + predictions={'probabilities': predictions}, + train_op=optimizer.minimize( + loss, global_step=training.get_global_step())) + + @property + def params(self): + params = {} + params['learning_rate'] = 1.0 + return params + + def test_train_multiple_towers(self): + features = np.array([[1.0], [2.0]]) + labels = np.array([[1.0], [2.0]]) + + train_input_fn = numpy_io.numpy_input_fn( + x={'features': features}, y=labels, batch_size=2, shuffle=False) + + model_fn = replicate_model_fn.replicate_model_fn( + self.model_fn, + loss_reduction=losses.Reduction.SUM, + devices=['/gpu:0', '/gpu:1']) + + estimator = estimator_lib.Estimator( + model_fn=model_fn, model_dir=tempfile.mkdtemp(), params=self.params) + estimator.train(train_input_fn, steps=1) + + self.assertEqual(7.0, estimator.get_variable_value('c')) + + +class ReplicateWithTwoOptimizersTest(test_util.TensorFlowTestCase): + + def model_fn(self, mode, features, labels, params): + c = variable_scope.get_variable( + 'c', + initializer=constant_op.constant(10, dtype=dtypes.float64), + dtype=dtypes.float64) + + side_effects = variable_scope.get_variable( + 'side_effects', + initializer=constant_op.constant(0, dtype=dtypes.float64), + dtype=dtypes.float64, + trainable=False) + + predictions = math_ops.multiply(features, c) + + loss = losses.absolute_difference( + labels=labels, predictions=predictions, reduction=losses.Reduction.SUM) + loss = math_ops.reduce_sum(loss) + + metrics = { + 'accuracy': metrics_lib.accuracy(labels, predictions), + 'auc': metrics_lib.auc(labels, predictions) + } + + first_optimizer = replicate_model_fn.TowerOptimizer( + gradient_descent.GradientDescentOptimizer(1.0)) + second_optimizer = replicate_model_fn.TowerOptimizer( + adam.AdamOptimizer(1.0)) + + with ops_lib.control_dependencies([side_effects.assign_add(1.0)]): + first_grads_and_vars = first_optimizer.compute_gradients(loss) + + train_op = control_flow_ops.group( + [first_optimizer.apply_gradients(first_grads_and_vars), + second_optimizer.minimize(loss)]) + + return model_fn_lib.EstimatorSpec( + mode=mode, + loss=loss, + eval_metric_ops=metrics, + predictions={'probabilities': predictions}, + train_op=train_op) + + def test_train(self): + features = np.array([[1.0], [2.0]]) + labels = np.array([[1.0], [2.0]]) + + with self.test_session() as session: + replicated_model_fn = replicate_model_fn.replicate_model_fn( + self.model_fn, + loss_reduction=losses.Reduction.SUM, + devices=['/gpu:0', '/gpu:1']) + estimator_spec = replicated_model_fn(features, labels, + model_fn_lib.ModeKeys.TRAIN, {}) + session.run(variables.global_variables_initializer()) + + # loss = feature * c - label + total_loss = (1.0 * 10 - 1.0) + (2.0 * 10 - 2.0) + self.assertEqual(total_loss, session.run(estimator_spec.loss)) + + # loss' of c is 3. + # new value of c = 10 - learning rate * 3 = 7.0. + # Adam subtracts another ~1. + session.run(estimator_spec.train_op) + with variable_scope.variable_scope('', reuse=True): + c = variable_scope.get_variable('c', dtype=dtypes.float64) + self.assertNear(6.0, session.run(c), 0.000001) + + side_effects = variable_scope.get_variable( + 'side_effects', dtype=dtypes.float64) + self.assertNear(2.0, session.run(side_effects), 0.000001) + + +class ReplicateWithTwoLossesAndOneOptimizer(test_util.TensorFlowTestCase): + + def setUp(self): + self._should_skip_optimizer = False + self._towers_left_before_skipping_optimizer = -1 + + def incorrectly_skip_optimizer_for_tower(self, tower_number): + self._should_skip_optimizer = True + self._towers_left_before_skipping_optimizer = tower_number + + def should_skip_optimizer(self): + if not self._should_skip_optimizer: + return False + if self._towers_left_before_skipping_optimizer == 0: + return True + else: + self._towers_left_before_skipping_optimizer -= 1 + return False + + def model_fn(self, mode, features, labels, params): + c = variable_scope.get_variable( + 'c', + initializer=constant_op.constant(10, dtype=dtypes.float64), + dtype=dtypes.float64) + d = variable_scope.get_variable( + 'd', + initializer=constant_op.constant(2, dtype=dtypes.float64), + dtype=dtypes.float64) + + predictions = math_ops.multiply(features, c) + + loss = losses.absolute_difference( + labels=labels, predictions=predictions, reduction=losses.Reduction.SUM) + loss = math_ops.reduce_sum(loss) + + another_predictions = math_ops.multiply(features, d) + another_loss = losses.absolute_difference( + labels=labels, + predictions=another_predictions, + reduction=losses.Reduction.SUM) + another_loss = math_ops.reduce_sum(another_loss) + + total_loss = math_ops.add(loss, another_loss) + + metrics = { + 'accuracy': metrics_lib.accuracy(labels, predictions), + 'auc': metrics_lib.auc(labels, predictions) + } + + train_ops = [] + + optimizer = replicate_model_fn.TowerOptimizer( + gradient_descent.GradientDescentOptimizer(1.0)) + train_ops.append(optimizer.minimize(loss, var_list=[c])) + if not self.should_skip_optimizer(): + another_optimizer = replicate_model_fn.TowerOptimizer( + gradient_descent.GradientDescentOptimizer(1.0)) + train_ops.append(another_optimizer.minimize(another_loss, var_list=[d])) + + train_op = control_flow_ops.group(train_ops) + return model_fn_lib.EstimatorSpec( + mode=mode, + loss=total_loss, + eval_metric_ops=metrics, + predictions={'probabilities': predictions}, + train_op=train_op) + + def test_train(self): + features = np.array([[1.0], [2.0]]) + labels = np.array([[1.0], [2.0]]) + + with self.test_session() as session: + replicated_model_fn = replicate_model_fn.replicate_model_fn( + self.model_fn, + loss_reduction=losses.Reduction.SUM, + devices=['/gpu:0', '/gpu:1']) + estimator_spec = replicated_model_fn(features, labels, + model_fn_lib.ModeKeys.TRAIN, {}) + session.run(variables.global_variables_initializer()) + + # For each tower, loss = (feature * c - label) + (feature * d - label). + total_loss = (1.0 * 10 - 1.0 + 1.0 * 2.0 - 1.0) + ( + 2.0 * 10 - 2.0 + 2.0 * 2.0 - 2.0) + self.assertEqual(total_loss, session.run(estimator_spec.loss)) + + session.run(estimator_spec.train_op) + + # loss' of c or loss' of d is 3. + # new value of c = 10 - learning rate * 3 = 7.0. + # new value of d = 2 - learning rate * 3 = -1.0. + with variable_scope.variable_scope('', reuse=True): + c = variable_scope.get_variable('c', dtype=dtypes.float64) + self.assertNear(7.0, session.run(c), 0.000001) + d = variable_scope.get_variable('d', dtype=dtypes.float64) + self.assertNear(-1.0, session.run(d), 0.000001) + + def test_different_optimizer_calls_within_towers(self): + self.incorrectly_skip_optimizer_for_tower(1) + + features = np.array([[1.0], [2.0]]) + labels = np.array([[1.0], [2.0]]) + + with self.test_session(), ops_lib.Graph().as_default(): + with self.assertRaisesRegexp( + ValueError, '.+was.+supposed.+to.+make.+same.+optimizer.+calls.+'): + replicated_model_fn = replicate_model_fn.replicate_model_fn( + self.model_fn, devices=['/gpu:0', '/gpu:1']) + _ = replicated_model_fn(features, labels, model_fn_lib.ModeKeys.TRAIN, + {}) + + +class FailToWrapOptimizerInTheModelFn(test_util.TensorFlowTestCase): + + def model_fn(self, mode, features, labels, params): + c = variable_scope.get_variable( + 'c', + initializer=constant_op.constant(10, dtype=dtypes.float64), + dtype=dtypes.float64) + + predictions = math_ops.multiply(features, c) + + loss = losses.absolute_difference( + labels=labels, predictions=predictions, reduction=losses.Reduction.SUM) + loss = math_ops.reduce_sum(loss) + + metrics = { + 'accuracy': metrics_lib.accuracy(labels, predictions), + 'auc': metrics_lib.auc(labels, predictions) + } + + optimizer = gradient_descent.GradientDescentOptimizer(1.0) + train_op = optimizer.minimize(loss) + + return model_fn_lib.EstimatorSpec( + mode=mode, + loss=loss, + eval_metric_ops=metrics, + predictions={'probabilities': predictions}, + train_op=train_op) + + def test_train(self): + features = np.array([[1.0], [2.0]]) + labels = np.array([[1.0], [2.0]]) + + with self.test_session(): + with self.assertRaisesRegexp(ValueError, + 'Please.+wrap.+with.+TowerOptimizer'): + replicated_model_fn = replicate_model_fn.replicate_model_fn( + self.model_fn, devices=['/gpu:0', '/gpu:1']) + _ = replicated_model_fn(features, labels, model_fn_lib.ModeKeys.TRAIN, + {}) + class GetLossTowersTest(test_util.TensorFlowTestCase): @@ -358,8 +865,9 @@ class GetLossTowersTest(test_util.TensorFlowTestCase): labels=[[0.6], [0.6]], params=None, config=None, + loss_reduction=losses.Reduction.SUM, devices=['/gpu:0', '/gpu:1'], - local_ps_device='/gpu:0', + local_ps_devices=['/gpu:0'], name_scope_pattern='test_tower_{}') session.run(variables.global_variables_initializer()) @@ -382,6 +890,89 @@ class GetLossTowersTest(test_util.TensorFlowTestCase): c = variable_scope.get_variable('c', dtype=dtypes.float64) self.assertEqual(0.25, session.run(c)) + def test_gradients_are_computed_with_mean_reduction(self): + with self.test_session() as session: + tower_specs = replicate_model_fn._get_loss_towers( + self.model_fn, + mode=model_fn_lib.ModeKeys.EVAL, + features=[[0.6], [1.6]], + labels=[[0.6], [0.6]], + params=None, + loss_reduction=losses.Reduction.MEAN, + config=None, + devices=['/gpu:0', '/gpu:1'], + local_ps_devices=['/gpu:0'], + name_scope_pattern='test_tower_{}') + session.run(variables.global_variables_initializer()) + + self.assertEqual(len(tower_specs), 2) + + self.assertEqual('/device:GPU:0', tower_specs[0].loss.device) + self.assertEqual('averaged_loss:0', tower_specs[0].loss.name) + self.assertEqual(0.5, session.run(tower_specs[0].loss)) + + self.assertEqual('/device:GPU:1', tower_specs[1].loss.device) + self.assertEqual('test_tower_1/averaged_loss:0', tower_specs[1].loss.name) + # The input batch for the second tower had a loss that is 1.0 + # bigger: 0.6 vs 1.6. + self.assertEqual(1.0, session.run(tower_specs[1].loss)) + + self.assertEqual(1, len(variables.global_variables())) + self.assertEqual(1, len(variables.trainable_variables())) + + with variable_scope.variable_scope('', reuse=True): + c = variable_scope.get_variable('c', dtype=dtypes.float64) + self.assertEqual(0.25, session.run(c)) + + def test_variables_are_round_robined_correctly(self): + """Test that creates multiple variables and tests round-robin placement.""" + + def model_fn(mode, features, labels, params): + del params + for variable_name in ['a', 'b', 'c', 'd']: + c = variable_scope.get_variable( + variable_name, + initializer=constant_op.constant(0.25, dtype=dtypes.float64), + dtype=dtypes.float64) + + predictions = math_ops.add(np.array([0.1, 0.2, 0.3, features[0]]), c) + labels = np.array([0.1, 0.2, 0.3, labels[0]]) + loss = losses.absolute_difference( + labels=labels, + predictions=predictions, + reduction=losses.Reduction.SUM) + return model_fn_lib.EstimatorSpec( + mode=mode, loss=math_ops.reduce_sum(loss)) + + with self.test_session() as session: + tower_specs = replicate_model_fn._get_loss_towers( + model_fn, + mode=None, + features=[[0.6], [1.6], [2.6]], + labels=[[0.6], [0.6], [2.6]], + params=None, + loss_reduction=losses.Reduction.SUM, + config=None, + devices=['/gpu:0', '/gpu:1', '/gpu:3'], + local_ps_devices=['/gpu:0', '/gpu:1', '/gpu:3'], + name_scope_pattern='test_tower_{}') + session.run(variables.global_variables_initializer()) + + self.assertEqual(len(tower_specs), 3) + self.assertEqual('/device:GPU:0', tower_specs[0].loss.device) + self.assertEqual('/device:GPU:1', tower_specs[1].loss.device) + self.assertEqual('/device:GPU:3', tower_specs[2].loss.device) + + with variable_scope.variable_scope('', reuse=True): + a = variable_scope.get_variable('a', dtype=dtypes.float64) + self.assertEqual('/device:GPU:0', a.device) + b = variable_scope.get_variable('b', dtype=dtypes.float64) + self.assertEqual('/device:GPU:1', b.device) + c = variable_scope.get_variable('c', dtype=dtypes.float64) + self.assertEqual('/device:GPU:3', c.device) + d = variable_scope.get_variable('d', dtype=dtypes.float64) + self.assertEqual('/device:GPU:0', d.device) + class SplitBatchTest(test_util.TensorFlowTestCase): @@ -600,11 +1191,12 @@ class PredictSpecTest(test_util.TensorFlowTestCase): self.model_fn, mode=None, features=[[0.1], [0.2]], + loss_reduction=losses.Reduction.SUM, labels=[[], []], params=None, config=None, devices=['/gpu:0', '/gpu:1'], - local_ps_device='/gpu:0', + local_ps_devices=['/gpu:0'], ) session.run(variables.global_variables_initializer()) @@ -718,16 +1310,14 @@ class ReduceMetricVariablesTest(test_util.TensorFlowTestCase): variables.variables_initializer( ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES))) - with self.assertRaisesRegexp(ValueError, ''): + with self.assertRaisesRegexp( + ValueError, '.+Expected.+local.+variables.+but.+got.+instead.+'): session.run( replicate_model_fn._reduce_metric_variables(number_of_towers=3)) class MergeExportOutputsTest(test_util.TensorFlowTestCase): - def optimizer_fn(self): - return gradient_descent.GradientDescentOptimizer(1.0) - def model_fn(self, mode, features, labels, params): c = variable_scope.get_variable( 'c', @@ -769,7 +1359,6 @@ class MergeExportOutputsTest(test_util.TensorFlowTestCase): loss=math_ops.reduce_sum(loss), eval_metric_ops=metrics, predictions=predictions, - train_op=loss, # This train_op isn't actually used. export_outputs=export_outputs) def replicate_estimator_spec(self, session): @@ -777,13 +1366,13 @@ class MergeExportOutputsTest(test_util.TensorFlowTestCase): labels = np.array([0.01, 0.02]) replicated_model_fn = replicate_model_fn.replicate_model_fn( - self.model_fn, self.optimizer_fn, devices=['/gpu:0', '/gpu:1']) + self.model_fn, devices=['/gpu:0', '/gpu:1']) estimator_spec = replicated_model_fn(features, labels, model_fn_lib.ModeKeys.PREDICT, {}) session.run(variables.global_variables_initializer()) return estimator_spec - def test_merde_predict_output(self): + def test_merge_predict_output(self): with self.test_session() as session: estimator_spec = self.replicate_estimator_spec(session) self.assertAllClose( @@ -850,25 +1439,66 @@ class GetLocalDevicesTest(test_util.TensorFlowTestCase): class LocalDeviceSetterTest(test_util.TensorFlowTestCase): def test_vars_are_on_ps_but_ops_are_on_workers(self): + ps_devices = ['/device:GPU:3'] + round_robin = device_setter._RoundRobinStrategy(num_tasks=len(ps_devices)) + + local_device_setter = replicate_model_fn._local_device_setter( + ps_devices=ps_devices, + ps_strategy=round_robin, + worker_device='/device:GPU:2') + + with ops_lib.device(local_device_setter): + a = variables.Variable(0.01) + self.assertEqual('/device:GPU:3', a.device) + + b = variables.Variable(0.02) + self.assertEqual('/device:GPU:3', b.device) + + c = variables.Variable(0.03) + self.assertEqual('/device:GPU:3', c.device) + + a_op = array_ops.concat(a, axis=0) + self.assertEqual('/device:GPU:2', a_op.device) + + b_op = array_ops.concat(b, axis=0) + self.assertEqual('/device:GPU:2', b_op.device) + + def test_round_robin_placement(self): + ps_devices = [ + '/device:GPU:0', '/device:GPU:1', '/device:GPU:3', '/device:GPU:4' + ] + round_robin = device_setter._RoundRobinStrategy(num_tasks=len(ps_devices)) + local_device_setter = replicate_model_fn._local_device_setter( - ps_device='/device:GPU:3', worker_device='/device:GPU:2') + ps_devices=ps_devices, + ps_strategy=round_robin, + worker_device='/device:GPU:2') with ops_lib.device(local_device_setter): - c = variables.Variable(0.01) + a = variables.Variable(0.01) + self.assertEqual('/device:GPU:0', a.device) + + b = variables.Variable(0.02) + self.assertEqual('/device:GPU:1', b.device) + + c = variables.Variable(0.03) self.assertEqual('/device:GPU:3', c.device) - cc = variables.Variable(0.02) - self.assertEqual('/device:GPU:3', cc.device) + a_op = array_ops.concat(a, axis=0) + self.assertEqual('/device:GPU:2', a_op.device) - ccc = variables.Variable(0.03) - self.assertEqual('/device:GPU:3', ccc.device) + b_op = array_ops.concat(b, axis=0) + self.assertEqual('/device:GPU:2', b_op.device) + + c = variables.Variable(0.03) + self.assertEqual('/device:GPU:4', c.device) + + d = variables.Variable(0.03) + self.assertEqual('/device:GPU:0', d.device) c_op = array_ops.concat(c, axis=0) self.assertEqual('/device:GPU:2', c_op.device) - cc_op = array_ops.concat(cc, axis=0) - self.assertEqual('/device:GPU:2', cc_op.device) - class ComputeSumWithDevicePlacementTest(test_util.TensorFlowTestCase): @@ -939,7 +1569,7 @@ class ComputeSumWithDevicePlacementTest(test_util.TensorFlowTestCase): dense_shape=constant_op.constant([2])) b = ops_lib.IndexedSlices(constant_op.constant([3.0, 4.0]), [0, 1]) - with self.assertRaisesRegexp(ValueError, ''): + with self.assertRaisesRegexp(ValueError, '.+name.+not.+expected.+'): _ = replicate_model_fn._compute_sum_on_device( [a, b], device='/device:GPU:0', name='cant_name_indexslices') diff --git a/tensorflow/contrib/factorization/examples/BUILD b/tensorflow/contrib/factorization/examples/BUILD index 363baa121ab3854a802ca3606e35597d31b35a57..bbe842bd5ccc7357805adda1df42ba8799fcd8f2 100644 --- a/tensorflow/contrib/factorization/examples/BUILD +++ b/tensorflow/contrib/factorization/examples/BUILD @@ -21,3 +21,14 @@ tf_py_test( ], tags = ["notsan"], ) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), +) diff --git a/tensorflow/contrib/factorization/python/ops/clustering_ops.py b/tensorflow/contrib/factorization/python/ops/clustering_ops.py index 96cc80ce241347ebca5b68140f1b1c8b9898ae72..6d3acb2750743318aad83991bc1e89d64c329423 100644 --- a/tensorflow/contrib/factorization/python/ops/clustering_ops.py +++ b/tensorflow/contrib/factorization/python/ops/clustering_ops.py @@ -261,8 +261,8 @@ class KMeans(object): inp, clusters, 1) if self._distance_metric == COSINE_DISTANCE: distances *= 0.5 - output.append((score, array_ops.squeeze(distances), - array_ops.squeeze(indices))) + output.append((score, array_ops.squeeze(distances, [-1]), + array_ops.squeeze(indices, [-1]))) return zip(*output) def _clusters_l2_normalized(self): diff --git a/tensorflow/contrib/factorization/python/ops/gmm.py b/tensorflow/contrib/factorization/python/ops/gmm.py index 0d67e09f8151b48c97094b6b48f26e63443707ef..f72280c4ecf19e33278ffe74061f44bbb7b21709 100644 --- a/tensorflow/contrib/factorization/python/ops/gmm.py +++ b/tensorflow/contrib/factorization/python/ops/gmm.py @@ -24,7 +24,7 @@ import numpy as np from tensorflow.contrib import framework from tensorflow.contrib.factorization.python.ops import gmm_ops from tensorflow.contrib.framework.python.framework import checkpoint_utils -from tensorflow.contrib.framework.python.ops import variables +from tensorflow.python.training import training_util from tensorflow.contrib.learn.python.learn.estimators import estimator from tensorflow.contrib.learn.python.learn.estimators import model_fn as model_fn_lib from tensorflow.python.framework import constant_op @@ -167,7 +167,7 @@ class GMM(estimator.Estimator): self._num_clusters, self._random_seed, self._covariance_type, self._params) - incr_step = state_ops.assign_add(variables.get_global_step(), 1) + incr_step = state_ops.assign_add(training_util.get_global_step(), 1) loss = math_ops.reduce_sum(losses) training_op = with_dependencies([training_op, incr_step], loss) training_hooks = [_InitializeClustersHook( diff --git a/tensorflow/contrib/factorization/python/ops/kmeans_test.py b/tensorflow/contrib/factorization/python/ops/kmeans_test.py index 4709d7942583f1406a3fa0ff3a078d0283872ea6..f9598bfc08c05ea3bba88b3135da0cf2e6bb0c95 100644 --- a/tensorflow/contrib/factorization/python/ops/kmeans_test.py +++ b/tensorflow/contrib/factorization/python/ops/kmeans_test.py @@ -194,15 +194,7 @@ class KMeansTest(KMeansTestBase): score = kmeans.score(input_fn=self.input_fn(batch_size=self.num_points)) self.assertNear(self.true_score, score, self.true_score * 0.01) - def test_infer(self): - kmeans = self._kmeans() - # Make a call to fit to initialize the cluster centers. - max_steps = 1 - kmeans.train(input_fn=self.input_fn(), max_steps=max_steps) - clusters = kmeans.cluster_centers() - - # Make a small test set - num_points = 10 + def _infer_helper(self, kmeans, clusters, num_points): points, true_assignments, true_offsets = make_random_points( clusters, num_points) input_fn = self.input_fn(batch_size=num_points, points=points, num_epochs=1) @@ -223,6 +215,17 @@ class KMeansTest(KMeansTestBase): np.sum(np.square(clusters), axis=1, keepdims=True))) self.assertAllClose(transform, true_transform, rtol=0.05, atol=10) + def test_infer(self): + kmeans = self._kmeans() + # Make a call to fit to initialize the cluster centers. + max_steps = 1 + kmeans.train(input_fn=self.input_fn(), max_steps=max_steps) + clusters = kmeans.cluster_centers() + + # Run inference on small datasets. + self._infer_helper(kmeans, clusters, 10) + self._infer_helper(kmeans, clusters, 1) + class KMeansTestMultiStageInit(KMeansTestBase): diff --git a/tensorflow/contrib/ffmpeg/BUILD b/tensorflow/contrib/ffmpeg/BUILD index dc5a04a0b15870babbc98cf104e109caf829901c..eccce99071dc1477cf4f3bb152f3304b3b0fc35a 100644 --- a/tensorflow/contrib/ffmpeg/BUILD +++ b/tensorflow/contrib/ffmpeg/BUILD @@ -155,7 +155,10 @@ tf_py_test( data = [ ":test_data", ], - tags = ["manual"], + tags = [ + "manual", + "notap", + ], ) py_library( diff --git a/tensorflow/contrib/ffmpeg/__init__.py b/tensorflow/contrib/ffmpeg/__init__.py index 871dff7bbe4912f0daf2bc184d6b0f12510abee7..daba965a98893b992abdc598ec713f13020d6e91 100644 --- a/tensorflow/contrib/ffmpeg/__init__.py +++ b/tensorflow/contrib/ffmpeg/__init__.py @@ -26,6 +26,7 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.ffmpeg.ffmpeg_ops import decode_audio +from tensorflow.contrib.ffmpeg.ffmpeg_ops import decode_video from tensorflow.contrib.ffmpeg.ffmpeg_ops import encode_audio from tensorflow.contrib.ffmpeg.ffmpeg_ops import decode_video diff --git a/tensorflow/contrib/ffmpeg/decode_video_op_test.py b/tensorflow/contrib/ffmpeg/decode_video_op_test.py index 4d1fac4ef8afbf44cd45bae065f8a95b0527079a..b43b6b8919223bd7731209d5423b142601396ea5 100644 --- a/tensorflow/contrib/ffmpeg/decode_video_op_test.py +++ b/tensorflow/contrib/ffmpeg/decode_video_op_test.py @@ -20,11 +20,9 @@ from __future__ import print_function import os.path -import six +import six # pylint: disable=unused-import from tensorflow.contrib import ffmpeg -from tensorflow.python.framework import dtypes -from tensorflow.python.ops import array_ops from tensorflow.python.ops import image_ops from tensorflow.python.platform import resource_loader from tensorflow.python.platform import test @@ -32,7 +30,8 @@ from tensorflow.python.platform import test class DecodeVideoOpTest(test.TestCase): - def _loadFileAndTest(self, filename, width, height, frames, bmp_filename, index): + def _loadFileAndTest(self, filename, width, height, frames, bmp_filename, + index): """Loads an video file and validates the output tensor. Args: @@ -40,6 +39,8 @@ class DecodeVideoOpTest(test.TestCase): width: The width of the video. height: The height of the video. frames: The frames of the video. + bmp_filename: The filename for the bmp file. + index: Index location inside the video. """ with self.test_session(): path = os.path.join(resource_loader.get_data_files_path(), 'testdata', @@ -48,7 +49,7 @@ class DecodeVideoOpTest(test.TestCase): contents = f.read() bmp_path = os.path.join(resource_loader.get_data_files_path(), 'testdata', - bmp_filename) + bmp_filename) with open(bmp_path, 'rb') as f: bmp_contents = f.read() @@ -58,7 +59,7 @@ class DecodeVideoOpTest(test.TestCase): video_op = ffmpeg.decode_video(contents) video = video_op.eval() self.assertEqual(video.shape, (frames, height, width, 3)) - self.assertAllEqual(video[index,:,:,:], image) + self.assertAllEqual(video[index, :, :, :], image) def testMp4(self): self._loadFileAndTest('small.mp4', 560, 320, 166, 'small_100.bmp', 99) diff --git a/tensorflow/contrib/ffmpeg/default/BUILD b/tensorflow/contrib/ffmpeg/default/BUILD index 949ae9ad9e4b045ee1b5cc82d49c0e7468c2005d..6b455567d766dbe6d380a498bd7f521db27e077b 100644 --- a/tensorflow/contrib/ffmpeg/default/BUILD +++ b/tensorflow/contrib/ffmpeg/default/BUILD @@ -19,6 +19,7 @@ cc_library( ], deps = [ "//tensorflow/core:framework_headers_lib", + "//third_party/eigen3", "@protobuf_archive//:protobuf_headers", ], ) diff --git a/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc b/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc index 201774e1d011f35df9c3803f2ed8818cc9b1c1c2..1e8af1458cea13b2ddb89b7d93a4ffb8b974ecd2 100644 --- a/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc +++ b/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc @@ -49,7 +49,8 @@ std::vector FfmpegAudioCommandLine(const string& input_filename, "-nostdin", // No interactive commands accepted. "-f", input_format_id, // eg: "mp3" "-probesize", StrCat(kDefaultProbeSize), "-i", input_filename, - "-loglevel", "info", // Enable verbose logging to support debugging. + "-loglevel", "error", // Print errors only. + "-hide_banner", // Skip printing build options, version, etc. "-map_metadata", "-1", // Copy global metadata from input to output. "-vn", // No video recording. "-ac:a:0", StrCat(channel_count), "-ar:a:0", @@ -72,7 +73,8 @@ std::vector FfmpegVideoCommandLine(const string& input_filename, "-probesize", StrCat(kDefaultProbeSize), "-loglevel", - "info", // Enable verbose logging to support debugging. + "error", // Print errors only. + "-hide_banner", // Skip printing build options, version, etc. "-vcodec", "rawvideo", "-pix_fmt", @@ -220,7 +222,8 @@ string BuildWavFile(int32 samples_per_second, int32 channel_count, Status ReadInfoFile(const string& filename, uint32* width, uint32* height, uint32* frames) { string data; - ReadFileToString(Env::Default(), filename, &data); + TF_QCHECK_OK(ReadFileToString(Env::Default(), filename, &data)) + << "Could not read FFmpeg file: " << filename; bool in_output = false; bool in_mapping = false; uint32 frames_value = 0; @@ -377,7 +380,7 @@ Status ReadVideoFile(const string& filename, std::vector* output_data, open(stderr_filename.c_str(), O_RDWR | O_CREAT | O_APPEND, 0600); if (fd < 0) { const int error = errno; - LOG(ERROR) << "FFmpeg stderr file coule not be created: " + LOG(ERROR) << "FFmpeg stderr file could not be created: " << strerror(error); ::_exit(error); } diff --git a/tensorflow/contrib/ffmpeg/default/ffmpeg_lib_utility_test.cc b/tensorflow/contrib/ffmpeg/default/ffmpeg_lib_utility_test.cc index 39e7e90cccf1012eb42261bde55d0dc3b7f278ef..36fc71794b06e0f3cb86c40b325ce50e8999c667 100644 --- a/tensorflow/contrib/ffmpeg/default/ffmpeg_lib_utility_test.cc +++ b/tensorflow/contrib/ffmpeg/default/ffmpeg_lib_utility_test.cc @@ -23,6 +23,7 @@ #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/test.h" diff --git a/tensorflow/contrib/ffmpeg/ffmpeg_ops.py b/tensorflow/contrib/ffmpeg/ffmpeg_ops.py index 78ead471d2cf9f0654a06dc022d7cc592d14c710..08b5a6ea48c2d4959af68a2ee9d27d21c6245457 100644 --- a/tensorflow/contrib/ffmpeg/ffmpeg_ops.py +++ b/tensorflow/contrib/ffmpeg/ffmpeg_ops.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.ffmpeg.ops import gen_decode_audio_op_py +from tensorflow.contrib.ffmpeg.ops import gen_decode_video_op_py from tensorflow.contrib.ffmpeg.ops import gen_encode_audio_op_py from tensorflow.contrib.ffmpeg.ops import gen_decode_video_op_py from tensorflow.contrib.util import loader diff --git a/tensorflow/contrib/framework/BUILD b/tensorflow/contrib/framework/BUILD index 5b659ddaa1386736eb8cc05a203ed1827ccd160e..9e5f54f0973eae899ca65e4098358107053cb7d4 100644 --- a/tensorflow/contrib/framework/BUILD +++ b/tensorflow/contrib/framework/BUILD @@ -11,11 +11,12 @@ package(default_visibility = [ ]) load("//tensorflow:tensorflow.bzl", "py_test") -load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") load("//tensorflow:tensorflow.bzl", "tf_custom_op_library") load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py") load("//tensorflow:tensorflow.bzl", "tf_gen_op_libs") load("//tensorflow:tensorflow.bzl", "tf_kernel_library") +load("//tensorflow:tensorflow.bzl", "cuda_py_test") +load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") tf_custom_op_py_library( name = "framework_py", @@ -31,8 +32,10 @@ tf_custom_op_py_library( "python/ops/arg_scope.py", "python/ops/audio_ops.py", "python/ops/checkpoint_ops.py", + "python/ops/critical_section_ops.py", "python/ops/ops.py", "python/ops/prettyprint_ops.py", + "python/ops/script_ops.py", "python/ops/sort_ops.py", "python/ops/variables.py", ], @@ -60,6 +63,7 @@ tf_custom_op_py_library( "//tensorflow/python:math_ops", "//tensorflow/python:platform", "//tensorflow/python:pywrap_tensorflow", + "//tensorflow/python:script_ops", "//tensorflow/python:sparse_tensor", "//tensorflow/python:state_ops", "//tensorflow/python:state_ops_gen", @@ -70,6 +74,7 @@ tf_custom_op_py_library( "//tensorflow/python:variable_scope", "//tensorflow/python:variables", "//tensorflow/python/eager:context", + "//tensorflow/python/eager:function", "//third_party/py/numpy", "@six_archive//:six", ], @@ -173,6 +178,21 @@ py_test( ], ) +cuda_py_test( + name = "critical_section_test", + size = "medium", + srcs = ["python/ops/critical_section_test.py"], + additional_deps = [ + "//tensorflow/python:client_testlib", + ":framework_py", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:gradients", + "//tensorflow/python:platform_test", + "//tensorflow/python:resource_variable_ops", + ], +) + py_test( name = "accumulate_n_v2_eager_test", size = "small", diff --git a/tensorflow/contrib/framework/__init__.py b/tensorflow/contrib/framework/__init__.py index 4edc77f86ba786ca547b8d3842e2cf02833fbbac..673c51784229bd88011f8b33fb851a2885566220 100644 --- a/tensorflow/contrib/framework/__init__.py +++ b/tensorflow/contrib/framework/__init__.py @@ -81,7 +81,10 @@ See the @{$python/contrib.framework} guide. @@load_linear_multiclass_bias_initializer @@load_variable_slot_initializer +@@py_func @@sort + +@@CriticalSection """ from __future__ import absolute_import diff --git a/tensorflow/contrib/framework/python/framework/graph_util.py b/tensorflow/contrib/framework/python/framework/graph_util.py index 6d5cde5c9e118d372a6532bfc593bd08b9e18a7b..a18ff2320d99726bb355ff6179fc97a070c2fec7 100644 --- a/tensorflow/contrib/framework/python/framework/graph_util.py +++ b/tensorflow/contrib/framework/python/framework/graph_util.py @@ -150,5 +150,5 @@ def get_placeholders(graph): # The return value (a Tensor) of placeholder() is the # first output of this operation in fact. operations = graph.get_operations() - result = [i.outputs[0] for i in operations if i.type == 'Placeholder'] + result = [i.outputs[0] for i in operations if i.type == "Placeholder"] return result diff --git a/tensorflow/contrib/framework/python/framework/graph_util_test.py b/tensorflow/contrib/framework/python/framework/graph_util_test.py index 0722fafc132c0db2ad621f6f9345185f34c643f5..b8a6d109e19211d271c2b15bac66ddacd38fe395 100644 --- a/tensorflow/contrib/framework/python/framework/graph_util_test.py +++ b/tensorflow/contrib/framework/python/framework/graph_util_test.py @@ -90,8 +90,9 @@ class GetPlaceholdersTest(test.TestCase): with ops.Graph().as_default() as g: placeholders = [array_ops.placeholder(dtypes.float32) for _ in range(5)] results = graph_util.get_placeholders(g) - self.assertEqual(sorted(placeholders, key=lambda x: x._id), # pylint: disable=protected-access - sorted(results, key=lambda x: x._id)) # pylint: disable=protected-access + self.assertEqual( + sorted(placeholders, key=lambda x: x._id), # pylint: disable=protected-access + sorted(results, key=lambda x: x._id)) # pylint: disable=protected-access if __name__ == '__main__': diff --git a/tensorflow/contrib/framework/python/ops/__init__.py b/tensorflow/contrib/framework/python/ops/__init__.py index 685bb94779762ce46ee342e7e0a182c54be64743..c4976497f5fa95d82e492153b117681f693eaa13 100644 --- a/tensorflow/contrib/framework/python/ops/__init__.py +++ b/tensorflow/contrib/framework/python/ops/__init__.py @@ -22,8 +22,10 @@ from __future__ import print_function # pylint: disable=wildcard-import from tensorflow.contrib.framework.python.ops.arg_scope import * from tensorflow.contrib.framework.python.ops.checkpoint_ops import * +from tensorflow.contrib.framework.python.ops.critical_section_ops import * from tensorflow.contrib.framework.python.ops.ops import * from tensorflow.contrib.framework.python.ops.prettyprint_ops import * +from tensorflow.contrib.framework.python.ops.script_ops import * from tensorflow.contrib.framework.python.ops.sort_ops import * from tensorflow.contrib.framework.python.ops.variables import * # pylint: enable=wildcard-import diff --git a/tensorflow/contrib/framework/python/ops/critical_section_ops.py b/tensorflow/contrib/framework/python/ops/critical_section_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..182fec924febb74a23b82b1664d137f033f3b1b4 --- /dev/null +++ b/tensorflow/contrib/framework/python/ops/critical_section_ops.py @@ -0,0 +1,324 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Critical Section object and execution logic.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections + +# TODO(ebrevdo): Re-enable once CriticalSection is in core. +# from tensorflow.core.protobuf import critical_section_pb2 + +from tensorflow.python.eager import context +from tensorflow.python.eager import function +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import gen_resource_variable_ops +from tensorflow.python.util import nest + + +# Graph Keys +CRITICAL_SECTIONS = "critical_sections" +CRITICAL_SECTION_EXECUTIONS = "critical_section_executions" + + +class _ExecutionSignature( + collections.namedtuple("_ExecutionSignature", + ("op", "exclusive_resource_access"))): + """A class storing an `ExecuteInCriticalResource` op and associated attrs.""" + pass + + +class CriticalSection(object): + """Critical section. + + A `CriticalSection` object is a resource in the graph which executes subgraphs + in **serial** order. A common example of a subgraph one may wish to run + exclusively is the one given by the following function: + + ```python + v = resource_variable_ops.ResourceVariable(0.0, name="v") + + def count(): + value = v.read_value() + with tf.control_dependencies([value]): + with tf.control_dependencies([v.assign_add(1)]): + return tf.identity(value) + ``` + + Here, a snapshot of `v` is captured in `value`; and then `v` is updated. + The snapshot value is returned. + + If multiple workers or threads all execute `count` in parallel, there is no + guarantee that access to the variable `v` is atomic at any point within + any thread's calculation of `count`. In fact, even implementing an atomic + counter that guarantees that the user will see each value `0, 1, ...,` is + currently impossible. + + The solution is to ensure any access to the underlying resource `v` is + only processed through a critical section: + + ```python + cs = CriticalSection() + f1 = cs.execute(count) + f2 = cs.execute(count) + output = f1 + f2 + session.run(output) + ``` + The functions `f1` and `f2` will be executed serially, and updates to `v` + will be atomic. + + **NOTES** + + All resource objects, including the critical section and any captured + variables of functions executed on that critical section, will be + colocated to the same device (host and cpu/gpu). + + When using multiple critical sections on the same resources, there is no + guarantee of exclusive access to those resources. This behavior is disallowed + by default (but see the kwarg `exclusive_resource_access`). + + For example, running the same function in two separate critical sections + will not ensure serial execution: + + ```python + v = tf.get_variable("v", initializer=0.0, use_resource=True) + def accumulate(up): + x = v.read_value() + with tf.control_dependencies([x]): + with tf.control_dependencies([v.assign_add(up)]): + return tf.identity(x) + ex1 = CriticalSection().execute( + accumulate, 1.0, exclusive_resource_access=False) + ex2 = CriticalSection().execute( + accumulate, 1.0, exclusive_resource_access=False) + bad_sum = ex1 + ex2 + sess.run(v.initializer) + sess.run(bad_sum) # May return 0.0 + ``` + """ + + def __init__(self, name=None, critical_section_def=None, import_scope=None): + """Creates a critical section.""" + if critical_section_def and name is not None: + raise ValueError("critical_section_def and name are mutually exclusive.") + if critical_section_def: + self._init_from_proto(critical_section_def, import_scope=import_scope) + else: + self._init_from_args(name) + + def _init_from_proto(self, critical_section_def, import_scope): + raise NotImplementedError("Not yet implemented") + # TODO(ebrevdo): Re-enable once CriticalSection is in core. + # assert isinstance( + # critical_section_def, critical_section_pb2.CriticalSectionDef) + # # Create from critical_section_def. + # g = ops.get_default_graph() + # self._handle = g.as_graph_element( + # ops.prepend_name_scope( + # critical_section_def.critical_section_name, + # import_scope=import_scope)) + + def _init_from_args(self, name): + """Initialize the CriticalSection from constructor arguments.""" + with ops.name_scope(name, "CriticalSection", []) as name: + with ops.control_dependencies(None): + # pylint: disable=protected-access + handle_name = ops._name_from_scope_name(name) + container = ops.get_default_graph()._container + # pylint: enable=protected-access + if container is None: + container = "" + self._handle = gen_resource_variable_ops.critical_section_op( + shared_name=handle_name, name=name) + if context.in_graph_mode(): + ops.add_to_collections(CRITICAL_SECTIONS, self) + + @property + def name(self): + return self._handle.op.name + + def execute(self, fn, *args, **kwargs): + """Execute function `fn(*args, **kwargs)` inside the CriticalSection. + + Args: + fn: The function to execute. Must return at least one tensor. + *args: Additional positional arguments to `fn`. + **kwargs: Additional keyword arguments to `fn`. + Several keywords are reserved for `execute`. These are: + + - name; The name to use when creating the execute operation. + - exclusive_resource_access; Whether the resources required by + `fn` should be exclusive to this `CriticalSection`. Default: `True`. + You may want to set this to `False` if you will be accessing a + resource in read-only mode in two different CriticalSections. + + Returns: + The tensors returned from `fn(*args, **kwargs)`. + + Raises: + ValueError: If `fn` attempts to use this `CriticalSection` in any nested + way. + ValueError: If `exclusive_resource_access` is not provided (is `True`) and + another `CriticalSection` has an execution requesting the same + resources as in `*args`, `**kwargs`, and any additionaly captured + inputs in `fn`. Note, even if `exclusive_resource_access` is `True`, + if another execution in another `CriticalSection` was created without + `exclusive_resource_access=True`, a `ValueError` will be raised. + """ + name = kwargs.pop("name", None) + exclusive_resource_access = kwargs.pop("exclusive_resource_access", True) + + args = nest.map_structure(ops.convert_to_tensor, args) + with ops.name_scope(name, "critical_section_execute", []): + fn_op = function.make_defun_op(fn, *args, **kwargs) + flat_dtypes = nest.flatten(fn_op.output_dtypes) + flat_shapes = nest.flatten(fn_op.output_shapes) + all_inputs = nest.flatten(args) + fn_op.captured_inputs + if self._handle in all_inputs: + raise ValueError("The function fn attempts to access the " + "CriticalSection in which it would be running. This " + "is illegal and would cause deadlocks. " + "CriticalSection: %s." % self._handle) + + if context.in_graph_mode(): + # Collections and op introspection does not work in eager + # mode. This is generally ok; since eager mode (as of + # writing) executes sequentially anyway. + all_input_resources = [ + x for x in all_inputs if x.dtype == dtypes.resource] + for sg in ops.get_collection(CRITICAL_SECTION_EXECUTIONS): + if sg.op.inputs[0].name == self._handle.name: + # Other executions in the same critical section are allowed. + continue + if not (exclusive_resource_access or sg.exclusive_resource_access): + # Neither execution requested exclusive access. + continue + sg_input_names = [y.name for y in sg.op.inputs[1:]] + for res in all_input_resources: + if res.name in sg_input_names: + raise ValueError( + "This execution would access resource %s; but either this " + "execution (CriticalSection: %s) or Execution '%s' " + "(CriticalSection: %s) requested exclusive resource access " + "of this resource for their critical section. Did you mean " + "to call execute with keyword argument " + "exclusive_resource_access=False?" + % (res.name, + self.name, + sg.op.name, + sg.op.inputs[0].op.name)) + + flat_outputs = gen_resource_variable_ops.execute_in_critical_section( + critical_section=self._handle, + arguments=all_inputs, + f=fn_op, + output_types=flat_dtypes, + output_shapes=flat_shapes) + + if context.in_graph_mode(): + if isinstance(flat_outputs, ops.Operation): + flat_outputs = [flat_outputs] + op = (flat_outputs[0].op if isinstance(flat_outputs[0], ops.Tensor) + else flat_outputs[0]) + signature = _ExecutionSignature( + op=op, + exclusive_resource_access=exclusive_resource_access) + ops.add_to_collections( + CRITICAL_SECTION_EXECUTIONS, signature) + + return (flat_outputs[0] + if (len(flat_outputs) == 1 + and isinstance(flat_outputs[0], ops.Operation)) + else nest.pack_sequence_as(fn_op.output_dtypes, flat_outputs)) + + # TODO(ebrevdo): Re-enable once CriticalSection is in core. + + # def to_proto(self, export_scope=None): + # """Converts a `CriticalSection` to a `CriticalSectoinDef` protocol buffer. + + # Args: + # export_scope: Optional `string`. Name scope to remove. + + # Returns: + # A `CriticalSectionDef` protocol buffer, or `None` if the + # `CriticalSection` is not in the specified name scope. + # """ + # if export_scope is None or self.handle.name.startswith(export_scope): + # cs_def = critical_section_pb2.CriticalSectionDef() + # cs_def.critical_section_name = ops.strip_name_scope( + # self._handle.name, export_scope) + # return cs_def + # else: + # return None + + # @staticmethod + # def from_proto(critical_section_def, import_scope=None): + # return CriticalSection( + # critical_section_def=critical_section_def, import_scope=import_scope) + + +# TODO(ebrevdo): Re-enable once CriticalSection is in core. + +# def _execution_to_proto_fn(execution_signature, export_scope=None): +# """Converts `_ExecutionSignature` to a `CriticalSectionExecutionDef`. + +# Args: +# execution_signature: Instance of `_ExecutionSignature`. +# export_scope: The export scope, if any. + +# Returns: +# An instance of `CriticalSectionExecutionDef`. +# """ +# if (export_scope is None +# or execution_signature.op.name.startswith(export_scope)): +# op_def = critical_section_pb2.CriticalSectionExecutionDef() +# op_def.execute_in_critical_section_name = ops.strip_name_scope( +# execution_signature.op.name, export_scope) +# op_def.exclusive_resource_access = ( +# execution_signature.exclusive_resource_access) +# return op_def +# else: +# return None + + +# def _execution_from_proto_fn(op_def, import_scope=None): +# """Converts a `CriticalSectionExecutionDef` to a `_ExecutionSignature`.""" +# assert isinstance( +# op_def, critical_section_pb2.CriticalSectionExecutionDef) + +# # Create from op_def. +# g = ops.get_default_graph() +# execution_op = g.as_graph_element( +# ops.prepend_name_scope( +# op_def.execute_in_critical_section_name, +# import_scope=import_scope)) +# return _ExecutionSignature( +# op=execution_op, +# exclusive_resource_access=op_def.exclusive_resource_access) + +# ops.register_proto_function( +# CRITICAL_SECTIONS, +# proto_type=critical_section_pb2.CriticalSectionDef, +# to_proto=CriticalSection.to_proto, +# from_proto=CriticalSection.from_proto) + +# ops.register_proto_function( +# CRITICAL_SECTION_EXECUTIONS, +# proto_type=critical_section_pb2.CriticalSectionExecutionDef, +# to_proto=_execution_to_proto_fn, +# from_proto=_execution_from_proto_fn) diff --git a/tensorflow/contrib/framework/python/ops/critical_section_test.py b/tensorflow/contrib/framework/python/ops/critical_section_test.py new file mode 100644 index 0000000000000000000000000000000000000000..a416724d3ba1719471d70667e140f9cd2daf86c7 --- /dev/null +++ b/tensorflow/contrib/framework/python/ops/critical_section_test.py @@ -0,0 +1,178 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""critical section tests.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.framework.python.ops import critical_section_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import function +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.platform import test +# TODO(ebrevdo): Re-enable once CriticalSection is in core. +# from tensorflow.python.training import saver as saver_lib + + +class CriticalSectionTest(test.TestCase): + + @test_util.run_in_graph_and_eager_modes() + def testCreateCriticalSection(self): + cs = critical_section_ops.CriticalSection(name="cs") + v = resource_variable_ops.ResourceVariable(0.0, name="v") + + def fn(a, b): + c = v.read_value() + with ops.control_dependencies([c]): + nv = v.assign_add(a * b) + with ops.control_dependencies([nv]): + return array_ops.identity(c) + + num_concurrent = 1000 + r = [cs.execute(fn, 1.0, 2.0) for _ in range(num_concurrent)] + self.evaluate(v.initializer) + r_value = self.evaluate(r) + self.assertAllClose([2.0 * i for i in range(num_concurrent)], + sorted(r_value)) + + @test_util.run_in_graph_and_eager_modes() + def testCreateCriticalSectionFnReturnsOp(self): + cs = critical_section_ops.CriticalSection(name="cs") + v = resource_variable_ops.ResourceVariable(0.0, name="v") + + def fn_return_op(a, b): + c = v.read_value() + with ops.control_dependencies([c]): + nv = v.assign_add(a * b) + with ops.control_dependencies([nv]): + return () + + num_concurrent = 100 + r = [cs.execute(fn_return_op, 1.0, 2.0) for _ in range(num_concurrent)] + self.evaluate(v.initializer) + self.evaluate(r) + final_v = self.evaluate(v) + self.assertAllClose(2.0 * num_concurrent, final_v) + + def testCreateCriticalSectionRaw(self): + cs = critical_section_ops.CriticalSection(name="cs") + v = resource_variable_ops.ResourceVariable(0.0, name="v") + + @function.Defun(dtypes.float32, dtypes.float32) + def fn(a, b): + c = v.read_value() + with ops.control_dependencies([c]): + nv = v.assign_add(a * b) + with ops.control_dependencies([nv]): + return array_ops.identity(c) + + def execute(fn, *args): + output_args = fn.definition.signature.output_arg + return resource_variable_ops.execute_in_critical_section( + critical_section=cs._handle, + arguments=list(args) + fn.captured_inputs, + f=fn, + output_types=[out.type for out in output_args], + output_shapes=[tensor_shape.TensorShape(None) for _ in output_args]) + + num_concurrent = 1000 + r = [execute(fn, 1.0, 2.0)[0] for _ in range(num_concurrent)] + self.evaluate(v.initializer) + r_value = self.evaluate(r) + self.assertAllClose([2.0 * i for i in range(num_concurrent)], + sorted(r_value)) + + def testCollection(self): + cs = critical_section_ops.CriticalSection(name="cs") + self.assertIn( + cs, ops.get_collection(critical_section_ops.CRITICAL_SECTIONS)) + execute_op = cs.execute(lambda x: x + 1, 1.0).op + self.assertIn( + execute_op, + [signature.op for signature in + ops.get_collection(critical_section_ops.CRITICAL_SECTION_EXECUTIONS)]) + + @test_util.run_in_graph_and_eager_modes() + def testRecursiveCriticalSectionAccessIsIllegal(self): + cs = critical_section_ops.CriticalSection(name="cs") + def fn(x): + return cs.execute(lambda x: x+1, x) + with self.assertRaisesRegexp( + ValueError, + r"attempts to access the CriticalSection in which it would be running"): + cs.execute(fn, 1.0) + + def testMultipleCSExecutionsRequestSameResource(self): + cs0 = critical_section_ops.CriticalSection() + cs1 = critical_section_ops.CriticalSection() + v = resource_variable_ops.ResourceVariable(0.0, name="v") + cs0.execute(lambda: v + 1) + # It's OK for the same CriticalSection to access this resource. + cs0.execute(lambda: v - 1) + # It's *not* OK for a different CriticalSection to access it by + # default. + with self.assertRaisesRegexp( + ValueError, "requested exclusive resource access"): + cs1.execute(lambda: v + 1) + # It's not even OK if the second call doesn't request exclusive access. + with self.assertRaisesRegexp( + ValueError, "requested exclusive resource access"): + cs1.execute(lambda: v + 1, exclusive_resource_access=False) + + v2 = resource_variable_ops.ResourceVariable(0.0, name="v2") + cs0.execute(lambda: v2 + 1, exclusive_resource_access=False) + # It's OK if neither requests exclusive resource access. + cs1.execute(lambda: v2 + 1, exclusive_resource_access=False) + + # It's not OK if the second request requires exlusive resource + # access. + with self.assertRaisesRegexp( + ValueError, "requested exclusive resource access"): + cs1.execute(lambda: v2 + 1) + + # TODO(ebrevdo): Re-enable once CriticalSection is in core. + # + # def testCriticalSectionAndExecuteOpSaverRoundTrip(self): + # cs = critical_section_ops.CriticalSection() + # r = cs.execute(lambda x: x + 1, 1.0) + # graph = ops.get_default_graph() + # meta_graph = saver_lib.export_meta_graph( + # graph=graph, collection_list=graph.get_all_collection_keys()) + # graph_copy = ops.Graph() + # with graph_copy.as_default(): + # _ = saver_lib.import_meta_graph(meta_graph, import_scope="imported") + # restored_cs = ops.get_collection(critical_section_ops.CRITICAL_SECTIONS) + # restored_exec = ops.get_collection( + # critical_section_ops.CRITICAL_SECTION_EXECUTIONS) + # self.assertEqual(1, len(restored_cs)) + # self.assertEqual(1, len(restored_exec)) + # self.assertEqual(restored_cs[0].name, "imported/%s" % cs.name) + # self.assertEqual(restored_exec[0].op.name, "imported/%s" % r.op.name) + + # def testToProto(self): + # cs = critical_section_ops.CriticalSection(name="cs") + # proto = cs.to_proto() + # self.assertEqual(proto.critical_section_name, cs._handle.name) + # cs_copy = critical_section_ops.CriticalSection.from_proto(proto) + # self.assertEqual(cs_copy._handle, cs._handle) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/framework/python/ops/script_ops.py b/tensorflow/contrib/framework/python/ops/script_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..5d269fefdcfae7902b35e0f29f8cd12fcc58b882 --- /dev/null +++ b/tensorflow/contrib/framework/python/ops/script_ops.py @@ -0,0 +1,143 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Script Language Operators. See the @{$python/script_ops} guide. + +@@py_func +""" + +# pylint: disable=g-bad-name +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops.script_ops import py_func as _py_func +from tensorflow.python.util import nest + +__all__ = ['py_func'] + + +def py_func(func, + args=(), + kwargs=None, + output_types=None, + output_shapes=None, + stateful=True, + name=None): + """Wraps a python function and uses it as a TensorFlow op. + + This function is a wrapper around `tf.py_func` and improve it with kwargs + and output_shapes. Further it changed some argument names. + + Given a python function `func`, which takes numpy arrays as its + inputs and returns numpy arrays as its outputs, wrap this function as an + operation in a TensorFlow graph. The following snippet constructs a simple + TensorFlow graph that invokes the `np.sinh()` NumPy function as a operation + in the graph: + + ```python + def my_func(x): + # x will be a numpy array with the contents of the placeholder below + return np.sinh(x) + inp = tf.placeholder(tf.float32) + y = tf.py_func(my_func, [inp], tf.float32) + ``` + + + **N.B.** The `tf.py_func()` operation has the following known limitations: + + * The body of the function (i.e. `func`) will not be serialized in a + `GraphDef`. Therefore, you should not use this function if you need to + serialize your model and restore it in a different environment. + + * The operation must run in the same address space as the Python program + that calls `tf.py_func()`. If you are using distributed TensorFlow, you + must run a `tf.train.Server` in the same process as the program that calls + `tf.py_func()` and you must pin the created operation to a device in that + server (e.g. using `with tf.device():`). + + Args: + func: A Python function, which accepts a list of NumPy `ndarray` objects + having element types that match the corresponding `tf.Tensor` objects + in `inp`, and returns a list of `ndarray` objects (or a single `ndarray`) + having element types that match the corresponding values in `Tout`. + args: A list of `Tensor` objects. + kwargs: A dict with `Tensor` objects as values. + output_types: A nested structure of tensorflow data types or a single + tensorflow data type if there is only one, indicating what `func` returns. + output_shapes: Same as output_types, except the types are replaces with + shapes (optional). + stateful: (Boolean.) If True, the function should be considered stateful. + If a function is stateless, when given the same input it will return the + same output and have no observable side effects. Optimizations such as + common subexpression elimination are only performed on stateless + operations. + name: A name for the operation (optional). + + Returns: + Tensorflow op that wraps the input python function. + """ + + if kwargs is None: + kwargs = {} + + if not isinstance(args, (list, tuple)): + raise TypeError('args must be list and not {}. args: {}'.format( + type(args), args)) + + if not isinstance(kwargs, dict): + raise TypeError('kwargs must be dict and not {}. args: {}'.format( + type(kwargs), kwargs)) + + # For dynamic type inference use callable output_types and output_shapes + if callable(output_types): + # If callable assume same signature and call with tensors and get the types + output_types = output_types(*args, **kwargs) + if callable(output_shapes): + # If callable assume same signature and call with tensors and get the shapes + output_shapes = output_shapes(*args, **kwargs) + + flat_output_types = nest.flatten(output_types) + args = (args, kwargs) + flat_args = nest.flatten(args) + + def python_function_wrapper(*py_args): + py_args, py_kwargs = nest.pack_sequence_as(args, py_args) + + ret = func(*py_args, **py_kwargs) + # TODO(alextp): Catch Exceptions and improve msg, because tensorflow + # ist not able to preserve the traceback, i.e. the Exceptions does not + # contain any information where the Exception was raised. + nest.assert_shallow_structure(output_types, ret) + return nest.flatten(ret) + + flat_values = _py_func( + python_function_wrapper, + flat_args, + flat_output_types, + stateful=stateful, + name=name) + + if output_shapes is not None: + # I am not sure if this is nessesary + output_shapes = nest.map_structure_up_to( + output_types, tensor_shape.as_shape, output_shapes) + + flattened_shapes = nest.flatten(output_shapes) + for ret_t, shape in zip(flat_values, flattened_shapes): + ret_t.set_shape(shape) + + return nest.pack_sequence_as(output_types, flat_values) diff --git a/tensorflow/contrib/framework/python/ops/variables.py b/tensorflow/contrib/framework/python/ops/variables.py index 07b7857e7b2114d251ebb5c14eda9dff0d55bbef..3f1ece4510578b5ac39849c577fffbb2a3be45a7 100644 --- a/tensorflow/contrib/framework/python/ops/variables.py +++ b/tensorflow/contrib/framework/python/ops/variables.py @@ -441,7 +441,7 @@ def get_unique_variable(var_op_name): """ candidates = get_variables(scope=var_op_name) if not candidates: - raise ValueError('Couldnt find variable %s' % var_op_name) + raise ValueError('Couldn\'t find variable %s' % var_op_name) for candidate in candidates: if candidate.op.name == var_op_name: diff --git a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc index 88306094ab9947c9c78b03c0013f6afc88316803..0e06575d96f9b9538f0245b12d48cfd7c0e8d981 100644 --- a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc +++ b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc @@ -34,6 +34,7 @@ limitations under the License. #include "tensorflow/core/util/use_cudnn.h" #if GOOGLE_CUDA +#include "cuda/include/cudnn.h" #include "tensorflow/core/kernels/conv_ops_gpu.h" #include "tensorflow/core/platform/stream_executor.h" #include "tensorflow/core/util/activation_mode.h" @@ -278,6 +279,28 @@ Status TransformNHWCToNCHW(OpKernelContext* ctx, const Tensor& nhwc_tensor, return Status::OK(); } +// Adjusts padding so cudnn supports it. Sets `adjusted_padding` to be the +// adjusted padding, and `extra_padding_before` and `extra_padding_after` to be +// the extra padding that FusedConv needs to apply before calling cudnn. +void AdjustPaddingForCudnn(int padding, bool is_int8x4, int filter_size, + int* adjusted_padding, int* extra_padding_before, + int* extra_padding_after) { +#if CUDNN_VERSION < 7000 + if (is_int8x4 && filter_size >= 6) { + // TODO(b/70795525): Remove after NVIDIA fixes this bug with int8 fused + // convolution. I don't know cuDNN7 still has the bug, so enable this + // workaround for cuDNN6 or older. + *adjusted_padding = 0; + *extra_padding_before = padding / 2; + *extra_padding_after = padding - *extra_padding_before; + return; + } +#endif + *adjusted_padding = padding / 2 * 2; + *extra_padding_before = 0; + *extra_padding_after = padding % 2; +} + template void LaunchFusedConv2DBiasActivationOp:: launch(OpKernelContext* ctx, bool cudnn_use_autotune, @@ -303,7 +326,7 @@ void LaunchFusedConv2DBiasActivationOp:: stream->parent()->GetDeviceDescription().cuda_compute_capability(&cc_major, &cc_minor); OP_REQUIRES( - ctx, cc_major >= 6 && cc_minor >= 1, + ctx, ((cc_major == 6 && cc_minor >= 1) || cc_major > 6), errors::Unimplemented( "FusedConv2DBiasActivation for int8 is only supported on GPUs with " "compute capability 6.1 or later.")); @@ -338,12 +361,21 @@ void LaunchFusedConv2DBiasActivationOp:: 0, (output_rows - 1) * row_stride + filter_rows - conv_input_rows); padding_cols = std::max( 0, (output_cols - 1) * col_stride + filter_cols - conv_input_cols); - const int padding_rows_parity = padding_rows & 1; - const int padding_cols_parity = padding_cols & 1; - if ((padding_rows_parity | padding_cols_parity) != 0) { + int extra_top_padding = 0; + int extra_bottom_padding = 0; + int extra_left_padding = 0; + int extra_right_padding = 0; + AdjustPaddingForCudnn(padding_rows, is_int8x4, filter_rows, &padding_rows, + &extra_top_padding, &extra_bottom_padding); + AdjustPaddingForCudnn(padding_cols, is_int8x4, filter_cols, &padding_cols, + &extra_left_padding, &extra_right_padding); + if (extra_top_padding != 0 || extra_bottom_padding != 0 || + extra_left_padding != 0 || extra_right_padding != 0) { Tensor transformed_input; - const int new_conv_input_rows = conv_input_rows + padding_rows_parity; - const int new_conv_input_cols = conv_input_cols + padding_cols_parity; + const int new_conv_input_rows = + conv_input_rows + extra_top_padding + extra_bottom_padding; + const int new_conv_input_cols = + conv_input_cols + extra_left_padding + extra_right_padding; using VectT = typename Int8x4ToInt32::type>::type; auto pad_data_format = is_int8x4 ? FORMAT_NCHW : data_format; @@ -361,8 +393,9 @@ void LaunchFusedConv2DBiasActivationOp:: maybe_padded_conv_input.reinterpret_last_dimension()); functor::PadInput()( - ctx->eigen_device(), conv_input_eigen_tensor, {{0, 0}}, - {{padding_rows_parity, padding_cols_parity}}, + ctx->eigen_device(), conv_input_eigen_tensor, + {{extra_top_padding, extra_left_padding}}, + {{extra_bottom_padding, extra_right_padding}}, padded_conv_input_eigen_tensor, pad_data_format); conv_input = &maybe_padded_conv_input; @@ -439,6 +472,8 @@ void LaunchFusedConv2DBiasActivationOp:: .set_feature_map_count(output_depth) .set_layout(data_layout); dnn::ConvolutionDescriptor conv_desc; + CHECK_EQ(0, padding_rows % 2); + CHECK_EQ(0, padding_cols % 2); conv_desc.set_vertical_filter_stride(row_stride) .set_horizontal_filter_stride(col_stride) .set_zero_padding_height(padding_rows / 2) @@ -493,6 +528,8 @@ void LaunchFusedConv2DBiasActivationOp:: {{conv_input_rows, conv_input_cols}}, output_depth, {{filter_rows, filter_cols}}, + // TODO(yangzihao): Add support for arbitrary dilations for fused conv. + {{1, 1}}, // dilation_rows, dilation_cols {{row_stride, col_stride}}, {{padding_rows, padding_cols}}, conv_input->dtype(), diff --git a/tensorflow/contrib/fused_conv/kernels/fused_conv_ops_gpu.h b/tensorflow/contrib/fused_conv/kernels/fused_conv_ops_gpu.h index dc43af11580ce5fda74ee25da6c151a5b89c7aee..fa7a3c03aa35c756252b22a004be91fa24c10e41 100644 --- a/tensorflow/contrib/fused_conv/kernels/fused_conv_ops_gpu.h +++ b/tensorflow/contrib/fused_conv/kernels/fused_conv_ops_gpu.h @@ -30,11 +30,12 @@ class FusedConvParameters : public ConvParameters { public: FusedConvParameters(int64 batch, int64 in_depths, const SpatialArray& in, int64 out_depths, const SpatialArray& filter, - const SpatialArray& stride, const SpatialArray& padding, - DataType dtype, int device_id, bool has_side_input, + const SpatialArray& dilation, const SpatialArray& stride, + const SpatialArray& padding, DataType dtype, + int device_id, bool has_side_input, ActivationMode activation_mode) - : ConvParameters(batch, in_depths, in, out_depths, filter, stride, - padding, dtype, device_id), + : ConvParameters(batch, in_depths, in, out_depths, filter, dilation, + stride, padding, dtype, device_id), activation_mode_(activation_mode), has_side_input_(has_side_input) { hash_code_ = Hash64Combine(hash_code_, has_side_input); diff --git a/tensorflow/contrib/fused_conv/ops/fused_conv2d_bias_activation_op.cc b/tensorflow/contrib/fused_conv/ops/fused_conv2d_bias_activation_op.cc index 887ebc5a6c35379476fa1a643c866d38e2b25699..6a56237f67c844a3daa546eb02d64c9e2658f639 100644 --- a/tensorflow/contrib/fused_conv/ops/fused_conv2d_bias_activation_op.cc +++ b/tensorflow/contrib/fused_conv/ops/fused_conv2d_bias_activation_op.cc @@ -52,6 +52,7 @@ REGISTER_OP("FusedConv2DBiasActivation") .Attr("data_format: {'NHWC', 'NCHW', 'NCHW_VECT_C'} = 'NHWC'") .Attr("filter_format: {'HWIO', 'OIHW', 'OIHW_VECT_I'} = 'HWIO'") .Attr("activation_mode: {'Relu'} = 'Relu'") + .Attr("dilations: list(int) = [1, 1, 1, 1]") .SetShapeFn([](shape_inference::InferenceContext* c) { using shape_inference::ShapeHandle; using shape_inference::DimensionHandle; @@ -151,6 +152,11 @@ REGISTER_OP("FusedConv2DBiasActivation") kernel_height, kernel_width, input_channels % 4 ]` activation_mode: The activation applied to the output. Currently must be "Relu". + dilations: 1-D tensor of length 4. The dilation factor for each dimension + of `input`. If set to k > 1, there will be k-1 skipped cells between + each filter element on that dimension. The dimension order is determined + by the value of `data_format`, see above for details. Dilations in the + batch and depth dimensions must be 1. )doc"); } // namespace tensorflow diff --git a/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py b/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py index 2a18f3eeecc7e0e69c54b219886a263136f01b2c..bb155aa2496cbafd9f0630d3dffb2ba69395186c 100644 --- a/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py +++ b/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py @@ -658,6 +658,36 @@ def SimulateFusedConv2dBiasActivationInt8(conv_input_scale, conv_input, kernel, class FusedConvInt8Tests(test.TestCase): _test_params = [ + { + "batch_size": 1, + "input_channels": 4, + "output_channels": 4, + "input_height": 8, + "input_width": 8, + "filter_height": 6, + "filter_width": 6, + "vertical_stride": 2, + "horizontal_stride": 2, + "conv_input_scale": 0.002, + "side_input_scale": 0.0, + "bias_scale": 1, + "padding_type": "SAME" + }, + { + "batch_size": 1, + "input_channels": 4, + "output_channels": 4, + "input_height": 6, + "input_width": 6, + "filter_height": 6, + "filter_width": 6, + "vertical_stride": 2, + "horizontal_stride": 2, + "conv_input_scale": 0.002, + "side_input_scale": 0.0, + "bias_scale": 1, + "padding_type": "SAME" + }, { "batch_size": 2, "input_channels": 8, diff --git a/tensorflow/contrib/gan/BUILD b/tensorflow/contrib/gan/BUILD index abe4665caa9b23b5663df48487c6c77d33d15c59..b355a79b1a5d967eb82a30d41c073bbb52e0364c 100644 --- a/tensorflow/contrib/gan/BUILD +++ b/tensorflow/contrib/gan/BUILD @@ -56,6 +56,7 @@ py_test( srcs = ["python/train_test.py"], srcs_version = "PY2AND3", deps = [ + ":features", ":namedtuples", ":train", "//tensorflow/contrib/framework:framework_py", @@ -82,6 +83,7 @@ py_library( deps = [ ":classifier_metrics", ":eval_utils", + ":sliced_wasserstein", ":summaries", "//tensorflow/python:util", ], @@ -116,7 +118,7 @@ py_library( deps = [ ":clip_weights", ":conditioning_utils", - ":tensor_pool", + ":random_tensor_pool", ":virtual_batchnorm", "//tensorflow/python:util", ], @@ -221,10 +223,10 @@ py_test( ) py_library( - name = "tensor_pool", + name = "random_tensor_pool", srcs = [ - "python/features/python/tensor_pool.py", - "python/features/python/tensor_pool_impl.py", + "python/features/python/random_tensor_pool.py", + "python/features/python/random_tensor_pool_impl.py", ], srcs_version = "PY2AND3", deps = [ @@ -239,11 +241,11 @@ py_library( ) py_test( - name = "tensor_pool_test", - srcs = ["python/features/python/tensor_pool_test.py"], + name = "random_tensor_pool_test", + srcs = ["python/features/python/random_tensor_pool_test.py"], srcs_version = "PY2AND3", deps = [ - ":tensor_pool", + ":random_tensor_pool", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:dtypes", @@ -502,6 +504,41 @@ py_test( ], ) +py_library( + name = "sliced_wasserstein", + srcs = [ + "python/eval/python/sliced_wasserstein.py", + "python/eval/python/sliced_wasserstein_impl.py", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:array_ops", + "//tensorflow/python:constant_op", + "//tensorflow/python:linalg_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:nn", + "//tensorflow/python:nn_ops", + "//tensorflow/python:random_ops", + "//tensorflow/python:script_ops", + "//tensorflow/python:util", + "//third_party/py/numpy", + ], +) + +py_test( + name = "sliced_wasserstein_test", + srcs = ["python/eval/python/sliced_wasserstein_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":sliced_wasserstein", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python:random_ops", + "//third_party/py/numpy", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/gan/README.md b/tensorflow/contrib/gan/README.md index 4bca0a1d62a2b404c6783c7cfe3b5c67cfc58221..4ead66ca13e74bacc0e4679a8d5c4e0f23d04b69 100644 --- a/tensorflow/contrib/gan/README.md +++ b/tensorflow/contrib/gan/README.md @@ -99,8 +99,8 @@ gan_model = tfgan.gan_model( # Build the GAN loss. gan_loss = tfgan.gan_loss( gan_model, - generator_loss_fn=tfgan_losses.wasserstein_generator_loss, - discriminator_loss_fn=tfgan_losses.wasserstein_discriminator_loss) + generator_loss_fn=tfgan.losses.wasserstein_generator_loss, + discriminator_loss_fn=tfgan.losses.wasserstein_discriminator_loss) # Create the train ops, which calculate gradients and apply updates to weights. train_ops = tfgan.gan_train_ops( @@ -161,8 +161,8 @@ gan_model = tfgan.gan_model( # Build the GAN loss and standard pixel loss. gan_loss = tfgan.gan_loss( gan_model, - generator_loss_fn=tfgan_losses.wasserstein_generator_loss, - discriminator_loss_fn=tfgan_losses.wasserstein_discriminator_loss, + generator_loss_fn=tfgan.losses.wasserstein_generator_loss, + discriminator_loss_fn=tfgan.losses.wasserstein_discriminator_loss, gradient_penalty=1.0) l1_pixel_loss = tf.norm(gan_model.real_data - gan_model.generated_data, ord=1) @@ -193,8 +193,8 @@ gan_model = tfgan.gan_model( # Build the GAN loss and standard pixel loss. gan_loss = tfgan.gan_loss( gan_model, - generator_loss_fn=tfgan_losses.least_squares_generator_loss, - discriminator_loss_fn=tfgan_losses.least_squares_discriminator_loss) + generator_loss_fn=tfgan.losses.least_squares_generator_loss, + discriminator_loss_fn=tfgan.losses.least_squares_discriminator_loss) l1_pixel_loss = tf.norm(gan_model.real_data - gan_model.generated_data, ord=1) # Modify the loss tuple to include the pixel loss. @@ -223,8 +223,8 @@ gan_model = tfgan.infogan_model( # Build the GAN loss with mutual information penalty. gan_loss = tfgan.gan_loss( gan_model, - generator_loss_fn=tfgan_losses.wasserstein_generator_loss, - discriminator_loss_fn=tfgan_losses.wasserstein_discriminator_loss, + generator_loss_fn=tfgan.losses.wasserstein_generator_loss, + discriminator_loss_fn=tfgan.losses.wasserstein_discriminator_loss, gradient_penalty=1.0, mutual_information_penalty_weight=1.0) diff --git a/tensorflow/contrib/gan/__init__.py b/tensorflow/contrib/gan/__init__.py index dff361fdc42708ea69999c2def4721f9d49fcf14..f1946c7f925660eae3aaa650c437e03da1f33d6c 100644 --- a/tensorflow/contrib/gan/__init__.py +++ b/tensorflow/contrib/gan/__init__.py @@ -12,7 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""TFGAN grouped API. Please see README.md for details and usage.""" +"""TFGAN is a lightweight library for training and evaluating GANs. + +In addition to providing the infrastructure for easily training and evaluating +GANS, this library contains modules for a TFGAN-backed Estimator, +evaluation metrics, features (such as virtual batch normalization), and losses. +Please see README.md for details and usage. +""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/gan/python/estimator/__init__.py b/tensorflow/contrib/gan/python/estimator/__init__.py index 8c4a18228039cb4f2c06e0333f4b8408f1f631e9..c9f7bc61b25230e4159cf8cbc7c9cceead0aa706 100644 --- a/tensorflow/contrib/gan/python/estimator/__init__.py +++ b/tensorflow/contrib/gan/python/estimator/__init__.py @@ -12,7 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""TFGAN grouped API. Please see README.md for details and usage.""" +"""TFGAN estimator module. + +GANEstimator provides all the infrastructure support of a TensorFlow Estimator +with the feature support of TFGAN. +""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py index 058dc1d1f8cc176dcdb81268da2c4704d7eddc99..0d51c282a8977871185fb4200082feb7868cdbae 100644 --- a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py +++ b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py @@ -96,7 +96,7 @@ class GANEstimator(estimator.Estimator): # Generate samples from generator. predictions = np.array([ x for x in gan_estimator.predict(predict_input_fn)]) - ``` + ``` """ def __init__(self, @@ -107,6 +107,7 @@ class GANEstimator(estimator.Estimator): discriminator_loss_fn=None, generator_optimizer=None, discriminator_optimizer=None, + get_hooks_fn=None, add_summaries=None, use_loss_summaries=True, config=None): @@ -137,6 +138,10 @@ class GANEstimator(estimator.Estimator): work. discriminator_optimizer: Same as `generator_optimizer`, but for the discriminator updates. + get_hooks_fn: A function that takes a `GANTrainOps` tuple and returns a + list of hooks. These hooks are run on the generator and discriminator + train ops, and can be used to implement the GAN training scheme. + Defaults to `train.get_sequential_train_hooks()`. add_summaries: `None`, a single `SummaryType`, or a list of `SummaryType`. use_loss_summaries: If `True`, add loss summaries. If `False`, does not. If `None`, uses defaults. @@ -151,7 +156,7 @@ class GANEstimator(estimator.Estimator): else discriminator_optimizer) gan_head = head_lib.gan_head( generator_loss_fn, discriminator_loss_fn, gopt, dopt, - use_loss_summaries) + use_loss_summaries, get_hooks_fn=get_hooks_fn) return _gan_model_fn( features, labels, mode, generator_fn, discriminator_fn, gan_head, add_summaries) @@ -160,11 +165,6 @@ class GANEstimator(estimator.Estimator): model_fn=_model_fn, model_dir=model_dir, config=config) -def _use_check_shapes(real_data): - """Determines whether TFGAN should check Tensor shapes.""" - return isinstance(real_data, ops.Tensor) - - def _gan_model_fn( features, labels, @@ -242,7 +242,7 @@ def _make_gan_model(generator_fn, discriminator_fn, real_data, real_data, generator_inputs, generator_scope=generator_scope, - check_shapes=_use_check_shapes(real_data)) + check_shapes=False) if add_summaries: if not isinstance(add_summaries, (tuple, list)): add_summaries = [add_summaries] diff --git a/tensorflow/contrib/gan/python/estimator/python/head_impl.py b/tensorflow/contrib/gan/python/estimator/python/head_impl.py index 204c646e194319c0e63599da0b2a4909ef270ef3..a21358c50bbdb4a1a929b0c5bc322cec4c9923b5 100644 --- a/tensorflow/contrib/gan/python/estimator/python/head_impl.py +++ b/tensorflow/contrib/gan/python/estimator/python/head_impl.py @@ -71,7 +71,7 @@ class GANHead(head._Head): # pylint: disable=protected-access def __init__(self, generator_loss_fn, discriminator_loss_fn, generator_optimizer, discriminator_optimizer, use_loss_summaries=True, - get_hooks_fn=tfgan_train.get_sequential_train_hooks(), + get_hooks_fn=None, name=None): """`Head` for GAN training. @@ -86,10 +86,12 @@ class GANHead(head._Head): # pylint: disable=protected-access use_loss_summaries: If `True`, add loss summaries. If `False`, does not. If `None`, uses defaults. get_hooks_fn: A function that takes a GANTrainOps tuple and returns a list - of hooks. + of hooks. Defaults to `train.get_sequential_train_hooks()` name: name of the head. If provided, summary and metrics keys will be suffixed by `"/" + name`. """ + if get_hooks_fn is None: + get_hooks_fn = tfgan_train.get_sequential_train_hooks() # TODO(joelshor): Validate inputs. if use_loss_summaries in [True, False]: diff --git a/tensorflow/contrib/gan/python/eval/__init__.py b/tensorflow/contrib/gan/python/eval/__init__.py index bb8046187807d0cc584f7174eb9aac578855c110..f86b8513053a45f9830411f7df2c32d1f36a97b2 100644 --- a/tensorflow/contrib/gan/python/eval/__init__.py +++ b/tensorflow/contrib/gan/python/eval/__init__.py @@ -12,7 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""TFGAN grouped API. Please see README.md for details and usage.""" +"""TFGAN evaluation module. + +This module supports techniques such as Inception Score, Frechet Inception +distance, and Sliced Wasserstein distance. +""" # pylint: disable=,wildcard-import,unused-import from __future__ import absolute_import @@ -22,10 +26,12 @@ from __future__ import print_function # Collapse eval into a single namespace. from tensorflow.contrib.gan.python.eval.python import classifier_metrics from tensorflow.contrib.gan.python.eval.python import eval_utils +from tensorflow.contrib.gan.python.eval.python import sliced_wasserstein from tensorflow.contrib.gan.python.eval.python import summaries from tensorflow.contrib.gan.python.eval.python.classifier_metrics import * from tensorflow.contrib.gan.python.eval.python.eval_utils import * +from tensorflow.contrib.gan.python.eval.python.sliced_wasserstein import * from tensorflow.contrib.gan.python.eval.python.summaries import * # pylint: enable=wildcard-import,unused-import @@ -33,7 +39,10 @@ from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = [ 'classifier_metrics', + 'sliced_wasserstein_distance', 'summaries', 'eval_utils', -] + classifier_metrics.__all__ + summaries.__all__ + eval_utils.__all__ +] + ( + classifier_metrics.__all__ + sliced_wasserstein.__all__ + + summaries.__all__ + eval_utils.__all__) remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py index bb65f05b5a17e9a872e41d1dcb05aeb3cd6f6f40..986a5ff6dcbeb2ff996f49137adc6d34e14c979f 100644 --- a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py +++ b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py @@ -57,8 +57,10 @@ __all__ = [ 'run_inception', 'inception_score', 'classifier_score', + 'classifier_score_from_logits', 'frechet_inception_distance', 'frechet_classifier_distance', + 'frechet_classifier_distance_from_activations', 'INCEPTION_DEFAULT_IMAGE_SIZE', ] @@ -222,13 +224,13 @@ def run_inception(images, image_size: Required image width and height. See unit tests for the default values. input_tensor: Name of input Tensor. - output_tensor: Name of output Tensor. This function will compute activations - at the specified layer. Examples include INCEPTION_V3_OUTPUT and - INCEPTION_V3_FINAL_POOL which would result in this function computing + output_tensor: Name or list of output Tensors. This function will compute + activations at the specified layer. Examples include INCEPTION_V3_OUTPUT + and INCEPTION_V3_FINAL_POOL which would result in this function computing the final logits or the penultimate pooling layer. Returns: - Logits. + Tensor or Tensors corresponding to computed `output_tensor`. Raises: ValueError: If images are not the correct size. @@ -244,8 +246,14 @@ def run_inception(images, activations = run_image_classifier(images, graph_def, input_tensor, output_tensor) - if array_ops.rank(activations) != 2: - activations = layers.flatten(activations) + if isinstance(activations, list): + for i, activation in enumerate(activations): + if array_ops.rank(activation) != 2: + activations[i] = layers.flatten(activation) + else: + if array_ops.rank(activations) != 2: + activations = layers.flatten(activations) + return activations @@ -257,23 +265,26 @@ def run_image_classifier(tensor, graph_def, input_tensor, tensor: An Input tensor. graph_def: A GraphDef proto. input_tensor: Name of input tensor in graph def. - output_tensor: Name of output tensor in graph def. + output_tensor: A tensor name or list of tensor names in graph def. scope: Name scope for classifier. Returns: - Classifier output. Shape depends on the classifier used, but is often - [batch, classes]. + Classifier output if `output_tensor` is a string, or a list of outputs if + `output_tensor` is a list. Raises: - ValueError: If `image_size` is not `None`, and `tensor` are not the correct - size. + ValueError: If `input_tensor` or `output_tensor` aren't in the graph_def. """ input_map = {input_tensor: tensor} - return_elements = [output_tensor] - classifier_output = importer.import_graph_def( - graph_def, input_map, return_elements, name=scope)[0] + is_singleton = isinstance(output_tensor, str) + if is_singleton: + output_tensor = [output_tensor] + classifier_outputs = importer.import_graph_def( + graph_def, input_map, output_tensor, name=scope) + if is_singleton: + classifier_outputs = classifier_outputs[0] - return classifier_output + return classifier_outputs def classifier_score(images, classifier_fn, num_batches=1): @@ -289,6 +300,11 @@ def classifier_score(images, classifier_fn, num_batches=1): which captures how different the network's classification prediction is from the prior distribution over classes. + NOTE: This function consumes images, computes their logits, and then + computes the classifier score. If you would like to precompute many logits for + large batches, use clasifier_score_from_logits(), which this method also + uses. + Args: images: Images to calculate the classifier score for. classifier_fn: A function that takes images and produces logits based on a @@ -312,6 +328,34 @@ def classifier_score(images, classifier_fn, num_batches=1): swap_memory=True, name='RunClassifier') logits = array_ops.concat(array_ops.unstack(logits), 0) + + return classifier_score_from_logits(logits) + + +def classifier_score_from_logits(logits): + """Classifier score for evaluating a generative model from logits. + + This method computes the classifier score for a set of logits. This can be + used independently of the classifier_score() method, especially in the case + of using large batches during evaluation where we would like precompute all + of the logits before computing the classifier score. + + This technique is described in detail in https://arxiv.org/abs/1606.03498. In + summary, this function calculates: + + exp( E[ KL(p(y|x) || p(y)) ] ) + + which captures how different the network's classification prediction is from + the prior distribution over classes. + + Args: + logits: Precomputed 2D tensor of logits that will be used to + compute the classifier score. + + Returns: + The classifier score. A floating-point scalar of the same type as the output + of `logits`. + """ logits.shape.assert_has_rank(2) # Use maximum precision for best results. @@ -328,6 +372,7 @@ def classifier_score(images, classifier_fn, num_batches=1): if logits_dtype != dtypes.float64: final_score = math_ops.cast(final_score, logits_dtype) + return final_score @@ -406,6 +451,11 @@ def frechet_classifier_distance(real_images, sample size to compute frechet classifier distance when comparing two generative models. + NOTE: This function consumes images, computes their activations, and then + computes the classifier score. If you would like to precompute many + activations for real and generated images for large batches, please use + frechet_clasifier_distance_from_activations(), which this method also uses. + Args: real_images: Real images to use to compute Frechet Inception distance. generated_images: Generated images to use to compute Frechet Inception @@ -417,7 +467,7 @@ def frechet_classifier_distance(real_images, Returns: The Frechet Inception distance. A floating-point scalar of the same type - as the output of `classifier_fn` + as the output of `classifier_fn`. """ real_images_list = array_ops.split( @@ -436,31 +486,69 @@ def frechet_classifier_distance(real_images, swap_memory=True, name='RunClassifier') - activations_dtype = activations.dtype # Split the activations by the real and generated images. real_a, gen_a = array_ops.split(activations, [num_batches, num_batches], 0) # Ensure the activations have the right shapes. real_a = array_ops.concat(array_ops.unstack(real_a), 0) gen_a = array_ops.concat(array_ops.unstack(gen_a), 0) - if activations_dtype != dtypes.float64: - real_a = math_ops.to_double(real_a) - gen_a = math_ops.to_double(gen_a) - real_a.shape.assert_has_rank(2) - gen_a.shape.assert_has_rank(2) + return frechet_classifier_distance_from_activations(real_a, gen_a) + + +def frechet_classifier_distance_from_activations( + real_activations, generated_activations): + """Classifier distance for evaluating a generative model from activations. + + This methods computes the Frechet classifier distance from activations of + real images and generated images. This can be used independently of the + frechet_classifier_distance() method, especially in the case of using large + batches during evaluation where we would like precompute all of the + activations before computing the classifier distance. + + This technique is described in detail in https://arxiv.org/abs/1706.08500. + Given two Gaussian distribution with means m and m_w and covariance matrices + C and C_w, this function calcuates + + |m - m_w|^2 + Tr(C + C_w - 2(C * C_w)^(1/2)) + + which captures how different the distributions of real images and generated + images (or more accurately, their visual features) are. Note that unlike the + Inception score, this is a true distance and utilizes information about real + world images. + + Args: + real_activations: 2D Tensor containing activations of real data. Shape is + [batch_size, activation_size]. + generated_activations: 2D Tensor containing activations of generated data. + Shape is [batch_size, activation_size]. + + Returns: + The Frechet Inception distance. A floating-point scalar of the same type + as the output of the activations. + + """ + real_activations.shape.assert_has_rank(2) + generated_activations.shape.assert_has_rank(2) + + activations_dtype = real_activations.dtype + if activations_dtype != dtypes.float64: + real_activations = math_ops.to_double(real_activations) + generated_activations = math_ops.to_double(generated_activations) # Compute mean and covariance matrices of activations. - m = math_ops.reduce_mean(real_a, 0) - m_v = math_ops.reduce_mean(gen_a, 0) - num_examples = math_ops.to_double(array_ops.shape(real_a)[0]) + m = math_ops.reduce_mean(real_activations, 0) + m_v = math_ops.reduce_mean(generated_activations, 0) + num_examples = math_ops.to_double(array_ops.shape(real_activations)[0]) # sigma = (1 / (n - 1)) * (X - mu) (X - mu)^T + real_centered = real_activations - m sigma = math_ops.matmul( - real_a - m, real_a - m, transpose_a=True) / (num_examples - 1) + real_centered, real_centered, transpose_a=True) / (num_examples - 1) + gen_centered = generated_activations - m_v sigma_v = math_ops.matmul( - gen_a - m_v, gen_a - m_v, transpose_a=True) / (num_examples - 1) + gen_centered, gen_centered, transpose_a=True) / (num_examples - 1) # Find the Tr(sqrt(sigma sigma_v)) component of FID sqrt_trace_component = trace_sqrt_product(sigma, sigma_v) diff --git a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py index 92e0a995748c1c4c2ddfff0daae59be5a6eaefb4..1e18c699ba93b5f524341c65d0a2db84556b65a2 100644 --- a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py +++ b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py @@ -190,6 +190,23 @@ class ClassifierMetricsTest(test.TestCase): # Check that none of the model variables are trainable. self.assertListEqual([], variables.trainable_variables()) + def test_run_inception_multiple_outputs(self): + """Test `run_inception` graph construction with multiple outputs.""" + batch_size = 3 + img = array_ops.ones([batch_size, 299, 299, 3]) + logits, pool = _run_with_mock( + classifier_metrics.run_inception, img, + output_tensor=[classifier_metrics.INCEPTION_OUTPUT, + classifier_metrics.INCEPTION_FINAL_POOL]) + + self.assertTrue(isinstance(logits, ops.Tensor)) + self.assertTrue(isinstance(pool, ops.Tensor)) + logits.shape.assert_is_compatible_with([batch_size, 1001]) + pool.shape.assert_is_compatible_with([batch_size, 2048]) + + # Check that none of the model variables are trainable. + self.assertListEqual([], variables.trainable_variables()) + def test_inception_score_graph(self): """Test `inception_score` graph construction.""" score = _run_with_mock(classifier_metrics.inception_score, diff --git a/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein.py b/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein.py new file mode 100644 index 0000000000000000000000000000000000000000..523968bed91f1021ae629bf52c405cf5c2d7b917 --- /dev/null +++ b/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein.py @@ -0,0 +1,28 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Model evaluation tools for TFGAN.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.gan.python.eval.python import sliced_wasserstein_impl +# pylint: disable=wildcard-import +from tensorflow.contrib.gan.python.eval.python.sliced_wasserstein_impl import * +# pylint: enable=wildcard-import +from tensorflow.python.util.all_util import remove_undocumented + +__all__ = sliced_wasserstein_impl.__all__ +remove_undocumented(__name__, __all__) diff --git a/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein_impl.py b/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..9bebcacbe46d85fc4226c4275b71b3ecbde57a97 --- /dev/null +++ b/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein_impl.py @@ -0,0 +1,282 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Implementation of Sliced Wasserstein Distance. + +Proposed in https://arxiv.org/abs/1710.10196 and the official Theano +implementation that we used as reference can be found here: +https://github.com/tkarras/progressive_growing_of_gans + +Note: this is not an exact distance but an approximation through random +projections. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +from tensorflow.python.framework import constant_op +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import script_ops + +__all__ = ['sliced_wasserstein_distance'] +_GAUSSIAN_FILTER = np.float32([[1, 4, 6, 4, 1], [4, 16, 24, 16, 4], [ + 6, 24, 36, 24, 6 +], [4, 16, 24, 16, 4], [1, 4, 6, 4, 1]]).reshape([5, 5, 1, 1]) / 256.0 + + +def _laplacian_pyramid(batch, num_levels): + """Compute a Laplacian pyramid. + + Args: + batch: (tensor) The batch of images (batch, height, width, channels). + num_levels: (int) Desired number of hierarchical levels. + Returns: + List of tensors from the highest to lowest resolution. + """ + gaussian_filter = constant_op.constant(_GAUSSIAN_FILTER) + + def spatial_conv(batch, gain): + s = array_ops.shape(batch) + padded = array_ops.pad(batch, [[0, 0], [2, 2], [2, 2], [0, 0]], 'REFLECT') + xt = array_ops.transpose(padded, [0, 3, 1, 2]) + xt = array_ops.reshape(xt, [s[0] * s[3], s[1] + 4, s[2] + 4, 1]) + conv_out = nn_ops.conv2d(xt, gaussian_filter * gain, [1] * 4, 'VALID') + conv_xt = array_ops.reshape(conv_out, [s[0], s[3], s[1], s[2]]) + conv_xt = array_ops.transpose(conv_xt, [0, 2, 3, 1]) + return conv_xt + + def pyr_down(batch): # matches cv2.pyrDown() + return spatial_conv(batch, 1)[:, ::2, ::2] + + def pyr_up(batch): # matches cv2.pyrUp() + s = array_ops.shape(batch) + zeros = array_ops.zeros([3 * s[0], s[1], s[2], s[3]]) + res = array_ops.concat([batch, zeros], 0) + res = array_ops.batch_to_space(res, crops=[[0, 0], [0, 0]], block_size=2) + res = spatial_conv(res, 4) + return res + + pyramid = [math_ops.to_float(batch)] + for _ in range(1, num_levels): + pyramid.append(pyr_down(pyramid[-1])) + pyramid[-2] -= pyr_up(pyramid[-1]) + return pyramid + + +def _batch_to_patches(batch, patches_per_image, patch_size): + """Extract patches from a batch. + + Args: + batch: (tensor) The batch of images (batch, height, width, channels). + patches_per_image: (int) Number of patches to extract per image. + patch_size: (int) Size of the patches (size, size, channels) to extract. + Returns: + Tensor (batch*patches_per_image, patch_size, patch_size, channels) of + patches. + """ + + def py_func_random_patches(batch): + """Numpy wrapper.""" + batch_size, height, width, channels = batch.shape + patch_count = patches_per_image * batch_size + hs = patch_size // 2 + # Randomly pick patches. + patch_id, y, x, chan = np.ogrid[0:patch_count, -hs:hs + 1, -hs:hs + 1, 0:3] + img_id = patch_id // patches_per_image + # pylint: disable=g-no-augmented-assignment + # Need explicit addition for broadcast to work properly. + y = y + np.random.randint(hs, height - hs, size=(patch_count, 1, 1, 1)) + x = x + np.random.randint(hs, width - hs, size=(patch_count, 1, 1, 1)) + # pylint: enable=g-no-augmented-assignment + idx = ((img_id * height + y) * width + x) * channels + chan + patches = batch.flat[idx] + return patches + + patches = script_ops.py_func( + py_func_random_patches, [batch], batch.dtype, stateful=False) + return patches + + +def _normalize_patches(patches): + """Normalize patches by their mean and standard deviation. + + Args: + patches: (tensor) The batch of patches (batch, size, size, channels). + Returns: + Tensor (batch, size, size, channels) of the normalized patches. + """ + patches = array_ops.concat(patches, 0) + mean, variance = nn.moments(patches, [1, 2, 3], keep_dims=True) + patches = (patches - mean) / math_ops.sqrt(variance) + return array_ops.reshape(patches, [array_ops.shape(patches)[0], -1]) + + +def _sort_rows(matrix, num_rows): + """Sort matrix rows by the last column. + + Args: + matrix: a matrix of values (row,col). + num_rows: (int) number of sorted rows to return from the matrix. + Returns: + Tensor (num_rows, col) of the sorted matrix top K rows. + """ + tmatrix = array_ops.transpose(matrix, [1, 0]) + sorted_tmatrix = nn_ops.top_k(tmatrix, num_rows)[0] + return array_ops.transpose(sorted_tmatrix, [1, 0]) + + +def _sliced_wasserstein(a, b, random_sampling_count, random_projection_dim): + """Compute the approximate sliced Wasserstein distance. + + Args: + a: (matrix) Distribution "a" of samples (row, col). + b: (matrix) Distribution "b" of samples (row, col). + random_sampling_count: (int) Number of random projections to average. + random_projection_dim: (int) Dimension of the random projection space. + Returns: + Float containing the approximate distance between "a" and "b". + """ + s = array_ops.shape(a) + means = [] + for _ in range(random_sampling_count): + # Random projection matrix. + proj = random_ops.random_normal( + [array_ops.shape(a)[1], random_projection_dim]) + proj *= math_ops.rsqrt( + math_ops.reduce_sum(math_ops.square(proj), 0, keep_dims=True)) + # Project both distributions and sort them. + proj_a = math_ops.matmul(a, proj) + proj_b = math_ops.matmul(b, proj) + proj_a = _sort_rows(proj_a, s[0]) + proj_b = _sort_rows(proj_b, s[0]) + # Pairwise Wasserstein distance. + wdist = math_ops.reduce_mean(math_ops.abs(proj_a - proj_b)) + means.append(wdist) + return math_ops.reduce_mean(means) + + +def _sliced_wasserstein_svd(a, b): + """Compute the approximate sliced Wasserstein distance using an SVD. + + This is not part of the paper, it's a variant with possibly more accurate + measure. + + Args: + a: (matrix) Distribution "a" of samples (row, col). + b: (matrix) Distribution "b" of samples (row, col). + Returns: + Float containing the approximate distance between "a" and "b". + """ + s = array_ops.shape(a) + # Random projection matrix. + sig, u = linalg_ops.svd(array_ops.concat([a, b], 0))[:2] + proj_a, proj_b = array_ops.split(u * sig, 2, axis=0) + proj_a = _sort_rows(proj_a[:, ::-1], s[0]) + proj_b = _sort_rows(proj_b[:, ::-1], s[0]) + # Pairwise Wasserstein distance. + wdist = math_ops.reduce_mean(math_ops.abs(proj_a - proj_b)) + return wdist + + +def sliced_wasserstein_distance(real_images, + fake_images, + resolution_min=16, + patches_per_image=64, + patch_size=7, + random_sampling_count=1, + random_projection_dim=7 * 7 * 3, + use_svd=False): + """Compute the Wasserstein distance between two distributions of images. + + Note that measure vary with the number of images. Use 8192 images to get + numbers comparable to the ones in the original paper. + + Args: + real_images: (tensor) Real images (batch, height, width, channels). + fake_images: (tensor) Fake images (batch, height, width, channels). + resolution_min: (int) Minimum resolution for the Laplacion pyramid. + patches_per_image: (int) Number of patches to extract per image per + Laplacian level. + patch_size: (int) Width of a square patch. + random_sampling_count: (int) Number of random projections to average. + random_projection_dim: (int) Dimension of the random projection space. + use_svd: experimental method to compute a more accurate distance. + Returns: + List of tuples (distance_real, distance_fake) for each level of the + Laplacian pyramid from the highest resoluion to the lowest. + distance_real is the Wasserstein distance between real images + distance_fake is the Wasserstein distance between real and fake images. + Raises: + ValueError: If the inputs shapes are incorrect. Input tensor dimensions + (batch, height, width, channels) are expected to be known at graph + construction time. In addition height and width must be the same and the + number of colors should be exactly 3. Real and fake images must have the + same size. + """ + height = real_images.shape[1] + real_images.shape.assert_is_compatible_with([None, None, height, 3]) + fake_images.shape.assert_is_compatible_with(real_images.shape) + + # Select resolutions. + resolution_full = int(height) + resolution_min = min(resolution_min, resolution_full) + resolution_max = resolution_full + # Base loss of detail. + resolutions = [ + 2**i + for i in range( + int(np.log2(resolution_max)), + int(np.log2(resolution_min)) - 1, -1) + ] + + # Gather patches for each level of the Laplacian pyramids. + patches_real, patches_fake, patches_test = ( + [[] for _ in resolutions] for _ in range(3)) + for lod, level in enumerate( + _laplacian_pyramid(real_images, len(resolutions))): + patches_real[lod].append( + _batch_to_patches(level, patches_per_image, patch_size)) + patches_test[lod].append( + _batch_to_patches(level, patches_per_image, patch_size)) + + for lod, level in enumerate( + _laplacian_pyramid(fake_images, len(resolutions))): + patches_fake[lod].append( + _batch_to_patches(level, patches_per_image, patch_size)) + + for lod in range(len(resolutions)): + for patches in [patches_real, patches_test, patches_fake]: + patches[lod] = _normalize_patches(patches[lod]) + + # Evaluate scores. + scores = [] + for lod in range(len(resolutions)): + if not use_svd: + scores.append( + (_sliced_wasserstein(patches_real[lod], patches_test[lod], + random_sampling_count, random_projection_dim), + _sliced_wasserstein(patches_real[lod], patches_fake[lod], + random_sampling_count, random_projection_dim))) + else: + scores.append( + (_sliced_wasserstein_svd(patches_real[lod], patches_test[lod]), + _sliced_wasserstein_svd(patches_real[lod], patches_fake[lod]))) + return scores diff --git a/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein_test.py b/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein_test.py new file mode 100644 index 0000000000000000000000000000000000000000..b960af28eaa969079b72c7aabcde2ad6cd1f5c68 --- /dev/null +++ b/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein_test.py @@ -0,0 +1,131 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Sliced Wasserstein Distance.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +from scipy import ndimage +from tensorflow.contrib.gan.python.eval.python import sliced_wasserstein_impl as swd +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.platform import test + + +class ClassifierMetricsTest(test.TestCase): + + def test_laplacian_pyramid(self): + # The numpy/scipy code for reference estimation comes from: + # https://github.com/tkarras/progressive_growing_of_gans + gaussian_filter = np.float32([[1, 4, 6, 4, 1], [4, 16, 24, 16, 4], [ + 6, 24, 36, 24, 6 + ], [4, 16, 24, 16, 4], [1, 4, 6, 4, 1]]) / 256.0 + + def np_pyr_down(minibatch): # matches cv2.pyrDown() + assert minibatch.ndim == 4 + return ndimage.convolve( + minibatch, + gaussian_filter[np.newaxis, np.newaxis, :, :], + mode='mirror')[:, :, ::2, ::2] + + def np_pyr_up(minibatch): # matches cv2.pyrUp() + assert minibatch.ndim == 4 + s = minibatch.shape + res = np.zeros((s[0], s[1], s[2] * 2, s[3] * 2), minibatch.dtype) + res[:, :, ::2, ::2] = minibatch + return ndimage.convolve( + res, + gaussian_filter[np.newaxis, np.newaxis, :, :] * 4.0, + mode='mirror') + + def np_laplacian_pyramid(minibatch, num_levels): + # Note: there's a bug in the original SWD, fixed repeatability. + pyramid = [minibatch.astype('f').copy()] + for _ in range(1, num_levels): + pyramid.append(np_pyr_down(pyramid[-1])) + pyramid[-2] -= np_pyr_up(pyramid[-1]) + return pyramid + + data = np.random.normal(size=[256, 3, 32, 32]).astype('f') + pyramid = np_laplacian_pyramid(data, 3) + data_tf = array_ops.placeholder(dtypes.float32, [256, 32, 32, 3]) + pyramid_tf = swd._laplacian_pyramid(data_tf, 3) + with self.test_session() as sess: + pyramid_tf = sess.run( + pyramid_tf, feed_dict={ + data_tf: data.transpose(0, 2, 3, 1) + }) + for x in range(3): + self.assertAllClose( + pyramid[x].transpose(0, 2, 3, 1), pyramid_tf[x], atol=1e-6) + + def test_sliced_wasserstein_distance(self): + """Test the distance.""" + d1 = random_ops.random_uniform([256, 32, 32, 3]) + d2 = random_ops.random_normal([256, 32, 32, 3]) + wfunc = swd.sliced_wasserstein_distance(d1, d2) + with self.test_session() as sess: + wscores = [sess.run(x) for x in wfunc] + self.assertAllClose( + np.array([0.014, 0.014], 'f'), + np.array([x[0] for x in wscores], 'f'), + rtol=0.1) + self.assertAllClose( + np.array([0.014, 0.020], 'f'), + np.array([x[1] for x in wscores], 'f'), + rtol=0.1) + + def test_sliced_wasserstein_distance_svd(self): + """Test the distance.""" + d1 = random_ops.random_uniform([256, 32, 32, 3]) + d2 = random_ops.random_normal([256, 32, 32, 3]) + wfunc = swd.sliced_wasserstein_distance(d1, d2, use_svd=True) + with self.test_session() as sess: + wscores = [sess.run(x) for x in wfunc] + self.assertAllClose( + np.array([0.013, 0.013], 'f'), + np.array([x[0] for x in wscores], 'f'), + rtol=0.15) + self.assertAllClose( + np.array([0.014, 0.019], 'f'), + np.array([x[1] for x in wscores], 'f'), + rtol=0.15) + + def test_swd_mismatched(self): + """Test the inputs mismatched shapes are detected.""" + d1 = random_ops.random_uniform([256, 32, 32, 3]) + d2 = random_ops.random_normal([256, 32, 31, 3]) + d3 = random_ops.random_normal([256, 31, 32, 3]) + d4 = random_ops.random_normal([255, 32, 32, 3]) + with self.assertRaises(ValueError): + swd.sliced_wasserstein_distance(d1, d2) + with self.assertRaises(ValueError): + swd.sliced_wasserstein_distance(d1, d3) + with self.assertRaises(ValueError): + swd.sliced_wasserstein_distance(d1, d4) + + def test_swd_not_rgb(self): + """Test that only RGB is supported.""" + d1 = random_ops.random_uniform([256, 32, 32, 1]) + d2 = random_ops.random_normal([256, 32, 32, 1]) + with self.assertRaises(ValueError): + swd.sliced_wasserstein_distance(d1, d2) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/gan/python/features/__init__.py b/tensorflow/contrib/gan/python/features/__init__.py index 6d0972f8db418d6fcf517cc6f7e96093ae08a9e4..4816daf760143af9f1502873b123ffad8e5ec8ce 100644 --- a/tensorflow/contrib/gan/python/features/__init__.py +++ b/tensorflow/contrib/gan/python/features/__init__.py @@ -12,7 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""TFGAN grouped API. Please see README.md for details and usage.""" +"""TFGAN features module. + +This module includes support for virtual batch normalization, buffer replay, +conditioning, etc. +""" from __future__ import absolute_import from __future__ import division @@ -22,10 +26,12 @@ from __future__ import print_function # pylint: disable=unused-import,wildcard-import from tensorflow.contrib.gan.python.features.python import clip_weights from tensorflow.contrib.gan.python.features.python import conditioning_utils +from tensorflow.contrib.gan.python.features.python import random_tensor_pool from tensorflow.contrib.gan.python.features.python import virtual_batchnorm from tensorflow.contrib.gan.python.features.python.clip_weights import * from tensorflow.contrib.gan.python.features.python.conditioning_utils import * +from tensorflow.contrib.gan.python.features.python.random_tensor_pool import * from tensorflow.contrib.gan.python.features.python.virtual_batchnorm import * # pylint: enable=unused-import,wildcard-import @@ -33,5 +39,6 @@ from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = clip_weights.__all__ _allowed_symbols += conditioning_utils.__all__ +_allowed_symbols += random_tensor_pool.__all__ _allowed_symbols += virtual_batchnorm.__all__ remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/gan/python/features/python/clip_weights_test.py b/tensorflow/contrib/gan/python/features/python/clip_weights_test.py index 030e37ec679ec58e3b534fd3644ffe1d23173404..2b7bb5f14e7f3d1b3f913d3426efaaae19079ffb 100644 --- a/tensorflow/contrib/gan/python/features/python/clip_weights_test.py +++ b/tensorflow/contrib/gan/python/features/python/clip_weights_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for tfgan.python.features.clip_weights.""" +"""Tests for features.clip_weights.""" from __future__ import absolute_import from __future__ import division @@ -31,17 +31,18 @@ class ClipWeightsTest(test.TestCase): """Tests for `discriminator_weight_clip`.""" def setUp(self): + super(ClipWeightsTest, self).setUp() self.variables = [variables.Variable(2.0)] self.tuple = collections.namedtuple( 'VarTuple', ['discriminator_variables'])(self.variables) def _test_weight_clipping_helper(self, use_tuple): - loss = self.variables[0] * 2.0 + loss = self.variables[0] opt = training.GradientDescentOptimizer(1.0) if use_tuple: - opt_clip = clip_weights.weight_clip(opt, self.variables, 0.1) + opt_clip = clip_weights.clip_variables(opt, self.variables, 0.1) else: - opt_clip = clip_weights.discriminator_weight_clip(opt, self.tuple, 0.1) + opt_clip = clip_weights.clip_discriminator_weights(opt, self.tuple, 0.1) train_op1 = opt.minimize(loss, var_list=self.variables) train_op2 = opt_clip.minimize(loss, var_list=self.variables) @@ -72,10 +73,14 @@ class ClipWeightsTest(test.TestCase): clip_weights.clip_discriminator_weights(opt, self.tuple, weight_clip=-1) else: with self.assertRaisesRegexp(ValueError, 'must be positive'): - clip_weights.clip_weights(opt, self.variables, weight_clip=-1) + clip_weights.clip_variables(opt, self.variables, weight_clip=-1) def test_incorrect_weight_clip_value_argsonly(self): self._test_incorrect_weight_clip_value_helper(False) def test_incorrect_weight_clip_value_tuple(self): self._test_incorrect_weight_clip_value_helper(True) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/gan/python/features/python/tensor_pool.py b/tensorflow/contrib/gan/python/features/python/random_tensor_pool.py similarity index 86% rename from tensorflow/contrib/gan/python/features/python/tensor_pool.py rename to tensorflow/contrib/gan/python/features/python/random_tensor_pool.py index 0bd2fa3db9427315ed623bc4d47d74683777bb94..ca904971fa8cb0440d3e0c9060f13cc214c9eaad 100644 --- a/tensorflow/contrib/gan/python/features/python/tensor_pool.py +++ b/tensorflow/contrib/gan/python/features/python/random_tensor_pool.py @@ -25,11 +25,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.gan.python.features.python import tensor_pool_impl +from tensorflow.contrib.gan.python.features.python import random_tensor_pool_impl # pylint: disable=wildcard-import -from tensorflow.contrib.gan.python.features.python.tensor_pool_impl import * +from tensorflow.contrib.gan.python.features.python.random_tensor_pool_impl import * # pylint: enable=wildcard-import from tensorflow.python.util.all_util import remove_undocumented -__all__ = tensor_pool_impl.__all__ +__all__ = random_tensor_pool_impl.__all__ remove_undocumented(__name__, __all__) diff --git a/tensorflow/contrib/gan/python/features/python/tensor_pool_impl.py b/tensorflow/contrib/gan/python/features/python/random_tensor_pool_impl.py similarity index 67% rename from tensorflow/contrib/gan/python/features/python/tensor_pool_impl.py rename to tensorflow/contrib/gan/python/features/python/random_tensor_pool_impl.py index 79318a69d291f11b7978e898423f1dd3e757466f..4cfae0de4451880cf8229903b0eb74b1c6e2e04d 100644 --- a/tensorflow/contrib/gan/python/features/python/tensor_pool_impl.py +++ b/tensorflow/contrib/gan/python/features/python/random_tensor_pool_impl.py @@ -42,8 +42,14 @@ __all__ = [ ] -def tensor_pool(input_value, - pool_size, +def _to_tuple(x): + if isinstance(x, (list, tuple)): + return tuple(x) + return (x,) + + +def tensor_pool(input_values, + pool_size=50, pooling_probability=0.5, name='tensor_pool'): """Queue storing input values and returning random previously stored ones. @@ -57,15 +63,18 @@ def tensor_pool(input_value, `pool_size` = 0 or `pooling_probability` = 0. Args: - input_value: A `Tensor` from which to read values to be pooled. - pool_size: An integer specifying the maximum size of the pool. + input_values: A `Tensor`, or a list or tuple of `Tensor`s from which to read + values to be pooled. + pool_size: An integer specifying the maximum size of the pool. Defaults to + 50. pooling_probability: A float `Tensor` specifying the probability of getting a value from the pool, as opposed to just the current input. name: A string prefix for the name scope for all tensorflow ops. Returns: - A `Tensor` which is with given probability either the `input_value` or a - randomly chosen sample that was previously inserted in the pool. + A `Tensor`, or a list or tuple of `Tensor`s (according to the type ofx + `input_values`) which is with given probability either the `input_values` or + a randomly chosen sample that was previously inserted in the pool. Raises: ValueError: If `pool_size` is negative. @@ -74,45 +83,57 @@ def tensor_pool(input_value, if pool_size < 0: raise ValueError('`pool_size` is negative.') elif pool_size == 0: - return input_value + return input_values - with ops.name_scope('{}_pool_queue'.format(name), - values=[input_value, pooling_probability]): + original_input_values = input_values + input_values = _to_tuple(input_values) + + with ops.name_scope( + '{}_pool_queue'.format(name), + values=input_values + (pooling_probability,)): pool_queue = data_flow_ops.RandomShuffleQueue( capacity=pool_size, min_after_dequeue=0, - dtypes=[input_value.dtype], + dtypes=[v.dtype for v in input_values], shapes=None) # In pseudeo code this code does the following: # if not pool_full: - # enqueue(input_value) - # return input_value + # enqueue(input_values) + # return input_values # else - # dequeue_value = dequeue_random_sample() - # enqueue(input_value) + # dequeue_values = dequeue_random_sample() + # enqueue(input_values) # if rand() < pooling_probability: - # return dequeue_value + # return dequeue_values # else - # return input_value + # return input_values def _get_input_value_pooled(): - enqueue_op = pool_queue.enqueue(input_value) + enqueue_op = pool_queue.enqueue(input_values) with ops.control_dependencies([enqueue_op]): - return array_ops.identity(input_value) + return tuple(array_ops.identity(v) for v in input_values) def _get_random_pool_value_and_enqueue_input(): - dequeue_value = pool_queue.dequeue() - with ops.control_dependencies([dequeue_value]): - enqueue_op = pool_queue.enqueue(input_value) + dequeue_values = _to_tuple(pool_queue.dequeue()) + with ops.control_dependencies(dequeue_values): + enqueue_op = pool_queue.enqueue(input_values) with ops.control_dependencies([enqueue_op]): prob = random_ops.random_uniform( (), dtype=dtypes.float32) < pooling_probability - return control_flow_ops.cond(prob, lambda: dequeue_value, - lambda: input_value) + return control_flow_ops.cond(prob, lambda: dequeue_values, + lambda: input_values) - output_value = control_flow_ops.cond( + output_values = _to_tuple(control_flow_ops.cond( pool_queue.size() < pool_size, _get_input_value_pooled, - _get_random_pool_value_and_enqueue_input) + _get_random_pool_value_and_enqueue_input)) + + # Make sure that the shape of `output_value` is set. + for input_value, output_value in zip(input_values, output_values): + output_value.set_shape(input_value.shape) - return output_value + if isinstance(original_input_values, list): + return list(output_values) + elif isinstance(original_input_values, tuple): + return output_values + return output_values[0] diff --git a/tensorflow/contrib/gan/python/features/python/tensor_pool_test.py b/tensorflow/contrib/gan/python/features/python/random_tensor_pool_test.py similarity index 70% rename from tensorflow/contrib/gan/python/features/python/tensor_pool_test.py rename to tensorflow/contrib/gan/python/features/python/random_tensor_pool_test.py index 49b77bb3fc56b91cd419f76b6eea920df7efe4a7..d8cf549cf71838178c9da01df462d41d81595fe5 100644 --- a/tensorflow/contrib/gan/python/features/python/tensor_pool_test.py +++ b/tensorflow/contrib/gan/python/features/python/random_tensor_pool_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for tf.contrib.gan.python.features.tensor_pool.""" +"""Tests for tf.contrib.gan.python.features.random_tensor_pool.""" from __future__ import absolute_import from __future__ import division @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.contrib.gan.python.features.python import tensor_pool_impl as tensor_pool +from tensorflow.contrib.gan.python.features.python.random_tensor_pool_impl import tensor_pool from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.platform import test @@ -32,7 +32,8 @@ class TensorPoolTest(test.TestCase): """Checks that `input_value` can have unknown shape.""" input_value = array_ops.placeholder( dtype=dtypes.int32, shape=[None, None, 3]) - output_value = tensor_pool.tensor_pool(input_value, pool_size=10) + output_value = tensor_pool(input_value, pool_size=10) + self.assertEqual(output_value.shape.as_list(), [None, None, 3]) with self.test_session(use_gpu=True) as session: for i in range(10): @@ -43,7 +44,8 @@ class TensorPoolTest(test.TestCase): def test_pool_sequence(self): """Checks that values are pooled and returned maximally twice.""" input_value = array_ops.placeholder(dtype=dtypes.int32, shape=[]) - output_value = tensor_pool.tensor_pool(input_value, pool_size=10) + output_value = tensor_pool(input_value, pool_size=10) + self.assertEqual(output_value.shape.as_list(), []) with self.test_session(use_gpu=True) as session: outs = [] @@ -59,8 +61,9 @@ class TensorPoolTest(test.TestCase): def test_never_pool(self): """Checks that setting `pooling_probability` to zero works.""" input_value = array_ops.placeholder(dtype=dtypes.int32, shape=[]) - output_value = tensor_pool.tensor_pool( + output_value = tensor_pool( input_value, pool_size=10, pooling_probability=0.0) + self.assertEqual(output_value.shape.as_list(), []) with self.test_session(use_gpu=True) as session: for i in range(50): @@ -72,10 +75,11 @@ class TensorPoolTest(test.TestCase): input_value = array_ops.placeholder(dtype=dtypes.int32, shape=[]) pool_size = 10 pooling_probability = 0.2 - output_value = tensor_pool.tensor_pool( + output_value = tensor_pool( input_value, pool_size=pool_size, pooling_probability=pooling_probability) + self.assertEqual(output_value.shape.as_list(), []) with self.test_session(use_gpu=True) as session: not_pooled = 0 @@ -89,6 +93,24 @@ class TensorPoolTest(test.TestCase): 1 - pooling_probability, atol=0.03) + def test_input_values_tuple(self): + """Checks that `input_values` can be a tuple.""" + input_values = (array_ops.placeholder(dtype=dtypes.int32, shape=[]), + array_ops.placeholder(dtype=dtypes.int32, shape=[])) + output_values = tensor_pool(input_values, pool_size=3) + self.assertEqual(len(output_values), len(input_values)) + for output_value in output_values: + self.assertEqual(output_value.shape.as_list(), []) + + with self.test_session(use_gpu=True) as session: + for i in range(10): + outs = session.run(output_values, { + input_values[0]: i, + input_values[1]: i + 1 + }) + self.assertEqual(len(outs), len(input_values)) + self.assertEqual(outs[1] - outs[0], 1) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/gan/python/losses/__init__.py b/tensorflow/contrib/gan/python/losses/__init__.py index 290ff867a1e443f20a63e27fd97f53fed8a6cc11..d9bf8ebfdf65dfc76e4569dcaf26e0e51c7fc107 100644 --- a/tensorflow/contrib/gan/python/losses/__init__.py +++ b/tensorflow/contrib/gan/python/losses/__init__.py @@ -12,7 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""TFGAN grouped API. Please see README.md for details and usage.""" +"""TFGAN losses and penalties. + +Losses can be used with individual arguments or with GANModel tuples. +""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/gan/python/namedtuples.py b/tensorflow/contrib/gan/python/namedtuples.py index 48f5e8e47dbcd5d32c23806b967a0d1e7403d2f7..3d4e315ebd0bd52b3b5e3e4a8655df8bfe9cebe8 100644 --- a/tensorflow/contrib/gan/python/namedtuples.py +++ b/tensorflow/contrib/gan/python/namedtuples.py @@ -79,6 +79,7 @@ class InfoGANModel( collections.namedtuple('InfoGANModel', GANModel._fields + ( 'structured_generator_inputs', 'predicted_distributions', + 'discriminator_and_aux_fn', ))): """An InfoGANModel contains all the pieces needed for InfoGAN training. @@ -91,6 +92,8 @@ class InfoGANModel( predicted_distributions: A list of tf.Distributions. Predicted by the recognizer, and used to evaluate the likelihood of the structured noise. List length should match `structured_generator_inputs`. + discriminator_and_aux_fn: The original discriminator function that returns + a tuple of (logits, `predicted_distributions`). """ diff --git a/tensorflow/contrib/gan/python/train.py b/tensorflow/contrib/gan/python/train.py index e9443f766bdc59cf45513c93e14390cd6126c295..c429ec48314b1f036beceb564bcf6d1e2a6d3b2e 100644 --- a/tensorflow/contrib/gan/python/train.py +++ b/tensorflow/contrib/gan/python/train.py @@ -215,7 +215,8 @@ def infogan_model( disc_scope, lambda x, y: discriminator_fn(x, y)[0], # conform to non-InfoGAN API structured_generator_inputs, - predicted_distributions) + predicted_distributions, + discriminator_fn) def acgan_model( @@ -326,6 +327,56 @@ def _use_aux_loss(aux_loss_weight): return False +def _tensor_pool_adjusted_model(model, tensor_pool_fn): + """Adjusts model using `tensor_pool_fn`. + + Args: + model: A GANModel tuple. + tensor_pool_fn: A function that takes (generated_data, generator_inputs), + stores them in an internal pool and returns a previously stored + (generated_data, generator_inputs) with some probability. For example + tfgan.features.tensor_pool. + + Returns: + A new GANModel tuple where discriminator outputs are adjusted by taking + pooled generator outputs as inputs. Returns the original model if + `tensor_pool_fn` is None. + + Raises: + ValueError: If tensor pool does not support the `model`. + """ + if tensor_pool_fn is None: + return model + + pooled_generated_data, pooled_generator_inputs = tensor_pool_fn( + (model.generated_data, model.generator_inputs)) + + if isinstance(model, namedtuples.GANModel): + with variable_scope.variable_scope(model.discriminator_scope, reuse=True): + dis_gen_outputs = model.discriminator_fn(pooled_generated_data, + pooled_generator_inputs) + return model._replace(discriminator_gen_outputs=dis_gen_outputs) + elif isinstance(model, namedtuples.ACGANModel): + with variable_scope.variable_scope(model.discriminator_scope, reuse=True): + (dis_pooled_gen_outputs, + dis_pooled_gen_classification_logits) = model.discriminator_fn( + pooled_generated_data, pooled_generator_inputs) + return model._replace( + discriminator_gen_outputs=dis_pooled_gen_outputs, + discriminator_gen_classification_logits= + dis_pooled_gen_classification_logits) + elif isinstance(model, namedtuples.InfoGANModel): + with variable_scope.variable_scope(model.discriminator_scope, reuse=True): + (dis_pooled_gen_outputs, + pooled_predicted_distributions) = model.discriminator_and_aux_fn( + pooled_generated_data, pooled_generator_inputs) + return model._replace( + discriminator_gen_outputs=dis_pooled_gen_outputs, + predicted_distributions=pooled_predicted_distributions) + else: + raise ValueError('Tensor pool does not support `model`: %s.' % type(model)) + + def gan_loss( # GANModel. model, @@ -338,6 +389,7 @@ def gan_loss( mutual_information_penalty_weight=None, aux_cond_generator_weight=None, aux_cond_discriminator_weight=None, + tensor_pool_fn=None, # Options. add_summaries=True): """Returns losses necessary to train generator and discriminator. @@ -363,6 +415,10 @@ def gan_loss( https://arxiv.org/abs/1610.09585 aux_cond_discriminator_weight: If not None: add a classification loss as in https://arxiv.org/abs/1610.09585 + tensor_pool_fn: A function that takes (generated_data, generator_inputs), + stores them in an internal pool and returns previous stored + (generated_data, generator_inputs). For example + `tf.gan.features.tensor_pool`. Defaults to None (not using tensor pool). add_summaries: Whether or not to add summaries for the losses. Returns: @@ -402,7 +458,9 @@ def gan_loss( # Create standard losses. gen_loss = generator_loss_fn(model, add_summaries=add_summaries) - dis_loss = discriminator_loss_fn(model, add_summaries=add_summaries) + dis_loss = discriminator_loss_fn( + _tensor_pool_adjusted_model(model, tensor_pool_fn), + add_summaries=add_summaries) # Add optional extra losses. if _use_aux_loss(gradient_penalty_weight): diff --git a/tensorflow/contrib/gan/python/train_test.py b/tensorflow/contrib/gan/python/train_test.py index 6b27b6926102b6e5a7ff134ceed75c23459a6534..58704e68594e947041697ec6cb1d240e1f505aae 100644 --- a/tensorflow/contrib/gan/python/train_test.py +++ b/tensorflow/contrib/gan/python/train_test.py @@ -23,6 +23,7 @@ import numpy as np from tensorflow.contrib.framework.python.ops import variables as variables_lib from tensorflow.contrib.gan.python import namedtuples from tensorflow.contrib.gan.python import train +from tensorflow.contrib.gan.python.features.python import random_tensor_pool from tensorflow.contrib.slim.python.slim import learning as slim_learning from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -145,14 +146,16 @@ def get_infogan_model(): return namedtuples.InfoGANModel( *get_gan_model(), structured_generator_inputs=[constant_op.constant(0)], - predicted_distributions=[categorical.Categorical([1.0])]) + predicted_distributions=[categorical.Categorical([1.0])], + discriminator_and_aux_fn=infogan_discriminator_model) def get_callable_infogan_model(): return namedtuples.InfoGANModel( *get_callable_gan_model(), structured_generator_inputs=[constant_op.constant(0)], - predicted_distributions=[categorical.Categorical([1.0])]) + predicted_distributions=[categorical.Categorical([1.0])], + discriminator_and_aux_fn=infogan_discriminator_model) def create_infogan_model(): @@ -213,6 +216,25 @@ def get_sync_optimizer(): replicas_to_aggregate=1) +def get_tensor_pool_fn(pool_size): + + def tensor_pool_fn_impl(input_values): + return random_tensor_pool.tensor_pool(input_values, pool_size=pool_size) + + return tensor_pool_fn_impl + + +def get_tensor_pool_fn_for_infogan(pool_size): + + def tensor_pool_fn_impl(input_values): + generated_data, generator_inputs = input_values + output_values = random_tensor_pool.tensor_pool( + [generated_data] + generator_inputs, pool_size=pool_size) + return output_values[0], output_values[1:] + + return tensor_pool_fn_impl + + class GANModelTest(test.TestCase): """Tests for `gan_model`.""" @@ -409,6 +431,114 @@ class GANLossTest(test.TestCase): def test_callable_acgan(self): self._test_acgan_helper(create_callable_acgan_model) + def _check_tensor_pool_adjusted_model_outputs(self, tensor1, tensor2, + pool_size): + history_values = [] + with self.test_session(use_gpu=True) as sess: + variables.global_variables_initializer().run() + for i in range(2 * pool_size): + t1, t2 = sess.run([tensor1, tensor2]) + history_values.append(t1) + if i < pool_size: + # For [0, pool_size), the pool is not full, tensor1 should be equal + # to tensor2 as the pool. + self.assertAllEqual(t1, t2) + else: + # For [pool_size, ?), the pool is full, tensor2 must be equal to some + # historical values of tensor1 (which is previously stored in the + # pool). + self.assertTrue(any([(v == t2).all() for v in history_values])) + + # Test `_tensor_pool_adjusted_model` for gan model. + def test_tensor_pool_adjusted_model_gan(self): + model = create_gan_model() + + new_model = train._tensor_pool_adjusted_model(model, None) + # 'Generator/dummy_g:0' and 'Discriminator/dummy_d:0' + self.assertEqual(2, len(ops.get_collection(ops.GraphKeys.VARIABLES))) + self.assertIs(new_model.discriminator_gen_outputs, + model.discriminator_gen_outputs) + + pool_size = 5 + new_model = train._tensor_pool_adjusted_model( + model, get_tensor_pool_fn(pool_size=pool_size)) + self.assertIsNot(new_model.discriminator_gen_outputs, + model.discriminator_gen_outputs) + # Check values. + self._check_tensor_pool_adjusted_model_outputs( + model.discriminator_gen_outputs, new_model.discriminator_gen_outputs, + pool_size) + + # Test _tensor_pool_adjusted_model for infogan model. + def test_tensor_pool_adjusted_model_infogan(self): + model = create_infogan_model() + + pool_size = 5 + new_model = train._tensor_pool_adjusted_model( + model, get_tensor_pool_fn_for_infogan(pool_size=pool_size)) + # 'Generator/dummy_g:0' and 'Discriminator/dummy_d:0' + self.assertEqual(2, len(ops.get_collection(ops.GraphKeys.VARIABLES))) + self.assertIsNot(new_model.discriminator_gen_outputs, + model.discriminator_gen_outputs) + self.assertIsNot(new_model.predicted_distributions, + model.predicted_distributions) + # Check values. + self._check_tensor_pool_adjusted_model_outputs( + model.discriminator_gen_outputs, new_model.discriminator_gen_outputs, + pool_size) + + # Test _tensor_pool_adjusted_model for acgan model. + def test_tensor_pool_adjusted_model_acgan(self): + model = create_acgan_model() + + pool_size = 5 + new_model = train._tensor_pool_adjusted_model( + model, get_tensor_pool_fn(pool_size=pool_size)) + # 'Generator/dummy_g:0' and 'Discriminator/dummy_d:0' + self.assertEqual(2, len(ops.get_collection(ops.GraphKeys.VARIABLES))) + self.assertIsNot(new_model.discriminator_gen_outputs, + model.discriminator_gen_outputs) + self.assertIsNot(new_model.discriminator_gen_classification_logits, + model.discriminator_gen_classification_logits) + # Check values. + self._check_tensor_pool_adjusted_model_outputs( + model.discriminator_gen_outputs, new_model.discriminator_gen_outputs, + pool_size) + + # Test tensor pool. + def _test_tensor_pool_helper(self, create_gan_model_fn): + model = create_gan_model_fn() + if isinstance(model, namedtuples.InfoGANModel): + tensor_pool_fn = get_tensor_pool_fn_for_infogan(pool_size=5) + else: + tensor_pool_fn = get_tensor_pool_fn(pool_size=5) + loss = train.gan_loss(model, tensor_pool_fn=tensor_pool_fn) + self.assertTrue(isinstance(loss, namedtuples.GANLoss)) + + # Check values. + with self.test_session(use_gpu=True) as sess: + variables.global_variables_initializer().run() + for _ in range(10): + sess.run([loss.generator_loss, loss.discriminator_loss]) + + def test_tensor_pool_gan(self): + self._test_tensor_pool_helper(create_gan_model) + + def test_tensor_pool_callable_gan(self): + self._test_tensor_pool_helper(create_callable_gan_model) + + def test_tensor_pool_infogan(self): + self._test_tensor_pool_helper(create_infogan_model) + + def test_tensor_pool_callable_infogan(self): + self._test_tensor_pool_helper(create_callable_infogan_model) + + def test_tensor_pool_acgan(self): + self._test_tensor_pool_helper(create_acgan_model) + + def test_tensor_pool_callable_acgan(self): + self._test_tensor_pool_helper(create_callable_acgan_model) + def test_doesnt_crash_when_in_nested_scope(self): with variable_scope.variable_scope('outer_scope'): gan_model = train.gan_model( diff --git a/tensorflow/contrib/graph_editor/transform.py b/tensorflow/contrib/graph_editor/transform.py index 2a97a79070ea3a0e634d76c5877e2307b6e2e577..14ac5296657d48c7f9e94d220c9e7e28af4d4353 100644 --- a/tensorflow/contrib/graph_editor/transform.py +++ b/tensorflow/contrib/graph_editor/transform.py @@ -173,6 +173,9 @@ def copy_op_handler(info, op, copy_shape=True): if op._original_op: op_._original_op = op._original_op + # Add op to the graph + info.graph_._add_op(op_) + return op_, op_.outputs diff --git a/tensorflow/contrib/image/BUILD b/tensorflow/contrib/image/BUILD index 157e97d237021d95c935a6be66aa57842b97125c..3ff02e085ee63fabf42b3cc4389f4605455f3800 100755 --- a/tensorflow/contrib/image/BUILD +++ b/tensorflow/contrib/image/BUILD @@ -9,10 +9,12 @@ package(default_visibility = ["//visibility:public"]) load( "//tensorflow:tensorflow.bzl", + "tf_cc_test", "tf_custom_op_library", "tf_gen_op_libs", "tf_gen_op_wrapper_py", "tf_kernel_library", + "tf_py_test", ) load("//tensorflow:tensorflow.bzl", "cuda_py_test") load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") @@ -23,6 +25,8 @@ tf_custom_op_library( "kernels/bipartite_match_op.cc", "kernels/image_ops.cc", "kernels/image_ops.h", + "kernels/segmentation_ops.cc", + "kernels/segmentation_ops.h", "ops/image_ops.cc", ], gpu_srcs = [ @@ -37,6 +41,8 @@ tf_kernel_library( "kernels/bipartite_match_op.cc", "kernels/image_ops.cc", "kernels/image_ops.h", + "kernels/segmentation_ops.cc", + "kernels/segmentation_ops.h", ], gpu_srcs = [ "kernels/image_ops_gpu.cu.cc", @@ -77,6 +83,7 @@ tf_custom_op_py_library( "//tensorflow/python:array_ops", "//tensorflow/python:common_shapes", "//tensorflow/python:constant_op", + "//tensorflow/python:control_flow_ops", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:linalg_ops", "//tensorflow/python:math_ops", @@ -106,10 +113,33 @@ tf_custom_op_library( name = "python/ops/_distort_image_ops.so", srcs = [ "kernels/adjust_hsv_in_yiq_op.cc", + "kernels/adjust_hsv_in_yiq_op.h", "ops/distort_image_ops.cc", ], + gpu_srcs = [ + "kernels/adjust_hsv_in_yiq_op_gpu.cu.cc", + "kernels/adjust_hsv_in_yiq_op.h", + ], deps = [ - "@protobuf_archive//:protobuf", + "//tensorflow/core/kernels:gpu_util_hdrs", + ], +) + +tf_cc_test( + name = "adjust_hsv_in_yiq_op_test", + size = "small", + srcs = [ + "kernels/adjust_hsv_in_yiq_op.h", + "kernels/adjust_hsv_in_yiq_op_test.cc", + ], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/kernels:ops_testutil", + "//tensorflow/core/kernels:ops_util", + "//third_party/eigen3", ], ) @@ -122,19 +152,6 @@ tf_gen_op_wrapper_py( deps = [":distort_image_ops_op_lib"], ) -cc_library( - name = "distort_image_ops_cc", - srcs = [ - "kernels/adjust_hsv_in_yiq_op.cc", - ], - deps = [ - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//third_party/eigen3", - ], - alwayslink = 1, -) - py_library( name = "distort_image_py", srcs = [ @@ -177,6 +194,21 @@ cuda_py_test( ], ) +tf_py_test( + name = "segmentation_test", + size = "medium", + srcs = ["python/kernel_tests/segmentation_test.py"], + additional_deps = [ + ":distort_image_py", + ":image_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + ], +) + tf_custom_op_library( name = "python/ops/_single_image_random_dot_stereograms.so", srcs = [ @@ -222,6 +254,23 @@ py_library( ], ) +cuda_py_test( + name = "single_image_random_dot_stereograms_ops_test", + size = "medium", + srcs = ["python/kernel_tests/single_image_random_dot_stereograms_ops_test.py"], + additional_deps = [ + ":distort_image_py", + ":image_py", + ":single_image_random_dot_stereograms_py", + "//third_party/py/numpy", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/image/__init__.py b/tensorflow/contrib/image/__init__.py index d030dffadeb9d67f7ffcbc197a2a3feb9b3b122d..cc8ed117ba2edcc7a53e609381166f17a2fbb45e 100755 --- a/tensorflow/contrib/image/__init__.py +++ b/tensorflow/contrib/image/__init__.py @@ -20,6 +20,8 @@ This module provides functions for image manipulation; currently, chrominance transformas (including changing saturation and hue) in YIQ space and projective transforms (including rotation) are supported. +## Image Transformation `Ops` + @@angles_to_projective_transforms @@compose_transforms @@adjust_yiq_hsv @@ -28,19 +30,29 @@ projective transforms (including rotation) are supported. @@transform @@translate @@translations_to_projective_transforms + +## Image Segmentation `Ops` + +@@connected_components + +## Matching `Ops` + @@bipartite_match + +## Random Dot Stereogram `Ops` + @@single_image_random_dot_stereograms """ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# pylint: disable=line-too-long from tensorflow.contrib.image.python.ops.distort_image_ops import adjust_hsv_in_yiq from tensorflow.contrib.image.python.ops.distort_image_ops import random_hsv_in_yiq from tensorflow.contrib.image.python.ops.image_ops import angles_to_projective_transforms from tensorflow.contrib.image.python.ops.image_ops import compose_transforms +from tensorflow.contrib.image.python.ops.image_ops import connected_components from tensorflow.contrib.image.python.ops.image_ops import rotate from tensorflow.contrib.image.python.ops.image_ops import transform from tensorflow.contrib.image.python.ops.image_ops import translate diff --git a/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.cc b/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.cc index f4962ed69dc68d4bad06ef29d7a167e0ba8ae044..478b716d88321101c971789f36c0ff8ecd3f418e 100644 --- a/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.cc +++ b/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.cc @@ -12,14 +12,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include +#if GOOGLE_CUDA +#define EIGEN_USE_GPU +#endif + +#include "tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.h" #include -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" -#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/util/work_sharder.h" @@ -36,10 +37,10 @@ class AdjustHsvInYiqOpBase : public OpKernel { struct ComputeOptions { const Tensor* input = nullptr; + Tensor* output = nullptr; const Tensor* delta_h = nullptr; const Tensor* scale_s = nullptr; const Tensor* scale_v = nullptr; - Tensor* output = nullptr; int64 channel_count = 0; }; @@ -65,7 +66,7 @@ class AdjustHsvInYiqOpBase : public OpKernel { scale_v.shape().DebugString())); auto channels = input.dim_size(input.dims() - 1); OP_REQUIRES( - context, channels == 3, + context, channels == kChannelSize, errors::InvalidArgument("input must have 3 channels but instead has ", channels, " channels.")); @@ -101,53 +102,21 @@ class AdjustHsvInYiqOp : public AdjustHsvInYiqOpBase { const Tensor* input = options.input; Tensor* output = options.output; const int64 channel_count = options.channel_count; - static const int kChannelSize = 3; auto input_data = input->shaped({channel_count, kChannelSize}); const float delta_h = options.delta_h->scalar()(); const float scale_s = options.scale_s->scalar()(); const float scale_v = options.scale_v->scalar()(); auto output_data = output->shaped({channel_count, kChannelSize}); + float tranformation_matrix[kChannelSize * kChannelSize] = {0}; + internal::compute_tranformation_matrix( + delta_h, scale_s, scale_v, tranformation_matrix); const int kCostPerChannel = 10; const DeviceBase::CpuWorkerThreads& worker_threads = *context->device()->tensorflow_cpu_worker_threads(); Shard(worker_threads.num_threads, worker_threads.workers, channel_count, kCostPerChannel, - [channel_count, &input_data, &output_data, delta_h, scale_s, scale_v]( + [channel_count, &input_data, &output_data, &tranformation_matrix]( int64 start_channel, int64 end_channel) { - // Using approximate linear transfomation described in: - // https://beesbuzz.biz/code/hsv_color_transforms.php - /** Get the constants from sympy - from sympy import Matrix - from sympy.abc import u, w - # Projection matrix to YIQ. http://en.wikipedia.org/wiki/YIQ - tyiq = Matrix([[0.299, 0.587, 0.114], - [0.596, -0.274, -0.322], - [0.211, -0.523, 0.312]]) - # Hue rotation matrix in YIQ space. - hue_proj = Matrix(3,3, [v, 0, 0, 0, vsu, -vsw, 0, vsw, vsu]) - m = tyiq.inv() * hue_proj * tyiq - **/ - // TODO(huangyp): directly compute the projection matrix from tyiq. - static const float t[kChannelSize][kChannelSize][kChannelSize] = { - {{.299, .701, .16862179492229}, - {.587, -.587, .329804745287403}, - {.114, -.114, -0.498426540209694}}, - {{.299, -.299, -.327963394172371}, - {.587, .413, .0346106879248821}, - {.114, -.114, .293352706247489}}, - {{.299, -.299, 1.24646136576682}, - {.587, -.587, -1.04322888291964}, - {.114, .886, -.203232482847173}}}; - float m[kChannelSize][kChannelSize] = {{0.}}; - float su = scale_s * std::cos(delta_h); - float sw = scale_s * std::sin(delta_h); - for (int q_index = 0; q_index < kChannelSize; q_index++) { - for (int p_index = 0; p_index < kChannelSize; p_index++) { - m[q_index][p_index] = scale_v * (t[q_index][p_index][0] + - t[q_index][p_index][1] * su + - t[q_index][p_index][2] * sw); - } - } // Applying projection matrix to input RGB vectors. const float* p = input_data.data() + start_channel * kChannelSize; float* q = output_data.data() + start_channel * kChannelSize; @@ -155,7 +124,9 @@ class AdjustHsvInYiqOp : public AdjustHsvInYiqOpBase { for (int q_index = 0; q_index < kChannelSize; q_index++) { q[q_index] = 0; for (int p_index = 0; p_index < kChannelSize; p_index++) { - q[q_index] += m[q_index][p_index] * p[p_index]; + q[q_index] += + p[p_index] * + tranformation_matrix[q_index + kChannelSize * p_index]; } } p += kChannelSize; @@ -165,8 +136,33 @@ class AdjustHsvInYiqOp : public AdjustHsvInYiqOpBase { } }; -REGISTER_KERNEL_BUILDER(Name("AdjustHsvInYiq").Device(DEVICE_CPU), - AdjustHsvInYiqOp); +REGISTER_KERNEL_BUILDER( + Name("AdjustHsvInYiq").Device(DEVICE_CPU).TypeConstraint("T"), + AdjustHsvInYiqOp); + +#if GOOGLE_CUDA +template <> +class AdjustHsvInYiqOp : public AdjustHsvInYiqOpBase { + public: + explicit AdjustHsvInYiqOp(OpKernelConstruction* context) + : AdjustHsvInYiqOpBase(context) {} + + void DoCompute(OpKernelContext* ctx, const ComputeOptions& options) override { + const int64 number_of_elements = options.input->NumElements(); + if (number_of_elements <= 0) { + return; + } + const float* delta_h = options.delta_h->flat().data(); + const float* scale_s = options.scale_s->flat().data(); + const float* scale_v = options.scale_v->flat().data(); + functor::AdjustHsvInYiqGPU()(ctx, options.channel_count, options.input, + delta_h, scale_s, scale_v, options.output); + } +}; + +REGISTER_KERNEL_BUILDER( + Name("AdjustHsvInYiq").Device(DEVICE_GPU).TypeConstraint("T"), + AdjustHsvInYiqOp); +#endif -// TODO(huangyp): add the GPU kernel } // namespace tensorflow diff --git a/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.h b/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.h new file mode 100644 index 0000000000000000000000000000000000000000..194ae2ba47456cac66c01989a78ab4ce607d1295 --- /dev/null +++ b/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.h @@ -0,0 +1,87 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_IMAGE_KERNELS_ADJUST_HSV_IN_YIQ_OP_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_IMAGE_KERNELS_ADJUST_HSV_IN_YIQ_OP_H_ + +#if GOOGLE_CUDA +#define EIGEN_USE_GPU +#endif // GOOGLE_CUDA + +#include +#include "third_party/eigen3/Eigen/Core" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/types.h" + +namespace tensorflow { + +static constexpr int kChannelSize = 3; + +namespace internal { + +template +EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void compute_tranformation_matrix( + const float delta_h, const float scale_s, const float scale_v, + float* matrix) { + static_assert(MATRIX_SIZE == kChannelSize * kChannelSize, + "Size of matrix should be 9."); + // Projection matrix from RGB to YIQ. Numbers from wikipedia + // https://en.wikipedia.org/wiki/YIQ + Eigen::Matrix3f yiq; + /* clang-format off */ + yiq << 0.299, 0.587, 0.114, + 0.596, -0.274, -0.322, + 0.211, -0.523, 0.312; + Eigen::Matrix3f yiq_inverse; + yiq_inverse << 1, 0.95617069, 0.62143257, + 1, -0.2726886, -0.64681324, + 1, -1.103744, 1.70062309; + /* clang-format on */ + // Construct hsv linear transformation matrix in YIQ space. + // https://beesbuzz.biz/code/hsv_color_transforms.php + float vsu = scale_v * scale_s * std::cos(delta_h); + float vsw = scale_v * scale_s * std::sin(delta_h); + Eigen::Matrix3f hsv_transform; + /* clang-format off */ + hsv_transform << scale_v, 0, 0, + 0, vsu, -vsw, + 0, vsw, vsu; + /* clang-format on */ + // Compute final transformation matrix = inverse_yiq * hsv_transform * yiq + Eigen::Map> eigen_matrix(matrix); + eigen_matrix = yiq_inverse * hsv_transform * yiq; +} +} // namespace internal + +#if GOOGLE_CUDA +typedef Eigen::GpuDevice GPUDevice; + +namespace functor { + +struct AdjustHsvInYiqGPU { + void operator()(OpKernelContext* ctx, int channel_count, + const Tensor* const input, const float* const delta_h, + const float* const scale_s, const float* const scale_v, + Tensor* const output); +}; + +} // namespace functor + +#endif // GOOGLE_CUDA + +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_IMAGE_KERNELS_ADJUST_HSV_IN_YIQ_OP_H_ diff --git a/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op_gpu.cu.cc b/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op_gpu.cu.cc new file mode 100644 index 0000000000000000000000000000000000000000..b71ff9cd507faac66b3a33d3c02ec9b5901d814a --- /dev/null +++ b/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op_gpu.cu.cc @@ -0,0 +1,84 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#if GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#include "tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.h" +#include "tensorflow/core/kernels/gpu_utils.h" +#include "tensorflow/core/platform/stream_executor.h" +#include "tensorflow/core/util/cuda_kernel_helper.h" + +namespace tensorflow { + +namespace internal { + +__global__ void compute_tranformation_matrix_cuda(const float* const delta_h, + const float* const scale_s, + const float* const scale_v, + float* const matrix, + const int matrix_size) { + if (matrix_size == kChannelSize * kChannelSize) { + compute_tranformation_matrix( + *delta_h, *scale_s, *scale_v, matrix); + } +} +} // namespace internal + +namespace functor { + +void AdjustHsvInYiqGPU::operator()(OpKernelContext* ctx, int channel_count, + const Tensor* const input, + const float* const delta_h, + const float* const scale_s, + const float* const scale_v, + Tensor* const output) { + const uint64 m = channel_count; + const uint64 k = kChannelSize; + const uint64 n = kChannelSize; + auto* cu_stream = ctx->eigen_device().stream(); + OP_REQUIRES(ctx, cu_stream, errors::Internal("No GPU stream available.")); + Tensor tranformation_matrix; + OP_REQUIRES_OK(ctx, ctx->allocate_temp( + DT_FLOAT, TensorShape({kChannelSize * kChannelSize}), + &tranformation_matrix)); + // TODO(huangyp): It takes about 3.5 us to comute tranformation_matrix + // with one thread. Improve its performance if necessary. + internal::compute_tranformation_matrix_cuda<<<1, 1, 0, cu_stream>>>( + delta_h, scale_s, scale_v, tranformation_matrix.flat().data(), + tranformation_matrix.flat().size()); + // Call cuBlas C = A * B directly. + auto no_transpose = perftools::gputools::blas::Transpose::kNoTranspose; + auto a_ptr = + AsDeviceMemory(input->flat().data(), input->flat().size()); + auto b_ptr = AsDeviceMemory(tranformation_matrix.flat().data(), + tranformation_matrix.flat().size()); + auto c_ptr = AsDeviceMemory(output->flat().data(), + output->flat().size()); + auto* stream = ctx->op_device_context()->stream(); + OP_REQUIRES(ctx, stream, errors::Internal("No GPU stream available.")); + // TODO(huangyp): share/use autotune cublas algorithms in Matmul.op. + bool blas_launch_status = + stream + ->ThenBlasGemm(no_transpose, no_transpose, n, m, k, 1.0f, b_ptr, n, + a_ptr, k, 0.0f, &c_ptr, n) + .ok(); + if (!blas_launch_status) { + ctx->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m, + ", n=", n, ", k=", k)); + } +} +} // namespace functor +} // namespace tensorflow +#endif // GOOGLE_CUDA diff --git a/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op_test.cc b/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..4cbbd277840133c9419f9ce3d945b7d099679dc0 --- /dev/null +++ b/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op_test.cc @@ -0,0 +1,48 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { + +class AdjustHsvInYiqOpTest : public OpsTestBase { + protected: +}; + +TEST_F(AdjustHsvInYiqOpTest, IdentiyTransformMatrix) { + Tensor matrix(allocator(), DT_FLOAT, TensorShape({9})); + internal::compute_tranformation_matrix<9>(0.0, 1.0, 1.0, + matrix.flat().data()); + Tensor expected(allocator(), DT_FLOAT, TensorShape({9})); + test::FillValues(&expected, {1, 0, 0, 0, 1, 0, 0, 0, 1}); + test::ExpectClose(matrix, expected); +} + +TEST_F(AdjustHsvInYiqOpTest, ScaleValueTransformMatrix) { + float scale_v = 2.3; + Tensor matrix(allocator(), DT_FLOAT, TensorShape({9})); + internal::compute_tranformation_matrix<9>(0.0, 1.0, scale_v, + matrix.flat().data()); + Tensor expected(allocator(), DT_FLOAT, TensorShape({9})); + test::FillValues(&expected, + {scale_v, 0, 0, 0, scale_v, 0, 0, 0, scale_v}); + test::ExpectClose(matrix, expected); +} + +} // end namespace tensorflow diff --git a/tensorflow/contrib/image/kernels/segmentation_ops.cc b/tensorflow/contrib/image/kernels/segmentation_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..fe8bf6e21c7b7310527668324571774e8bc50893 --- /dev/null +++ b/tensorflow/contrib/image/kernels/segmentation_ops.cc @@ -0,0 +1,139 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// See docs for ImageConnectedComponents in ../ops/image_ops.cc, and description +// of the algorithm in segmentation_ops.h. + +#define EIGEN_USE_THREADS + +#include "tensorflow/contrib/image/kernels/segmentation_ops.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +using tensorflow::functor::BlockedImageUnionFindFunctor; +using tensorflow::functor::FindRootFunctor; +using tensorflow::functor::ImageConnectedComponentsFunctor; +using tensorflow::functor::TensorRangeFunctor; + +using OutputType = typename BlockedImageUnionFindFunctor::OutputType; + +// Computes connected components on batches of 2D images. +template +class ImageConnectedComponents : public OpKernel { + public: + explicit ImageConnectedComponents(OpKernelConstruction* ctx) + : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + const Tensor& images_t = ctx->input(0); + OP_REQUIRES(ctx, images_t.shape().dims() == 3, + errors::InvalidArgument("Input images must have rank 3")); + Tensor forest_t, rank_t; + OP_REQUIRES_OK(ctx, ctx->allocate_temp(tensorflow::DT_INT64, + images_t.shape(), &forest_t)); + OP_REQUIRES_OK(ctx, ctx->allocate_temp(tensorflow::DT_INT64, + images_t.shape(), &rank_t)); + Tensor* output_t; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, images_t.shape(), &output_t)); + + // Fill forest with values from 0 to n - 1, so that each node points to + // itself. + TensorRangeFunctor()(ctx->eigen_device(), + forest_t.flat()); + auto rank = rank_t.tensor(); + rank.device(ctx->eigen_device()) = rank.constant(OutputType(0)); + + const auto images = images_t.tensor(); + auto forest = forest_t.tensor(); + ImageConnectedComponentsFunctor()( + ctx, output_t->flat(), images, forest, rank); + } +}; + +using CPUDevice = Eigen::ThreadPoolDevice; + +namespace functor { + +// Connected components CPU implementation. See `segmentation_ops.h` for a +// description of the algorithm. +template +struct ImageConnectedComponentsFunctor { + void operator()(OpKernelContext* ctx, + typename TTypes::Flat output, + typename TTypes::ConstTensor images, + typename TTypes::Tensor forest, + typename TTypes::Tensor rank) { + const int64 num_images = images.dimension(0), + num_rows = images.dimension(1), num_cols = images.dimension(2), + num_elements = images.size(); + // Bail out early for an empty image--no work to do. + if (num_elements == 0) { + return; + } + auto worker_threads = ctx->device()->tensorflow_cpu_worker_threads(); + BlockedImageUnionFindFunctor union_find( + images.data(), num_rows, num_cols, forest.data(), rank.data()); + while (union_find.can_merge()) { + union_find.merge_blocks(); + int64 num_blocks_vertically = union_find.num_blocks_vertically(); + int64 num_blocks_horizontally = union_find.num_blocks_horizontally(); + // Merging each block calls union_down for each pixel in a row of the + // block, and union_right for each pixel in a column of the block. Assume + // 20 instructions for each call to union_down or union_right. find() may + // loop more while searching for the root, but this should not be very + // significant. + int cost = (union_find.block_height() + union_find.block_width()) * 20; + Shard(worker_threads->num_threads, worker_threads->workers, + num_images * num_blocks_vertically * num_blocks_horizontally, cost, + [&union_find, num_images, num_blocks_vertically, + num_blocks_horizontally](int64 start_block, int64 limit_block) { + for (int64 i = start_block; i < limit_block; i++) { + int64 block_x = i % num_blocks_horizontally; + int64 block_y = + (i / num_blocks_horizontally) % num_blocks_vertically; + int64 image = + i / (num_blocks_horizontally * num_blocks_vertically); + union_find.merge_internal_block_edges(image, block_y, block_x); + } + }); + } + FindRootFunctor()(ctx->eigen_device(), output, + images.data(), union_find); + } +}; + +} // end namespace functor + +#define REGISTER_IMAGE_CONNECTED_COMPONENTS(TYPE) \ + REGISTER_KERNEL_BUILDER(Name("ImageConnectedComponents") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("dtype"), \ + ImageConnectedComponents) +// Connected components (arguably) make sense for number, bool, and string types +TF_CALL_NUMBER_TYPES(REGISTER_IMAGE_CONNECTED_COMPONENTS); +TF_CALL_bool(REGISTER_IMAGE_CONNECTED_COMPONENTS); +TF_CALL_string(REGISTER_IMAGE_CONNECTED_COMPONENTS); +#undef REGISTER_IMAGE_CONNECTED_COMPONENTS + +// TODO(ringwalt): Implement on GPU. We probably want to stick to the original +// algorithm by Stava and Benes there for efficiency (computing small blocks in +// shared memory in CUDA thread blocks, instead of starting with single-pixel +// blocks). + +} // end namespace tensorflow diff --git a/tensorflow/contrib/image/kernels/segmentation_ops.h b/tensorflow/contrib/image/kernels/segmentation_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..0957d5fd10f02daad3d8d51aadec9ce9da2660b5 --- /dev/null +++ b/tensorflow/contrib/image/kernels/segmentation_ops.h @@ -0,0 +1,303 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_IMAGE_KERNELS_SEGMENTATION_OPS_H_ +#define TENSORFLOW_CONTRIB_IMAGE_KERNELS_SEGMENTATION_OPS_H_ + +// Connected component analysis. The op is described in ../ops/image_ops.cc. A +// description of the algorithm appears below. + +#define EIGEN_USE_THREADS + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/work_sharder.h" + +namespace tensorflow { + +namespace functor { + +template +bool is_nonzero(T value) { + return value != T(0); +} + +template <> +bool is_nonzero(string value) { + return value.size() != 0; +} + +// Processes each pixel of an image for union-find, in parallel blocks. This is +// loosely based on the algorithm in "GPU Computing Gems" by Ondrej Stava and +// Bedrich Benes, available here: +// http://hpcg.purdue.edu/bbenes/papers/Stava2011CCL.pdf +// The bulk of the process uses blocks of each image, which have each been +// processed separately. As long as there are multiple blocks in the image, we +// double the height and width of the blocks, creating new blocks which each +// consist of 2x2 previous sub-blocks. On each new block, we process adjacent +// pixels from the previous sub-blocks serially. However, the new blocks are not +// connected, so we can process each block in parallel. +// The GPU algorithm first processes blocks of a fixed size in GPU shared +// memory, with one image block per CUDA thread block. On the CPU, we just start +// with a block size of a single pixel, and borrow the rest of the algorithm +// unchanged. +template +class BlockedImageUnionFindFunctor { + public: + using OutputType = int64; + + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE BlockedImageUnionFindFunctor( + const T* images, const int64 num_rows, const int64 num_cols, + OutputType* forest, OutputType* rank) + : images_(images), + num_rows_(num_rows), + num_cols_(num_cols), + block_height_(1), + block_width_(1), + forest_(forest), + rank_(rank) {} + + // Returns the root of the tree that the pixel at the given index belongs to. + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE OutputType + find(OutputType index) const { + while (forest_[index] != index) { + index = forest_[index]; + } + return index; + } + + // Returns the number of blocks along the y axis. + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE int64 num_blocks_vertically() const { + return (num_rows_ + block_height_ - 1) / block_height_; + } + + // Returns the number of blocks along the x axis. + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE int64 num_blocks_horizontally() const { + return (num_cols_ + block_width_ - 1) / block_width_; + } + + // Returns the total number of blocks in each image. + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE int64 num_blocks() const { + return num_blocks_vertically() * num_blocks_horizontally(); + } + + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE int64 block_height() const { + return block_height_; + } + + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE int64 block_width() const { + return block_width_; + } + + // Returns whether we may merge again (the image contains more than one + // block). + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool can_merge() const { + return block_height_ < num_rows_ || block_width_ < num_cols_; + } + + // Doubles the block size. After this method, you must call + // `merge_internal_block_edges` for each image and each *new* block's xy + // coordinates (typically in parallel). + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void merge_blocks() { + block_height_ *= 2; + block_width_ *= 2; + } + + // Processes pairs of pixels within the block which were adjacent in the four + // sub-blocks. This must be done at each stage so that the connected + // components in each block are joined correctly. + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void merge_internal_block_edges( + int64 image_index, int64 block_vertical_index, + int64 block_horizontal_index) const { + int64 block_start_y = block_vertical_index * block_height_; + int64 block_start_x = block_horizontal_index * block_width_; + // Merge the 4 sub-blocks horizontally (fixing the vertical seam). + int64 block_center_x = block_start_x + block_width_ / 2 - 1; + if (0 <= block_center_x && block_center_x + 1 < num_cols_) { + int64 merge_blocks_limit_y = + std::min(num_rows_, block_start_y + block_height_); + for (int64 y = block_start_y; y < merge_blocks_limit_y; y++) { + union_right(image_index, y, block_center_x); + } + } + // Merge the 4 sub-blocks vertically (fixing the horizontal seam). + int64 block_center_y = block_start_y + block_height_ / 2 - 1; + if (0 <= block_center_y && block_center_y + 1 < num_rows_) { + int64 merge_blocks_limit_x = + std::min(num_cols_, block_start_x + block_width_); + for (int64 x = block_start_x; x < merge_blocks_limit_x; x++) { + union_down(image_index, block_center_y, x); + } + } + } + + private: + // The input image(s). + const T* const images_; + const int64 num_rows_; + const int64 num_cols_; + // Current height of each sub-block of the image. + int64 block_height_; + // Current width of each sub-block of the image. + int64 block_width_; + // Union-find forest. This has the same size as `images_`, and each entry + // holds the index of its parent in `images_` (roots hold their own index). + // Cycles should not occur. + OutputType* const forest_; + // Union-find rank of each pixel. + OutputType* const rank_; + + // Unions the pixel with the pixel below it if applicable (both pixels are + // true, and the pixel is not in the last row). + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void union_down(OutputType batch, + OutputType row, + OutputType col) const { + T pixel = read_pixel(batch, row, col); + if (is_nonzero(pixel)) { + const int64 index_a = col + num_cols_ * (row + num_rows_ * batch); + if (row + 1 < num_rows_ && read_pixel(batch, row + 1, col) == pixel) { + const int64 index_b = col + num_cols_ * (row + 1 + num_rows_ * batch); + do_union(index_a, index_b); + } + } + } + + // Unions the pixel with the pixel to the right of it if applicable. + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void union_right(OutputType batch, + OutputType row, + OutputType col) const { + T pixel = read_pixel(batch, row, col); + if (is_nonzero(pixel)) { + const int64 index_a = col + num_cols_ * (row + num_rows_ * batch); + if (col + 1 < num_cols_ && read_pixel(batch, row, col + 1) == pixel) { + const int64 index_b = col + 1 + num_cols_ * (row + num_rows_ * batch); + do_union(index_a, index_b); + } + } + } + + // Reads a pixel value in the images. + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T + read_pixel(const OutputType batch, const OutputType row, + const OutputType col) const { + return images_[col + num_cols_ * (row + num_rows_ * batch)]; + } + + // Unions the trees that the two pixels belong to, using their index in the + // `images_` array. + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void do_union( + OutputType index_a, OutputType index_b) const { + // Find the roots of index_a and index_b in the forest, and make one the + // child of the other. + index_a = find(index_a); + index_b = find(index_b); + const OutputType rank_a = rank_[index_a]; + const OutputType rank_b = rank_[index_b]; + OutputType parent, child; + if (index_a == index_b) { + return; + } else if (rank_a < rank_b) { + parent = index_a; + child = index_b; + } else { + parent = index_b; + child = index_a; + rank_[parent]++; + } + forest_[child] = parent; + } +}; + +// Runs the ImageUnionFindFunctor on all pixels. Will require different CPU and +// GPU implementations. +template +class ImageConnectedComponentsFunctor { + public: + using OutputType = typename BlockedImageUnionFindFunctor::OutputType; + + void operator()(OpKernelContext* ctx, + typename TTypes::ConstTensor images, + typename TTypes::Tensor forest, + typename TTypes::Tensor rank); +}; + +// Fills a flat Tensor with indices from 0 to n - 1. +template +class TensorRangeFunctor { + public: + using OutputType = typename BlockedImageUnionFindFunctor::OutputType; + + void operator()(const Device& device, + typename TTypes::Flat tensor) { + tensor.device(device) = tensor.generate(TensorRangeGenerator()); + } + + private: + class TensorRangeGenerator { + public: + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE OutputType + operator()(const Eigen::array& coords) const { + return coords[0]; + } + }; +}; + +// Given the union-find forest, generates the root index for each node. This +// gives us arbitrary, usually non-consecutive ids for each connected component. +// The ids are massaged in Python to get deterministic, consecutive ids. +template +class FindRootFunctor { + public: + using OutputType = typename BlockedImageUnionFindFunctor::OutputType; + + void operator()(const Device& device, + typename TTypes::Flat component_ids, + const T* images, + const BlockedImageUnionFindFunctor& union_find) { + component_ids.device(device) = + component_ids.generate(FindRootGenerator(images, union_find)); + } + + private: + class FindRootGenerator { + const T* const images_; + const BlockedImageUnionFindFunctor union_find_; + + public: + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE FindRootGenerator( + const T* images, BlockedImageUnionFindFunctor union_find) + : images_(images), union_find_(union_find) {} + + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE OutputType + operator()(const Eigen::array& coords) const { + if (is_nonzero(images_[coords[0]])) { + // True pixels have an arbitrary segment id > 0. The segment ids will be + // made contiguous later. + return union_find_.find(coords[0]) + 1; + } else { + // False pixels have a segment of 0. + return 0; + } + } + }; +}; + +} // end namespace functor + +} // namespace tensorflow + +#endif // TENSORFLOW_CONTRIB_IMAGE_KERNELS_SEGMENTATION_OPS_H_ diff --git a/tensorflow/contrib/image/ops/image_ops.cc b/tensorflow/contrib/image/ops/image_ops.cc index 4527fdd87a8be3390fb0840410218ab74a27f0d2..68771b3d054a64ba94141c092e20df1ed6b2339b 100644 --- a/tensorflow/contrib/image/ops/image_ops.cc +++ b/tensorflow/contrib/image/ops/image_ops.cc @@ -98,4 +98,34 @@ col_to_row_match_indices: A vector of length num_columns, which is the number `col_to_row_match_indices[j]`. )doc"); +REGISTER_OP("ImageConnectedComponents") + .Input("image: dtype") + .Output("components: int64") + .Attr( + "dtype: {int64, int32, uint16, int16, uint8, int8, half, float, " + "double, bool, string}") + .SetShapeFn([](InferenceContext* c) { + return shape_inference::UnchangedShape(c); + }) + .Doc(R"doc( +Find the connected components of image(s). + +For each image (along the 0th axis), all connected components of adjacent pixels +with the same non-zero value are detected and given unique ids. + +The returned `components` tensor has 0s for the zero pixels of `images`, and +arbitrary nonzero ids for the connected components of nonzero values. Ids are +unique across all of the images, and are in row-major order by the first pixel +in the component. + +Uses union-find with union by rank but not path compression, giving a runtime of +`O(n log n)`. See: + https://en.wikipedia.org/wiki/Disjoint-set_data_structure#Time_Complexity + +image: Image(s) with shape (N, H, W). +components: Component ids for each pixel in "image". Same shape as "image". Zero + pixels all have an output of 0, and all components of adjacent pixels with + the same value are given consecutive ids, starting from 1. +)doc"); + } // namespace tensorflow diff --git a/tensorflow/contrib/image/ops/single_image_random_dot_stereograms_ops.cc b/tensorflow/contrib/image/ops/single_image_random_dot_stereograms_ops.cc index f8b56ab1c5400694b3aa8d4a0c19c7769aa8cbce..1f41f243f2ebc0d1e884728defa160bf6d6c34ce 100755 --- a/tensorflow/contrib/image/ops/single_image_random_dot_stereograms_ops.cc +++ b/tensorflow/contrib/image/ops/single_image_random_dot_stereograms_ops.cc @@ -19,6 +19,10 @@ limitations under the License. namespace tensorflow { +using shape_inference::DimensionHandle; +using shape_inference::InferenceContext; +using shape_inference::ShapeHandle; + REGISTER_OP("SingleImageRandomDotStereograms") .Attr("T: {double,float,int64,int32}") .Input("depth_values: T") @@ -37,6 +41,26 @@ REGISTER_OP("SingleImageRandomDotStereograms") "output_image_shape: shape = { dim {size:1024} dim {size: 768} dim " "{size: 1}}") .Attr("output_data_window: shape = { dim {size:1022} dim {size: 757}}") + .SetShapeFn([](InferenceContext* c) { + // Validate that the output_image_shape attr is correct. + // NOTE: The output_image_shape is [X, Y, C] + // while the output data is [Y, X, C] (or [H, W, C]). + // As a result, by default the output_image_shape has the value + // of [1024, 768, 1] but the output data will be [768, 1024, 1]. + PartialTensorShape shape; + TF_RETURN_IF_ERROR(c->GetAttr("output_image_shape", &shape)); + ShapeHandle output_image_shape; + TF_RETURN_IF_ERROR( + c->MakeShapeFromPartialTensorShape(shape, &output_image_shape)); + DimensionHandle x_dim = c->Dim(output_image_shape, 0); + DimensionHandle y_dim = c->Dim(output_image_shape, 1); + + int colors; + TF_RETURN_IF_ERROR(c->GetAttr("number_colors", &colors)); + + c->set_output(0, c->MakeShape({y_dim, x_dim, colors > 256? c->MakeDim(3) : c->MakeDim(1)})); + return Status::OK(); + }) .Doc(R"doc( Outputs a single image random dot stereogram for export via encode_PNG/JPG OP. diff --git a/tensorflow/contrib/image/python/kernel_tests/distort_image_ops_test.py b/tensorflow/contrib/image/python/kernel_tests/distort_image_ops_test.py index b85f19d29b79defa10493bdbaa4a1b237cb2a9ee..a495b58b7f6481d4cdedf73f23615d0390eb6a45 100644 --- a/tensorflow/contrib/image/python/kernel_tests/distort_image_ops_test.py +++ b/tensorflow/contrib/image/python/kernel_tests/distort_image_ops_test.py @@ -172,7 +172,7 @@ class AdjustValueInYiqTest(test_util.TensorFlowTestCase): raise AssertionError('Invalid test style: %s' % (test_style)) y_np = self._adjust_value_in_yiq_np(x_np, scale) y_tf = self._adjust_value_in_yiq_tf(x_np, scale) - self.assertAllClose(y_tf, y_np, rtol=2e-5, atol=1e-5) + self.assertAllClose(y_tf, y_np, rtol=2e-4, atol=1e-4) def test_invalid_shapes(self): x_np = np.random.rand(2, 3) * 255. @@ -237,7 +237,7 @@ class AdjustSaturationInYiqTest(test_util.TensorFlowTestCase): raise AssertionError('Invalid test style: %s' % (test_style)) y_baseline = self._adjust_saturation_in_yiq_np(x_np, scale) y_tf = self._adjust_saturation_in_yiq_tf(x_np, scale) - self.assertAllClose(y_tf, y_baseline, rtol=2e-5, atol=1e-5) + self.assertAllClose(y_tf, y_baseline, rtol=2e-4, atol=1e-4) def test_invalid_shapes(self): x_np = np.random.rand(2, 3) * 255. @@ -291,6 +291,9 @@ class AdjustHueInYiqBenchmark(test.Benchmark): def benchmark_adjust_hue_in_yiqCpuAll(self): self._benchmark_adjust_hue_in_yiq('/cpu:0', None) + def benchmark_adjust_hue_in_yiq_gpu_all(self): + self._benchmark_adjust_hue_in_yiq(test.gpu_device_name(), None) + class AdjustSaturationInYiqBenchmark(test.Benchmark): @@ -333,6 +336,9 @@ class AdjustSaturationInYiqBenchmark(test.Benchmark): def benchmark_adjust_saturation_in_yiq_cpu_all(self): self._benchmark_adjust_saturation_in_yiq('/cpu:0', None) + def benchmark_adjust_saturation_in_yiq_gpu_all(self): + self._benchmark_adjust_saturation_in_yiq(test.gpu_device_name(), None) + if __name__ == '__main__': googletest.main() diff --git a/tensorflow/contrib/image/python/kernel_tests/segmentation_test.py b/tensorflow/contrib/image/python/kernel_tests/segmentation_test.py new file mode 100644 index 0000000000000000000000000000000000000000..48066cbacefe6b229a1f485486f11e8b8af7704f --- /dev/null +++ b/tensorflow/contrib/image/python/kernel_tests/segmentation_test.py @@ -0,0 +1,189 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for connected component analysis.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import logging + +import numpy as np + +from tensorflow.contrib.image.python.ops import image_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import googletest + +# Image for testing connected_components, with a single, winding component. +SNAKE = np.asarray( + [[0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 1, 1, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 1, 1, 1, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 0], + [0, 1, 1, 1, 1, 1, 1, 1, 0], + [0, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 0, 1, 1, 1, 1, 1, 0], + [0, 1, 0, 0, 0, 0, 0, 1, 0], + [0, 1, 1, 1, 1, 1, 1, 1, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0]]) # pyformat: disable + + +class SegmentationTest(test_util.TensorFlowTestCase): + + def testDisconnected(self): + arr = math_ops.cast( + [[1, 0, 0, 1, 0, 0, 0, 0, 1], + [0, 1, 0, 0, 0, 1, 0, 1, 0], + [1, 0, 1, 0, 0, 0, 1, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 0, 0, 0, 0]], + dtypes.bool) # pyformat: disable + expected = ( + [[1, 0, 0, 2, 0, 0, 0, 0, 3], + [0, 4, 0, 0, 0, 5, 0, 6, 0], + [7, 0, 8, 0, 0, 0, 9, 0, 0], + [0, 0, 0, 0, 10, 0, 0, 0, 0], + [0, 0, 11, 0, 0, 0, 0, 0, 0]]) # pyformat: disable + with self.test_session(): + self.assertAllEqual(image_ops.connected_components(arr).eval(), expected) + + def testSimple(self): + arr = [[0, 1, 0], [1, 1, 1], [0, 1, 0]] + with self.test_session(): + # Single component with id 1. + self.assertAllEqual( + image_ops.connected_components(math_ops.cast( + arr, dtypes.bool)).eval(), arr) + + def testSnake(self): + with self.test_session(): + # Single component with id 1. + self.assertAllEqual( + image_ops.connected_components(math_ops.cast( + SNAKE, dtypes.bool)).eval(), SNAKE) + + def testSnake_disconnected(self): + for i in range(SNAKE.shape[0]): + for j in range(SNAKE.shape[1]): + with self.test_session(): + # If we disconnect any part of the snake except for the endpoints, + # there will be 2 components. + if SNAKE[i, j] and (i, j) not in [(1, 1), (6, 3)]: + disconnected_snake = SNAKE.copy() + disconnected_snake[i, j] = 0 + components = image_ops.connected_components( + math_ops.cast(disconnected_snake, dtypes.bool)).eval() + self.assertEqual(components.max(), 2, 'disconnect (%d, %d)' % (i, + j)) + bins = np.bincount(components.ravel()) + # Nonzero number of pixels labeled 0, 1, or 2. + self.assertGreater(bins[0], 0) + self.assertGreater(bins[1], 0) + self.assertGreater(bins[2], 0) + + def testMultipleImages(self): + images = [[[1, 1, 1, 1], + [1, 0, 0, 1], + [1, 0, 0, 1], + [1, 1, 1, 1]], + [[1, 0, 0, 1], + [0, 0, 0, 0], + [0, 0, 0, 0], + [1, 0, 0, 1]], + [[1, 1, 0, 1], + [0, 1, 1, 0], + [1, 0, 1, 0], + [0, 0, 1, 1]]] # pyformat: disable + expected = [[[1, 1, 1, 1], + [1, 0, 0, 1], + [1, 0, 0, 1], + [1, 1, 1, 1]], + [[2, 0, 0, 3], + [0, 0, 0, 0], + [0, 0, 0, 0], + [4, 0, 0, 5]], + [[6, 6, 0, 7], + [0, 6, 6, 0], + [8, 0, 6, 0], + [0, 0, 6, 6]]] # pyformat: disable + with self.test_session(): + self.assertAllEqual( + image_ops.connected_components(math_ops.cast( + images, dtypes.bool)).eval(), expected) + + def testZeros(self): + with self.test_session(): + self.assertAllEqual( + image_ops.connected_components( + array_ops.zeros((100, 20, 50), dtypes.bool)).eval(), + np.zeros((100, 20, 50))) + + def testOnes(self): + with self.test_session(): + self.assertAllEqual( + image_ops.connected_components( + array_ops.ones((100, 20, 50), dtypes.bool)).eval(), + np.tile(np.arange(100)[:, None, None] + 1, [1, 20, 50])) + + def testOnes_small(self): + with self.test_session(): + self.assertAllEqual( + image_ops.connected_components(array_ops.ones((3, 5), + dtypes.bool)).eval(), + np.ones((3, 5))) + + def testRandom_scipy(self): + np.random.seed(42) + images = np.random.randint(0, 2, size=(10, 100, 200)).astype(np.bool) + expected = connected_components_reference_implementation(images) + if expected is None: + return + with self.test_session(): + self.assertAllEqual( + image_ops.connected_components(images).eval(), expected) + + +def connected_components_reference_implementation(images): + try: + # pylint: disable=g-import-not-at-top + from scipy.ndimage import measurements + except ImportError: + logging.exception('Skipping test method because scipy could not be loaded') + return + image_or_images = np.asarray(images) + if len(image_or_images.shape) == 2: + images = image_or_images[None, :, :] + elif len(image_or_images.shape) == 3: + images = image_or_images + components = np.asarray([measurements.label(image)[0] for image in images]) + # Get the count of nonzero ids for each image, and offset each image's nonzero + # ids using the cumulative sum. + num_ids_per_image = components.reshape( + [-1, components.shape[1] * components.shape[2]]).max(axis=-1) + positive_id_start_per_image = np.cumsum(num_ids_per_image) + for i in range(components.shape[0]): + new_id_start = positive_id_start_per_image[i - 1] if i > 0 else 0 + components[i, components[i] > 0] += new_id_start + if len(image_or_images.shape) == 2: + return components[0, :, :] + else: + return components + + +if __name__ == '__main__': + googletest.main() diff --git a/tensorflow/contrib/image/python/kernel_tests/single_image_random_dot_stereograms_ops_test.py b/tensorflow/contrib/image/python/kernel_tests/single_image_random_dot_stereograms_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..bf0c97245fc5c70469350ec66023f4d1474930e2 --- /dev/null +++ b/tensorflow/contrib/image/python/kernel_tests/single_image_random_dot_stereograms_ops_test.py @@ -0,0 +1,87 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for python single_image_random_dot_stereograms_ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from six.moves import xrange # pylint: disable=redefined-builtin + +from tensorflow.contrib.image.python.ops.single_image_random_dot_stereograms \ + import single_image_random_dot_stereograms +from tensorflow.python.client import session +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import test_util +from tensorflow.python.platform import googletest + +class SingleImageRandomDotStereogramsTest(test_util.TensorFlowTestCase): + + def test_shape_function_default(self): + """ + NOTE: The output_image_shape is [X, Y, C] + while the output data is [Y, X, C] (or [H, W, C]). + As a result, by default the output_image_shape has the value + of [1024, 768, 1], but the output data will be [768, 1024, 1]. + """ + x_np = [[1, 2, 3, 3, 2, 1], + [1, 2, 3, 4, 5, 2], + [1, 2, 3, 4, 5, 3], + [1, 2, 3, 4, 5, 4], + [6, 5, 4, 4, 5, 5]] + x_tf = constant_op.constant(x_np) + # By default [1024, 768, 1] => [768, 1024, 1]. + sirds_1 = single_image_random_dot_stereograms( + x_tf, + convergence_dots_size=8, + number_colors=256, + normalize=True) + shape_1 = sirds_1.get_shape().as_list() + self.assertEqual(shape_1, [768, 1024, 1]) + with self.test_session(): + r_tf_1 = sirds_1.eval() + self.assertAllEqual(shape_1, r_tf_1.shape) + + # If color > 256 then [1024, 768, 3] => [768, 1024, 3]. + sirds_2 = single_image_random_dot_stereograms( + x_tf, + convergence_dots_size=8, + number_colors=512, + normalize=True) + shape_2 = sirds_2.get_shape().as_list() + self.assertEqual(shape_2, [768, 1024, 3]) + with self.test_session(): + r_tf_2 = sirds_2.eval() + self.assertAllEqual(shape_2, r_tf_2.shape) + + # If explicitly set output_image_shape to [1200, 800, 1], + # then the output data should be [800, 1200, 1]. + sirds_3 = single_image_random_dot_stereograms( + x_tf, + convergence_dots_size=8, + number_colors=256, + normalize=True, + output_image_shape=[1200, 800, 1]) + shape_3 = sirds_3.get_shape().as_list() + self.assertEqual(shape_3, [800, 1200, 1]) + with self.test_session(): + r_tf_3 = sirds_3.eval() + self.assertAllEqual(shape_3, r_tf_3.shape) + + +if __name__ == '__main__': + googletest.main() diff --git a/tensorflow/contrib/image/python/ops/image_ops.py b/tensorflow/contrib/image/python/ops/image_ops.py index faedee6f87772016561671bacd87f88657eafffb..63377ae50310db51a3111c5a6e00df7d75dccc0b 100644 --- a/tensorflow/contrib/image/python/ops/image_ops.py +++ b/tensorflow/contrib/image/python/ops/image_ops.py @@ -24,6 +24,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops from tensorflow.python.platform import resource_loader @@ -34,6 +35,7 @@ _image_ops_so = loader.load_op_library( _IMAGE_DTYPES = set( [dtypes.uint8, dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64]) +ops.RegisterShape("ImageConnectedComponents")(common_shapes.call_cpp_shape_fn) ops.RegisterShape("ImageProjectiveTransform")(common_shapes.call_cpp_shape_fn) @@ -395,4 +397,72 @@ def bipartite_match(distance_mat, return result +def connected_components(images): + """Labels the connected components in a batch of images. + + A component is a set of pixels in a single input image, which are all adjacent + and all have the same non-zero value. The components using a squared + connectivity of one (all True entries are joined with their neighbors above, + below, left, and right). Components across all images have consecutive ids 1 + through n. Components are labeled according to the first pixel of the + component appearing in row-major order (lexicographic order by + image_index_in_batch, row, col). Zero entries all have an output id of 0. + + This op is equivalent with `scipy.ndimage.measurements.label` on a 2D array + with the default structuring element (which is the connectivity used here). + + Args: + images: A 2D (H, W) or 3D (N, H, W) Tensor of boolean image(s). + + Returns: + Components with the same shape as `images`. False entries in `images` have + value 0, and all True entries map to a component id > 0. + + Raises: + TypeError: if `images` is not 2D or 3D. + """ + with ops.name_scope("connected_components"): + image_or_images = ops.convert_to_tensor(images, name="images") + if len(image_or_images.get_shape()) == 2: + images = image_or_images[None, :, :] + elif len(image_or_images.get_shape()) == 3: + images = image_or_images + else: + raise TypeError( + "images should have rank 2 (HW) or 3 (NHW). Static shape is %s" % + image_or_images.get_shape()) + components = gen_image_ops.image_connected_components(images) + + # TODO(ringwalt): Component id renaming should be done in the op, to avoid + # constructing multiple additional large tensors. + components_flat = array_ops.reshape(components, [-1]) + unique_ids, id_index = array_ops.unique(components_flat) + id_is_zero = array_ops.where(math_ops.equal(unique_ids, 0))[:, 0] + # Map each nonzero id to consecutive values. + nonzero_consecutive_ids = math_ops.range( + array_ops.shape(unique_ids)[0] - array_ops.shape(id_is_zero)[0]) + 1 + + def no_zero(): + # No need to insert a zero into the ids. + return nonzero_consecutive_ids + + def has_zero(): + # Insert a zero in the consecutive ids where zero appears in unique_ids. + # id_is_zero has length 1. + zero_id_ind = math_ops.to_int32(id_is_zero[0]) + ids_before = nonzero_consecutive_ids[:zero_id_ind] + ids_after = nonzero_consecutive_ids[zero_id_ind:] + return array_ops.concat([ids_before, [0], ids_after], axis=0) + + new_ids = control_flow_ops.cond( + math_ops.equal(array_ops.shape(id_is_zero)[0], 0), no_zero, has_zero) + components = array_ops.reshape( + array_ops.gather(new_ids, id_index), array_ops.shape(components)) + if len(image_or_images.get_shape()) == 2: + return components[0, :, :] + else: + return components + + ops.NotDifferentiable("BipartiteMatch") +ops.NotDifferentiable("ImageConnectedComponents") diff --git a/tensorflow/contrib/keras/api/__init__.py b/tensorflow/contrib/keras/api/__init__.py index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..52e83069cb0c68b510da46149248369dce376647 100644 --- a/tensorflow/contrib/keras/api/__init__.py +++ b/tensorflow/contrib/keras/api/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function diff --git a/tensorflow/contrib/kernel_methods/BUILD b/tensorflow/contrib/kernel_methods/BUILD index a2f320ab11291e4049c8367e1f133a4fbcb72a62..eff7dfeb4c1117e40f4faf43c5e92a52cffd6528 100644 --- a/tensorflow/contrib/kernel_methods/BUILD +++ b/tensorflow/contrib/kernel_methods/BUILD @@ -83,9 +83,11 @@ py_test( srcs_version = "PY2AND3", deps = [ ":kernel_methods", + "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:errors", "//tensorflow/python:framework_for_generated_wrappers", + "//third_party/py/numpy", ], ) diff --git a/tensorflow/contrib/kernel_methods/python/losses.py b/tensorflow/contrib/kernel_methods/python/losses.py index 208b0e1c9dbe93fb99e17e7be5ed5b6e30f4e201..f182fef067b7f523bc5ca63227265be40528b171 100644 --- a/tensorflow/contrib/kernel_methods/python/losses.py +++ b/tensorflow/contrib/kernel_methods/python/losses.py @@ -73,13 +73,13 @@ def sparse_multiclass_hinge_loss( labels)) as scope: # Check logits Tensor has valid rank. - logits_shape = logits.get_shape() - logits_rank = logits_shape.ndims + logits_rank = logits.get_shape().ndims if logits_rank != 2: raise ValueError( 'logits should have rank 2 ([batch_size, num_classes]). Given rank is' ' {}'.format(logits_rank)) - batch_size, num_classes = logits_shape[0].value, logits_shape[1].value + logits_shape = array_ops.shape(logits) + batch_size, num_classes = logits_shape[0], logits_shape[1] logits = math_ops.to_float(logits) # Check labels have valid type. diff --git a/tensorflow/contrib/kernel_methods/python/losses_test.py b/tensorflow/contrib/kernel_methods/python/losses_test.py index 8a1a5ffe56ba283bfae514738fa87e4055f8934e..d38d8041ce1216dfb5af6e93984b35e71008610a 100644 --- a/tensorflow/contrib/kernel_methods/python/losses_test.py +++ b/tensorflow/contrib/kernel_methods/python/losses_test.py @@ -18,10 +18,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import numpy as np + from tensorflow.contrib.kernel_methods.python import losses from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors +from tensorflow.python.ops import array_ops from tensorflow.python.platform import test @@ -114,6 +117,26 @@ class SparseMulticlassHingeLossTest(test.TestCase): loss = losses.sparse_multiclass_hinge_loss(labels, logits) self.assertAlmostEqual(loss.eval(), 0.0, 3) + def testUnknownShape(self): + """Result keeps same with `testZeroLossInt32Labels`""" + logits_np = np.array([[1.2, -1.4, -1.0], + [1.4, 1.8, 4.0], + [0.5, 1.8, -1.0]]) + labels_np = np.array([0, 2, 1], dtype=np.int32) + + logits_shapes = [[3, 3], # batch_size, num_classes + [None, 3], + [3, None], + [None, None]] + + for batch_size, num_classes in logits_shapes: + with self.test_session(): + logits = array_ops.placeholder(dtypes.float32, shape=(batch_size, num_classes)) + labels = array_ops.placeholder(dtypes.int32, shape=(batch_size,)) + loss = losses.sparse_multiclass_hinge_loss(labels, logits) + result = loss.eval(feed_dict={logits: logits_np, labels: labels_np}) + self.assertAlmostEqual(result, 0.0, 3) + def testCorrectPredictionsSomeClassesInsideMargin(self): """Loss is > 0 even if true class logits are higher than other classes.""" with self.test_session(): diff --git a/tensorflow/contrib/kfac/examples/convnet.py b/tensorflow/contrib/kfac/examples/convnet.py index 558bc294bc8ac129b3055ed46623c78a0d5a33e3..39d80addaac1fe855a37255b32bf4412b99df46a 100644 --- a/tensorflow/contrib/kfac/examples/convnet.py +++ b/tensorflow/contrib/kfac/examples/convnet.py @@ -286,7 +286,7 @@ def minimize_loss_distributed(task_id, num_worker_tasks, num_ps_tasks, master, damping=0.001, layer_collection=layer_collection, momentum=0.9) - inv_update_queue = oq.OpQueue(optimizer.inv_updates_dict.values()) + inv_update_queue = oq.OpQueue(optimizer.inv_update_ops) sync_optimizer = tf.train.SyncReplicasOptimizer( opt=optimizer, replicas_to_aggregate=_num_gradient_tasks(num_worker_tasks)) diff --git a/tensorflow/contrib/kfac/examples/mlp.py b/tensorflow/contrib/kfac/examples/mlp.py index 4275ceadc210ff471109b596e1c9aa260ce31ab5..0f0dbb53f45dfefe69aaa9e25caf6ba0a3cf449e 100644 --- a/tensorflow/contrib/kfac/examples/mlp.py +++ b/tensorflow/contrib/kfac/examples/mlp.py @@ -239,3 +239,85 @@ def train_mnist_multitower(data_dir, }) return minimize( loss, accuracy, layer_collection, session_config=session_config) + + +def train_mnist_estimator(data_dir, num_epochs, use_fake_data=False): + """Train an MLP on MNIST using tf.estimator. + + Args: + data_dir: string. Directory to read MNIST examples from. + num_epochs: int. Number of passes to make over the training set. + use_fake_data: bool. If True, generate a synthetic dataset. + + Returns: + accuracy of model on the final minibatch of training data. + """ + + # Load a dataset. + def input_fn(): + tf.logging.info("Loading MNIST into memory.") + return mnist.load_mnist( + data_dir, + num_epochs=num_epochs, + batch_size=64, + flatten_images=True, + use_fake_data=use_fake_data) + + def model_fn(features, labels, mode, params): + """Model function for MLP trained with K-FAC. + + Args: + features: Tensor of shape [batch_size, input_size]. Input features. + labels: Tensor of shape [batch_size]. Target labels for training. + mode: tf.estimator.ModeKey. Must be TRAIN. + params: ignored. + + Returns: + EstimatorSpec for training. + + Raises: + ValueError: If 'mode' is anything other than TRAIN. + """ + del params + + if mode != tf.estimator.ModeKeys.TRAIN: + raise ValueError("Only training is supposed with this API.") + + # Build a ConvNet. + layer_collection = lc.LayerCollection() + loss, accuracy = build_model( + features, labels, num_labels=10, layer_collection=layer_collection) + + # Train with K-FAC. + global_step = tf.train.get_or_create_global_step() + optimizer = opt.KfacOptimizer( + learning_rate=tf.train.exponential_decay( + 0.00002, global_step, 10000, 0.5, staircase=True), + cov_ema_decay=0.95, + damping=0.0001, + layer_collection=layer_collection, + momentum=0.99) + + # Run cov_update_op every step. Run 1 inv_update_ops per step. + cov_update_op = optimizer.cov_update_op + inv_update_op = tf.group( + tf.contrib.kfac.utils.batch_execute( + global_step, optimizer.inv_update_thunks, batch_size=1)) + with tf.control_dependencies([cov_update_op, inv_update_op]): + train_op = optimizer.minimize(loss, global_step=global_step) + + # Print metrics every 5 sec. + hooks = [ + tf.train.LoggingTensorHook( + { + "loss": loss, + "accuracy": accuracy + }, every_n_secs=5), + ] + return tf.estimator.EstimatorSpec( + mode=mode, loss=loss, train_op=train_op, training_hooks=hooks) + + # Train until input_fn() is empty with Estimator. This is a prerequisite for + # TPU compatibility. + estimator = tf.estimator.Estimator(model_fn=model_fn) + estimator.train(input_fn=input_fn) diff --git a/tensorflow/contrib/kfac/examples/mlp_mnist_main.py b/tensorflow/contrib/kfac/examples/mlp_mnist_main.py index b318c71a568be2d717745579df24134ceb3b6a0b..9c34ade1d2018135b3636fddb9dcc65839cd59de 100644 --- a/tensorflow/contrib/kfac/examples/mlp_mnist_main.py +++ b/tensorflow/contrib/kfac/examples/mlp_mnist_main.py @@ -33,7 +33,11 @@ FLAGS = None def main(argv): _ = argv - if FLAGS.num_towers > 1: + if FLAGS.use_estimator: + if FLAGS.num_towers != 1: + raise ValueError("Only 1 device supported in tf.estimator example.") + mlp.train_mnist_estimator(FLAGS.data_dir, num_epochs=200) + elif FLAGS.num_towers > 1: mlp.train_mnist_multitower( FLAGS.data_dir, num_epochs=200, num_towers=FLAGS.num_towers) else: @@ -52,5 +56,9 @@ if __name__ == "__main__": type=int, default=1, help="Number of CPUs to split minibatch across.") + parser.add_argument( + "--use_estimator", + action="store_true", + help="Use tf.estimator API to train.") FLAGS, unparsed = parser.parse_known_args() tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/contrib/kfac/examples/mnist.py b/tensorflow/contrib/kfac/examples/mnist.py index cf92c909f4b5201bc0ffda5703136f46c7058ec6..547c4ab25d589192f2a5b65987be3b05128fe298 100644 --- a/tensorflow/contrib/kfac/examples/mnist.py +++ b/tensorflow/contrib/kfac/examples/mnist.py @@ -63,7 +63,7 @@ def load_mnist(data_dir, images = mnist_data.train.images labels = mnist_data.train.labels - dataset = tf.contrib.data.Dataset.from_tensor_slices((np.asarray( + dataset = tf.data.Dataset.from_tensor_slices((np.asarray( images, dtype=np.float32), np.asarray(labels, dtype=np.int64))) return (dataset.repeat(num_epochs).shuffle(num_examples).batch(batch_size) .make_one_shot_iterator().get_next()) diff --git a/tensorflow/contrib/kfac/examples/tests/convnet_test.py b/tensorflow/contrib/kfac/examples/tests/convnet_test.py index 3c98c54ef6cbd527aa0035e0b6f40be961c6308d..8d86c2bb5150cd4bc8a2b21ba050e904929e0fe9 100644 --- a/tensorflow/contrib/kfac/examples/tests/convnet_test.py +++ b/tensorflow/contrib/kfac/examples/tests/convnet_test.py @@ -96,7 +96,7 @@ class ConvNetTest(tf.test.TestCase): """ x = np.asarray([[1.], [2.]]).astype(np.float32) y = np.asarray([1., 2.]).astype(np.float32) - x, y = (tf.contrib.data.Dataset.from_tensor_slices((x, y)) + x, y = (tf.data.Dataset.from_tensor_slices((x, y)) .repeat(100).batch(2).make_one_shot_iterator().get_next()) w = tf.get_variable("w", shape=[1, 1], initializer=tf.zeros_initializer()) y_hat = tf.matmul(x, w) diff --git a/tensorflow/contrib/kfac/examples/tests/mlp_test.py b/tensorflow/contrib/kfac/examples/tests/mlp_test.py index 34a942d27f64e2583c686c2ba3240bc636ed918b..22da6c29f1b364d94432315988d844db9b95ec28 100644 --- a/tensorflow/contrib/kfac/examples/tests/mlp_test.py +++ b/tensorflow/contrib/kfac/examples/tests/mlp_test.py @@ -53,6 +53,11 @@ class MlpTest(tf.test.TestCase): mlp.train_mnist_multitower( data_dir=None, num_epochs=1, num_towers=2, use_fake_data=True) + def testTrainMnistEstimator(self): + with tf.Graph().as_default(): + # Ensure model training doesn't crash. + mlp.train_mnist_estimator(data_dir=None, num_epochs=1, use_fake_data=True) + if __name__ == "__main__": tf.test.main() diff --git a/tensorflow/contrib/kfac/python/kernel_tests/BUILD b/tensorflow/contrib/kfac/python/kernel_tests/BUILD index 95fba59e3c96ae3c69e0b154740785b0d2bcb3c9..f4ed978174a9ddd8b54a88e60bfb48a67a2e76d2 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/BUILD +++ b/tensorflow/contrib/kfac/python/kernel_tests/BUILD @@ -17,12 +17,17 @@ py_test( "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", + "//tensorflow/python:control_flow_ops", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", "//tensorflow/python:init_ops", + "//tensorflow/python:linalg_ops", "//tensorflow/python:math_ops", "//tensorflow/python:random_ops", + "//tensorflow/python:training", "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + "//third_party/py/numpy", ], ) @@ -110,12 +115,15 @@ py_test( srcs_version = "PY2AND3", deps = [ "//tensorflow/contrib/kfac/python/ops:utils", + "//tensorflow/contrib/tpu", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", "//tensorflow/python:linalg_ops", "//tensorflow/python:random_seed", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", "//third_party/py/numpy", ], ) diff --git a/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py b/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py index 9b28c45c7263208d21b1514ae5f05b7e81e315a3..bfdb69ad02caaa57827e0ae6b3c9fc0d0ed03754 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py +++ b/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py @@ -18,6 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import numpy as np + from tensorflow.contrib.kfac.python.ops import estimator from tensorflow.contrib.kfac.python.ops import layer_collection as lc from tensorflow.contrib.kfac.python.ops import utils @@ -25,11 +27,15 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import init_ops +from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables from tensorflow.python.platform import test +from tensorflow.python.training import training_util _ALL_ESTIMATION_MODES = ["gradients", "empirical", "curvature_prop", "exact"] @@ -119,6 +125,114 @@ class EstimatorTest(test.TestCase): estimator.FisherEstimator([self.weights], 0.1, 0.2, self.layer_collection, mode) + def test_cov_update_thunks(self): + """Ensures covariance update ops run once per global_step.""" + with self._graph.as_default(), self.test_session() as sess: + fisher_estimator = estimator.FisherEstimator( + variables=[self.weights], + layer_collection=self.layer_collection, + cov_ema_decay=0.0, + damping=0.0) + + # Construct an op that executes one covariance update per step. + global_step = training_util.get_or_create_global_step() + cov_matrices = [ + fisher_factor.get_cov() + for fisher_factor in self.layer_collection.get_factors() + ] + cov_update_op_thunks = fisher_estimator.cov_update_thunks + cov_update_op = control_flow_ops.case( + [(math_ops.equal(global_step, i), thunk) + for i, thunk in enumerate(cov_update_op_thunks)]) + increment_global_step = global_step.assign_add(1) + + sess.run(variables.global_variables_initializer()) + initial_cov_values = sess.run(cov_matrices) + + # Ensure there's one update per covariance matrix. + self.assertEqual(len(cov_matrices), len(cov_update_op_thunks)) + + # Test is no-op if only 1 covariance matrix. + assert len(cov_matrices) > 1 + + for i in range(len(cov_matrices)): + # Compare new and old covariance values + new_cov_values = sess.run(cov_matrices) + is_cov_equal = [ + np.allclose(initial_cov_value, new_cov_value) + for (initial_cov_value, + new_cov_value) in zip(initial_cov_values, new_cov_values) + ] + num_cov_equal = sum(is_cov_equal) + + # Ensure exactly one covariance matrix changes per step. + self.assertEqual(num_cov_equal, len(cov_matrices) - i) + + # Run all covariance update ops. + sess.run(cov_update_op) + sess.run(increment_global_step) + + def test_inv_update_thunks(self): + """Ensures inverse update ops run once per global_step.""" + with self._graph.as_default(), self.test_session() as sess: + fisher_estimator = estimator.FisherEstimator( + variables=[self.weights], + layer_collection=self.layer_collection, + cov_ema_decay=0.0, + damping=0.0) + + # Construct op that updates one inverse per global step. + global_step = training_util.get_or_create_global_step() + inv_matrices = [ + matrix + for fisher_factor in self.layer_collection.get_factors() + for matrix in fisher_factor._inverses_by_damping.values() + ] + inv_update_op_thunks = fisher_estimator.inv_update_thunks + inv_update_op = control_flow_ops.case( + [(math_ops.equal(global_step, i), thunk) + for i, thunk in enumerate(inv_update_op_thunks)]) + increment_global_step = global_step.assign_add(1) + + sess.run(variables.global_variables_initializer()) + initial_inv_values = sess.run(inv_matrices) + + # Ensure there's one update per inverse matrix. This is true as long as + # there's no fan-in/fan-out or parameter re-use. + self.assertEqual(len(inv_matrices), len(inv_update_op_thunks)) + + # Test is no-op if only 1 invariance matrix. + assert len(inv_matrices) > 1 + + # Assign each covariance matrix a value other than the identity. This + # ensures that the inverse matrices are updated to something different as + # well. + cov_matrices = [ + fisher_factor.get_cov() + for fisher_factor in self.layer_collection.get_factors() + ] + sess.run([ + cov_matrix.assign(2 * linalg_ops.eye(int(cov_matrix.shape[0]))) + for cov_matrix in cov_matrices + ]) + + for i in range(len(inv_matrices)): + # Compare new and old inverse values + new_inv_values = sess.run(inv_matrices) + is_inv_equal = [ + np.allclose(initial_inv_value, new_inv_value) + for (initial_inv_value, + new_inv_value) in zip(initial_inv_values, new_inv_values) + ] + num_inv_equal = sum(is_inv_equal) + + # Ensure exactly one inverse matrix changes per step. + self.assertEqual(num_inv_equal, len(inv_matrices) - i) + + # Run all inverse update ops. + sess.run(inv_update_op) + sess.run(increment_global_step) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py b/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py index 5f2b5c6cace9cd18f4cc5590ff55a9b39680a381..82accd57f0c37d140238f1884fce956654d14227 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py +++ b/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py @@ -40,6 +40,21 @@ def _make_psd(dim): return array_ops.constant(mat) +class UtilsTest(test.TestCase): + + def testComputePiTracenorm(self): + with ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + left_factor = array_ops.diag([1., 2., 0., 1.]) + right_factor = array_ops.ones([2., 2.]) + + # pi is the sqrt of the left trace norm divided by the right trace norm + pi = fb.compute_pi_tracenorm(left_factor, right_factor) + + pi_val = sess.run(pi) + self.assertEqual(1., pi_val) + + class FullFBTest(test.TestCase): def testFullFBInitSingleTensor(self): @@ -301,8 +316,7 @@ class FullyConnectedDiagonalFB(test.TestCase): multiply_result_big, multiply_inverse_result_big = self.runFisherBlockOps( self.w, [self.inputs], [self.outputs], [self.output_grads]) multiply_result_small, multiply_inverse_result_small = ( - self.runFisherBlockOps(self.w, - np.split(self.inputs, 2), + self.runFisherBlockOps(self.w, np.split(self.inputs, 2), np.split(self.outputs, 2), np.split(self.output_grads, 2))) @@ -584,8 +598,7 @@ class ConvDiagonalFBTest(test.TestCase): multiply_result_big, multiply_inverse_result_big = self.runFisherBlockOps( self.w, [self.inputs], [self.outputs], [self.output_grads]) multiply_result_small, multiply_inverse_result_small = ( - self.runFisherBlockOps(self.w, - np.split(self.inputs, 2), + self.runFisherBlockOps(self.w, np.split(self.inputs, 2), np.split(self.outputs, 2), np.split(self.output_grads, 2))) @@ -608,8 +621,9 @@ class ConvDiagonalFBTest(test.TestCase): self.kernel_size, self.kernel_size, self.input_channels + 1, self.output_channels ]) - expected_result = (expected_result[:, :, 0:-1, :], np.reshape( - expected_result[:, :, -1, :], [self.output_channels])) + expected_result = (expected_result[:, :, 0:-1, :], + np.reshape(expected_result[:, :, -1, :], + [self.output_channels])) self.assertEqual(len(result), 2) self.assertAllClose(expected_result[0], result[0]) @@ -692,8 +706,8 @@ class ConvKFCBasicFBTest(test.TestCase): sess.run(block._input_factor.make_inverse_update_ops()) sess.run(block._output_factor.make_inverse_update_ops()) - vector = (np.arange(1, 15).reshape(7, 2).astype(np.float32), np.arange( - 2, 4).reshape(2, 1).astype(np.float32)) + vector = (np.arange(1, 15).reshape(7, 2).astype(np.float32), + np.arange(2, 4).reshape(2, 1).astype(np.float32)) output = block.multiply_inverse((array_ops.constant(vector[0]), array_ops.constant(vector[1]))) @@ -776,11 +790,50 @@ class ConvKFCBasicFBTest(test.TestCase): self.assertAllClose(output_flat, explicit) +class FullyConnectedSeriesFBTest(test.TestCase): + + def testFullyConnectedSeriesFBInit(self): + with ops.Graph().as_default(): + random_seed.set_random_seed(200) + inputs = array_ops.constant([1., 2.]) + outputs = array_ops.constant([3., 4.]) + block = fb.FullyConnectedSeriesFB( + lc.LayerCollection(), inputs=[inputs], outputs=[outputs]) + self.assertAllEqual([outputs], block.tensors_to_compute_grads()) + + def testInstantiateFactorsHasBias(self): + with ops.Graph().as_default(): + random_seed.set_random_seed(200) + inputs = array_ops.constant([[1., 2.], [3., 4.]]) + outputs = array_ops.constant([[3., 4.], [5., 6.]]) + block = fb.FullyConnectedSeriesFB( + lc.LayerCollection(), + inputs=[inputs], + outputs=[outputs], + has_bias=True) + grads = outputs**2 + block.instantiate_factors(((grads,),), 0.5) + + def testInstantiateFactorsNoBias(self): + with ops.Graph().as_default(): + random_seed.set_random_seed(200) + inputs = array_ops.constant([[1., 2.], [3., 4.]]) + outputs = array_ops.constant([[3., 4.], [5., 6.]]) + block = fb.FullyConnectedSeriesFB( + lc.LayerCollection(), + inputs=[inputs], + outputs=[outputs], + has_bias=False) + grads = outputs**2 + block.instantiate_factors(((grads,),), 0.5) + + def as_tensors(tensor_or_tuple): """Converts a potentially nested tuple of np.array to Tensors.""" if isinstance(tensor_or_tuple, (tuple, list)): return tuple(as_tensors(t) for t in tensor_or_tuple) return ops.convert_to_tensor(tensor_or_tuple) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py b/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py index 5e2ce5a3096f5b523fafad56be742154d79e4803..753378d9f4a0d8762bafbee2ec27d6c71783dda1 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py +++ b/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py @@ -35,18 +35,27 @@ from tensorflow.python.platform import test class MaybeColocateTest(test.TestCase): + def setUp(self): + self._colocate_cov_ops_with_inputs = ff.COLOCATE_COV_OPS_WITH_INPUTS + + def tearDown(self): + ff.set_global_constants( + colocate_cov_ops_with_inputs=self._colocate_cov_ops_with_inputs) + def testFalse(self): + ff.set_global_constants(colocate_cov_ops_with_inputs=False) with tf_ops.Graph().as_default(): a = constant_op.constant([2.0], name='a') - with ff._maybe_colocate_with(a, False): + with ff.maybe_colocate_with(a): b = constant_op.constant(3.0, name='b') self.assertEqual([b'loc:@a'], a.op.colocation_groups()) self.assertEqual([b'loc:@b'], b.op.colocation_groups()) def testTrue(self): + ff.set_global_constants(colocate_cov_ops_with_inputs=True) with tf_ops.Graph().as_default(): a = constant_op.constant([2.0], name='a') - with ff._maybe_colocate_with(a, True): + with ff.maybe_colocate_with(a): b = constant_op.constant(3.0, name='b') self.assertEqual([b'loc:@a'], a.op.colocation_groups()) self.assertEqual([b'loc:@a'], b.op.colocation_groups()) @@ -67,12 +76,19 @@ class FisherFactorTestingDummy(ff.FisherFactor): def _num_sources(self): return 1 + @property + def _dtype(self): + return dtypes.float32 + def _compute_new_cov(self): raise NotImplementedError def instantiate_covariance(self): pass + def make_inverse_update_ops(self): + return [] + class InverseProvidingFactorTestingDummy(ff.InverseProvidingFactor): """Dummy class to test the non-abstract methods on ff.InverseProvidingFactor. @@ -94,6 +110,10 @@ class InverseProvidingFactorTestingDummy(ff.InverseProvidingFactor): def _num_sources(self): return 1 + @property + def _dtype(self): + return dtypes.float32 + def _compute_new_cov(self): raise NotImplementedError @@ -109,7 +129,7 @@ class NumericalUtilsTest(test.TestCase): random_seed.set_random_seed(200) x = npr.randn(100, 3) - cov = ff._compute_cov(array_ops.constant(x)) + cov = ff.compute_cov(array_ops.constant(x)) np_cov = np.dot(x.T, x) / x.shape[0] self.assertAllClose(sess.run(cov), np_cov) @@ -121,7 +141,7 @@ class NumericalUtilsTest(test.TestCase): normalizer = 10. x = npr.randn(100, 3) - cov = ff._compute_cov(array_ops.constant(x), normalizer) + cov = ff.compute_cov(array_ops.constant(x), normalizer=normalizer) np_cov = np.dot(x.T, x) / normalizer self.assertAllClose(sess.run(cov), np_cov) @@ -132,7 +152,7 @@ class NumericalUtilsTest(test.TestCase): m, n = 3, 4 a = npr.randn(m, n) - a_homog = ff._append_homog(array_ops.constant(a)) + a_homog = ff.append_homog(array_ops.constant(a)) np_result = np.hstack([a, np.ones((m, 1))]) self.assertAllClose(sess.run(a_homog), np_result) @@ -267,13 +287,13 @@ class InverseProvidingFactorTest(test.TestCase): for i in range(1, ff.EIGENVALUE_DECOMPOSITION_THRESHOLD + 1): factor.register_damped_inverse(1. / i) ops = factor.make_inverse_update_ops() - self.assertEqual(ff.EIGENVALUE_DECOMPOSITION_THRESHOLD, len(ops)) + self.assertEqual(1, len(ops)) sess.run(tf_variables.global_variables_initializer()) new_invs = [] + sess.run(ops) for i in range(1, ff.EIGENVALUE_DECOMPOSITION_THRESHOLD + 1): # The inverse op will assign the damped inverse of cov to the inv var. - sess.run(ops[i - 1]) new_invs.append(sess.run(factor._inverses_by_damping[1. / i])) # We want to see that the new invs are all different from each other. for i in range(len(new_invs)): @@ -331,6 +351,16 @@ class FullFactorTest(test.TestCase): factor = ff.FullFactor((tensor,), 32) self.assertEqual([6, 6], factor.get_cov().get_shape().as_list()) + def testFullFactorInitFloat64(self): + with tf_ops.Graph().as_default(): + dtype = dtypes.float64_ref + random_seed.set_random_seed(200) + tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c') + factor = ff.FullFactor((tensor,), 32) + cov = factor.get_cov() + self.assertEqual(cov.dtype, dtype) + self.assertEqual([6, 6], cov.get_shape().as_list()) + def testMakeCovarianceUpdateOp(self): with tf_ops.Graph().as_default(), self.test_session() as sess: random_seed.set_random_seed(200) @@ -351,6 +381,16 @@ class NaiveDiagonalFactorTest(test.TestCase): factor = ff.NaiveDiagonalFactor((tensor,), 32) self.assertEqual([6, 1], factor.get_cov().get_shape().as_list()) + def testNaiveDiagonalFactorInitFloat64(self): + with tf_ops.Graph().as_default(): + dtype = dtypes.float64_ref + random_seed.set_random_seed(200) + tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c') + factor = ff.NaiveDiagonalFactor((tensor,), 32) + cov = factor.get_cov() + self.assertEqual(cov.dtype, dtype) + self.assertEqual([6, 1], cov.get_shape().as_list()) + def testMakeCovarianceUpdateOp(self): with tf_ops.Graph().as_default(), self.test_session() as sess: random_seed.set_random_seed(200) @@ -364,18 +404,25 @@ class NaiveDiagonalFactorTest(test.TestCase): class FullyConnectedKroneckerFactorTest(test.TestCase): - def _testFullyConnectedKroneckerFactorInit(self, has_bias, final_shape): + def _testFullyConnectedKroneckerFactorInit(self, + has_bias, + final_shape, + dtype=dtypes.float32_ref): with tf_ops.Graph().as_default(): random_seed.set_random_seed(200) - tensor = array_ops.ones((2, 3), name='a/b/c') + tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c') factor = ff.FullyConnectedKroneckerFactor((tensor,), has_bias=has_bias) - self.assertEqual(final_shape, factor.get_cov().get_shape().as_list()) + cov = factor.get_cov() + self.assertEqual(cov.dtype, dtype) + self.assertEqual(final_shape, cov.get_shape().as_list()) def testFullyConnectedKroneckerFactorInitNoBias(self): - self._testFullyConnectedKroneckerFactorInit(False, [3, 3]) + for dtype in (dtypes.float32_ref, dtypes.float64_ref): + self._testFullyConnectedKroneckerFactorInit(False, [3, 3], dtype=dtype) def testFullyConnectedKroneckerFactorInitWithBias(self): - self._testFullyConnectedKroneckerFactorInit(True, [4, 4]) + for dtype in (dtypes.float32_ref, dtypes.float64_ref): + self._testFullyConnectedKroneckerFactorInit(True, [4, 4], dtype=dtype) def testMakeCovarianceUpdateOpWithBias(self): with tf_ops.Graph().as_default(), self.test_session() as sess: @@ -418,6 +465,18 @@ class ConvInputKroneckerFactorTest(test.TestCase): self.assertEqual([1 * 2 * 3 + 1, 1 * 2 * 3 + 1], factor.get_cov().get_shape().as_list()) + def testConvInputKroneckerFactorInitFloat64(self): + with tf_ops.Graph().as_default(): + dtype = dtypes.float64_ref + random_seed.set_random_seed(200) + tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c') + factor = ff.ConvInputKroneckerFactor( + tensor, (1, 2, 3, 4), 3, 2, has_bias=True) + cov = factor.get_cov() + self.assertEqual(cov.dtype, dtype) + self.assertEqual([1 * 2 * 3 + 1, 1 * 2 * 3 + 1], + cov.get_shape().as_list()) + def testMakeCovarianceUpdateOpWithBias(self): with tf_ops.Graph().as_default(), self.test_session() as sess: random_seed.set_random_seed(200) @@ -453,6 +512,16 @@ class ConvOutputKroneckerFactorTest(test.TestCase): factor = ff.ConvOutputKroneckerFactor((tensor,)) self.assertEqual([5, 5], factor.get_cov().get_shape().as_list()) + def testConvOutputKroneckerFactorInitFloat64(self): + with tf_ops.Graph().as_default(): + dtype = dtypes.float64_ref + random_seed.set_random_seed(200) + tensor = array_ops.ones((2, 3, 4, 5), dtype=dtype, name='a/b/c') + factor = ff.ConvOutputKroneckerFactor((tensor,)) + cov = factor.get_cov() + self.assertEqual(cov.dtype, dtype) + self.assertEqual([5, 5], cov.get_shape().as_list()) + def testConvOutputKroneckerFactorInitNotEnoughDims(self): with tf_ops.Graph().as_default(): random_seed.set_random_seed(200) @@ -471,5 +540,49 @@ class ConvOutputKroneckerFactorTest(test.TestCase): self.assertAllClose([[43, 46.5], [46.5, 51.5]], new_cov) +class FullyConnectedMultiKFTest(test.TestCase): + + def testFullyConnectedMultiKFInit(self): + with tf_ops.Graph().as_default(): + random_seed.set_random_seed(200) + tensor = array_ops.ones((2, 3), name='a/b/c') + tensor_list = [tensor] + factor = ff.FullyConnectedMultiKF((tensor_list,), has_bias=False) + self.assertEqual([3, 3], factor.get_cov().get_shape().as_list()) + + def testFullyConnectedMultiKFInitFloat64(self): + with tf_ops.Graph().as_default(): + dtype = dtypes.float64_ref + random_seed.set_random_seed(200) + tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c') + tensor_list = [tensor] + factor = ff.FullyConnectedMultiKF((tensor_list,), has_bias=False) + cov = factor.get_cov() + self.assertEqual(cov.dtype, dtype) + self.assertEqual([3, 3], cov.get_shape().as_list()) + + def testMakeCovarianceUpdateOpWithBias(self): + with tf_ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + tensor = array_ops.constant([[1., 2.], [3., 4.]], name='a/b/c') + tensor_list = [tensor] + factor = ff.FullyConnectedMultiKF((tensor_list,), has_bias=True) + + sess.run(tf_variables.global_variables_initializer()) + new_cov = sess.run(factor.make_covariance_update_op(.5)) + self.assertAllClose([[3, 3.5, 1], [3.5, 5.5, 1.5], [1, 1.5, 1]], new_cov) + + def testMakeCovarianceUpdateOpNoBias(self): + with tf_ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + tensor = array_ops.constant([[1., 2.], [3., 4.]], name='a/b/c') + tensor_list = [tensor] + factor = ff.FullyConnectedMultiKF((tensor_list,)) + + sess.run(tf_variables.global_variables_initializer()) + new_cov = sess.run(factor.make_covariance_update_op(.5)) + self.assertAllClose([[3, 3.5], [3.5, 5.5]], new_cov) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py b/tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py index 39ce3e9337157c8206107bc40c489e44019743ab..63f45ea55b3d1f65a113e8c81a822a08613672df 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py +++ b/tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py @@ -114,5 +114,76 @@ class CategoricalLogitsNegativeLogProbLossTest(test.TestCase): self.assertEqual(loss.num_registered_minibatches, num_towers) +class OnehotCategoricalLogitsNegativeLogProbLossTest(test.TestCase): + + def testSample(self): + """Ensure samples can be drawn.""" + with ops.Graph().as_default(), self.test_session() as sess: + logits = np.asarray([ + [0., 0., 0.], # + [1., -1., 0.] + ]).astype(np.float32) + loss = loss_functions.OnehotCategoricalLogitsNegativeLogProbLoss( + array_ops.constant(logits)) + sample = loss.sample(42) + sample = sess.run(sample) + self.assertEqual(sample.shape, (2, 3)) + + def testEvaluateOnTargets(self): + """Ensure log probability can be evaluated correctly.""" + with ops.Graph().as_default(), self.test_session() as sess: + logits = np.asarray([ + [0., 0., 0.], # + [1., -1., 0.] + ]).astype(np.float32) + targets = np.asarray([2, 1]).astype(np.int32) + loss = loss_functions.OnehotCategoricalLogitsNegativeLogProbLoss( + array_ops.constant(logits), targets=array_ops.one_hot(targets, 3)) + neg_log_prob = loss.evaluate() + neg_log_prob = sess.run(neg_log_prob) + + # Calculate explicit log probability of targets. + probs = np.exp(logits) / np.sum(np.exp(logits), axis=1, keepdims=True) + log_probs = np.log([ + probs[0, targets[0]], # + probs[1, targets[1]] + ]) + expected_log_prob = np.sum(log_probs) + + self.assertAllClose(neg_log_prob, -expected_log_prob) + + def testEvaluateOnSample(self): + """Ensure log probability of a sample can be drawn.""" + with ops.Graph().as_default(), self.test_session() as sess: + logits = np.asarray([ + [0., 0., 0.], # + [1., -1., 0.] + ]).astype(np.float32) + loss = loss_functions.OnehotCategoricalLogitsNegativeLogProbLoss( + array_ops.constant(logits)) + neg_log_prob = loss.evaluate_on_sample(42) + + # Simply ensure this doesn't crash. As the output is random, it's + # difficult to say if the output is correct or not... + neg_log_prob = sess.run(neg_log_prob) + + def testMultiMinibatchRegistration(self): + """Ensure this loss function supports registering multiple minibatches.""" + with ops.Graph().as_default(): + tower_logits = [] + loss = None + num_towers = 5 + for _ in range(num_towers): + logits = random_ops.random_uniform(shape=[2, 3]) + tower_logits.append(logits) + if loss is None: + loss = loss_functions.OnehotCategoricalLogitsNegativeLogProbLoss( + logits) + else: + loss.register_additional_minibatch(logits) + self.assertListEqual(loss.input_minibatches, tower_logits) + self.assertEqual(loss.num_registered_minibatches, num_towers) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/kfac/python/kernel_tests/utils_test.py b/tensorflow/contrib/kfac/python/kernel_tests/utils_test.py index 55fe38e3e9aab2dbd70a45cdc8fa0c208b036db0..97a97adbf5577cd2694d3055acaa59258ad27964 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/utils_test.py +++ b/tensorflow/contrib/kfac/python/kernel_tests/utils_test.py @@ -22,11 +22,15 @@ import numpy as np import numpy.random as npr from tensorflow.contrib.kfac.python.ops import utils +from tensorflow.contrib.tpu.python.tpu import tpu_function from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed from tensorflow.python.ops import array_ops from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables from tensorflow.python.platform import test @@ -95,6 +99,18 @@ class SubGraphTest(test.TestCase): filtered_list = sub_graph.filter_list(input_list) self.assertEqual(filtered_list, [b]) + def testVariableUses(self): + with ops.Graph().as_default(): + var = variable_scope.get_variable('var', shape=[10, 10]) + resource_var = variable_scope.get_variable( + 'resource_var', shape=[10, 10], use_resource=True) + x = array_ops.zeros([3, 10]) + z0 = math_ops.matmul(x, var) + math_ops.matmul(x, var) + z1 = math_ops.matmul(x, resource_var) + sub_graph = utils.SubGraph((z0, z1)) + self.assertEqual(2, sub_graph.variable_uses(var)) + self.assertEqual(1, sub_graph.variable_uses(resource_var)) + class UtilsTest(test.TestCase): @@ -222,18 +238,6 @@ class UtilsTest(test.TestCase): self.assertAllClose(b, np.array([4., 5.])) self.assertAllClose(c, np.array([[6.], [7.], [8.], [9.]])) - def testComputePi(self): - with ops.Graph().as_default(), self.test_session() as sess: - random_seed.set_random_seed(200) - left_factor = array_ops.diag([1., 2., 0., 1.]) - right_factor = array_ops.ones([2., 2.]) - - # pi is the sqrt of the left trace norm divided by the right trace norm - pi = utils.compute_pi(left_factor, right_factor) - - pi_val = sess.run(pi) - self.assertEqual(1., pi_val) - def testPosDefInvCholesky(self): with ops.Graph().as_default(), self.test_session() as sess: random_seed.set_random_seed(200) @@ -265,6 +269,62 @@ class UtilsTest(test.TestCase): np_inv = np.linalg.inv(x + damp * np.eye(size)) self.assertAllClose(sess.run(tf_inv), np_inv) + def testCrossReplicaMean(self): + """Ensures that cross_replica_mean() executes only when num_shards > 1.""" + with ops.Graph().as_default(): + with tpu_function.tpu_shard_context(4): + tensor = array_ops.zeros([], dtype=dtypes.float32) + mean = utils.cross_replica_mean(tensor) + self.assertNotEqual(mean, tensor) + + with ops.Graph().as_default(): + with tpu_function.tpu_shard_context(1): + tensor = array_ops.zeros([], dtype=dtypes.float32) + mean = utils.cross_replica_mean(tensor) + self.assertEqual(mean, tensor) + + with ops.Graph().as_default(): + with self.assertRaises(ValueError): # Outside of TPU context. + tensor = array_ops.zeros([], dtype=dtypes.float32) + mean = utils.cross_replica_mean(tensor) + + def testBatchExecute(self): + """Ensure batch_execute runs in a round-robin fashion.""" + + def increment_var(var): + return lambda: var.assign_add(1) + + with ops.Graph().as_default(), self.test_session() as sess: + i = variable_scope.get_variable('i', initializer=0) + accumulators = [ + variable_scope.get_variable('var%d' % j, initializer=0) + for j in range(3) + ] + thunks = [increment_var(var) for var in accumulators] + increment_accumulators = utils.batch_execute(i, thunks, 2) + increment_i = i.assign_add(1) + + sess.run(variables.global_variables_initializer()) + + # Ensure one op per thunk. + self.assertEqual(3, len(increment_accumulators)) + + # Ensure round-robin execution. + values = [] + for _ in range(5): + sess.run(increment_accumulators) + sess.run(increment_i) + values.append(sess.run(accumulators)) + self.assertAllClose( + [ + [1, 1, 0], # + [2, 1, 1], # + [2, 2, 2], # + [3, 3, 2], # + [4, 3, 3] + ], + values) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/kfac/python/ops/BUILD b/tensorflow/contrib/kfac/python/ops/BUILD index b2272a4cee09b35ff672514077b4b128b870b772..ee6549b109399766579b6ea18a987ae2c8275983 100644 --- a/tensorflow/contrib/kfac/python/ops/BUILD +++ b/tensorflow/contrib/kfac/python/ops/BUILD @@ -38,6 +38,7 @@ py_library( ":utils", "//tensorflow/python:array_ops", "//tensorflow/python:framework_ops", + "//tensorflow/python:init_ops", "//tensorflow/python:linalg_ops", "//tensorflow/python:math_ops", "//tensorflow/python:special_math_ops", @@ -64,6 +65,7 @@ py_library( srcs = ["loss_functions.py"], srcs_version = "PY2AND3", deps = [ + "//tensorflow/contrib/distributions:distributions_py", "//tensorflow/python:array_ops", "//tensorflow/python:math_ops", "//tensorflow/python:tensor_shape", @@ -195,7 +197,9 @@ py_library( srcs = ["utils.py"], srcs_version = "PY2AND3", deps = [ + "//tensorflow/contrib/tpu", "//tensorflow/python:array_ops", + "//tensorflow/python:control_flow_ops", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", "//tensorflow/python:gradients", diff --git a/tensorflow/contrib/kfac/python/ops/estimator.py b/tensorflow/contrib/kfac/python/ops/estimator.py index 27ff951f16112e09b82ac6885072d966de09983f..a7b1f9d35c931fc44408be804479e758f28f7110 100644 --- a/tensorflow/contrib/kfac/python/ops/estimator.py +++ b/tensorflow/contrib/kfac/python/ops/estimator.py @@ -20,7 +20,6 @@ from __future__ import print_function import contextlib import itertools -import math import numpy as np @@ -67,7 +66,21 @@ class _DeviceContextGenerator(object): class FisherEstimator(object): - """Fisher estimator class supporting various approximations of the Fisher.""" + """Fisher estimator class supporting various approximations of the Fisher. + + Attributes: + cov_update_thunks: list of no-arg functions. Executing a function adds + covariance update ops for a single FisherFactor to the graph. + cov_update_ops: List of Ops. Running an op updates covariance matrices for a + single FisherFactor. + cov_update_op: Op. Running updates covariance matrices for all + FisherFactors. + inv_update_thunks: list of no-arg functions. Executing a function adds + inverse update ops for a single FisherFactor to the graph. + inv_update_ops: List of Ops. Running an op updates inverse matrices for a + single FisherFactor. + inv_update_op: Op. Running updates inverse matrices for all FisherFactors. + """ def __init__(self, variables, @@ -75,7 +88,7 @@ class FisherEstimator(object): damping, layer_collection, estimation_mode="gradients", - colocate_gradients_with_ops=False, + colocate_gradients_with_ops=True, cov_devices=None, inv_devices=None): """Create a FisherEstimator object. @@ -111,7 +124,7 @@ class FisherEstimator(object): is more expensive to compute than the other three options by a factor equal to the output dimension, roughly speaking. colocate_gradients_with_ops: Whether we should request gradients be - colocated with their respective ops. + colocated with their respective ops. (Default: True) cov_devices: Iterable of device strings (e.g. '/gpu:0'). Covariance computations will be placed on these devices in a round-robin fashion. Can be None, which means that no devices are specified. @@ -123,12 +136,13 @@ class FisherEstimator(object): ValueError: If no losses have been registered with layer_collection. """ + self._cov_ema_decay = cov_ema_decay self._variables = variables self._damping = damping self._estimation_mode = estimation_mode self._layers = layer_collection self._layers.create_subgraph() - self._check_registration(variables) + self._layers.check_registration(variables) self._gradient_fns = { "gradients": self._get_grads_lists_gradients, "empirical": self._get_grads_lists_empirical, @@ -136,13 +150,31 @@ class FisherEstimator(object): "exact": self._get_grads_lists_exact } self._colocate_gradients_with_ops = colocate_gradients_with_ops + + # TODO(b/70674513): Factor device placement outside of this class. self._cov_device_context_generator = _DeviceContextGenerator(cov_devices) if inv_devices == cov_devices: self._inv_device_context_generator = self._cov_device_context_generator else: self._inv_device_context_generator = _DeviceContextGenerator(inv_devices) - setup = self._setup(cov_ema_decay) - self.cov_update_op, self.inv_update_op, self.inv_updates_dict = setup + + self._instantiate_factors() + + self.cov_update_thunks = [ + self._create_cov_update_thunk(factor) + for factor in self._layers.get_factors() + ] + self.cov_update_ops = [thunk() for thunk in self.cov_update_thunks] + self.cov_update_op = control_flow_ops.group( + self.cov_update_ops, name="cov_update_op") + + self.inv_update_thunks = [ + self._create_inv_update_thunk(factor) + for factor in self._layers.get_factors() + ] + self.inv_update_ops = [thunk() for thunk in self.inv_update_thunks] + self.inv_update_op = control_flow_ops.group( + self.inv_update_ops, name="inv_update_op") @property def variables(self): @@ -203,61 +235,8 @@ class FisherEstimator(object): return self._apply_transformation(vecs_and_vars, lambda fb, vec: fb.multiply(vec)) - def _check_registration(self, variables): - """Checks that all variable uses have been registered properly. - - Args: - variables: List of variables. - - Raises: - ValueError: If any registered variables are not included in the list. - ValueError: If any variable in the list is not registered. - ValueError: If any variable in the list is registered with the wrong - number of "uses" in the subgraph recorded (vs the number of times that - variable is actually used in the subgraph). - """ - # Note that overlapping parameters (i.e. those that share variables) will - # be caught by layer_collection.LayerParametersDict during registration. - - reg_use_map = self._layers.get_use_count_map() - - error_messages = [] - - for var in variables: - total_uses = self._layers.subgraph.variable_uses(var) - reg_uses = reg_use_map[var] - - if reg_uses == 0: - error_messages.append("Variable {} not registered.".format(var)) - elif (not math.isinf(reg_uses)) and reg_uses != total_uses: - error_messages.append( - "Variable {} registered with wrong number of uses ({} " - "registrations vs {} uses).".format(var, reg_uses, total_uses)) - - num_get_vars = len(reg_use_map) - - if num_get_vars > len(variables): - error_messages.append("{} registered variables were not included in list." - .format(num_get_vars - len(variables))) - - if error_messages: - error_messages = [ - "Found the following errors with variable registration:" - ] + error_messages - raise ValueError("\n\t".join(error_messages)) - - def _setup(self, cov_ema_decay): - """Sets up the various operations. - - Args: - cov_ema_decay: The decay factor used when calculating the covariance - estimate moving averages. - - Returns: - A triple (covs_update_op, invs_update_op, inv_updates_dict), where - covs_update_op is the grouped Op to update all the covariance estimates, - invs_update_op is the grouped Op to update all the inverses, and - inv_updates_dict is a dict mapping Op names to individual inverse updates. + def _instantiate_factors(self): + """Instantiates FisherFactors' variables. Raises: ValueError: If estimation_mode was improperly specified at construction. @@ -282,20 +261,25 @@ class FisherEstimator(object): with self._cov_device_context_generator(): fb.instantiate_factors(grads_list, self.damping) - cov_updates = [ - factor.make_covariance_update_op(cov_ema_decay) - for factor in self._layers.get_factors() - ] - inv_updates = {op.name: op for op in self._get_all_inverse_update_ops()} + def _create_cov_update_thunk(self, factor): + """Constructs a covariance update thunk for a single FisherFactor.""" + + def thunk(): + with tf_ops.name_scope( + "create_cov_update_thunk", values=[self._cov_ema_decay]): + return factor.make_covariance_update_op(self._cov_ema_decay) + + return thunk - return control_flow_ops.group(*cov_updates), control_flow_ops.group( - *inv_updates.values()), inv_updates + def _create_inv_update_thunk(self, factor): + """Constructs an inverse update thunk for a single FisherFactor.""" - def _get_all_inverse_update_ops(self): - for factor in self._layers.get_factors(): - with self._inv_device_context_generator(): - for op in factor.make_inverse_update_ops(): - yield op + def thunk(): + with tf_ops.name_scope("create_inv_update_thunk"): + with self._inv_device_context_generator(): + return control_flow_ops.group(factor.make_inverse_update_ops()) + + return thunk def _get_grads_lists_gradients(self, tensors): grads_flat = gradients_impl.gradients( @@ -333,11 +317,7 @@ class FisherEstimator(object): return tuple((grad,) for grad in grads_all) def _get_grads_lists_exact(self, tensors): - """Returns a list of all gradients, computing them exactly. - - Args: - tensors: Tensors for which to compute gradients. - """ + """No docstring required.""" # Loop over all coordinates of all losses. grads_all = [] for loss in self._layers.losses: diff --git a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py index e822a1213a4132522be8031401609c78572cb1a6..9436caf9618bc3d3c0dd7b3842420016b119464f 100644 --- a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py +++ b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py @@ -38,6 +38,7 @@ from __future__ import division from __future__ import print_function import abc +import enum # pylint: disable=g-bad-import-order import six @@ -52,14 +53,61 @@ from tensorflow.python.ops import math_ops # damping /= num_replications ** NORMALIZE_DAMPING_POWER NORMALIZE_DAMPING_POWER = 1.0 +# Methods for adjusting damping for FisherBlocks. See +# compute_pi_adjusted_damping() for details. +PI_OFF_NAME = "off" +PI_TRACENORM_NAME = "tracenorm" +PI_TYPE = PI_TRACENORM_NAME -def set_global_constants(normalize_damping_power=None): + +def set_global_constants(normalize_damping_power=None, pi_type=None): """Sets various global constants used by the classes in this module.""" global NORMALIZE_DAMPING_POWER + global PI_TYPE if normalize_damping_power is not None: NORMALIZE_DAMPING_POWER = normalize_damping_power + if pi_type is not None: + PI_TYPE = pi_type + + +def normalize_damping(damping, num_replications): + """Normalize damping after adjusting scale by NORMALIZE_DAMPING_POWER.""" + if NORMALIZE_DAMPING_POWER: + return damping / (num_replications ** NORMALIZE_DAMPING_POWER) + return damping + + +def compute_pi_tracenorm(left_cov, right_cov): + """Computes the scalar constant pi for Tikhonov regularization/damping. + + pi = sqrt( (trace(A) / dim(A)) / (trace(B) / dim(B)) ) + See section 6.3 of https://arxiv.org/pdf/1503.05671.pdf for details. + + Args: + left_cov: The left Kronecker factor "covariance". + right_cov: The right Kronecker factor "covariance". + + Returns: + The computed scalar constant pi for these Kronecker Factors (as a Tensor). + """ + # Instead of dividing by the dim of the norm, we multiply by the dim of the + # other norm. This works out the same in the ratio. + left_norm = math_ops.trace(left_cov) * right_cov.shape.as_list()[0] + right_norm = math_ops.trace(right_cov) * left_cov.shape.as_list()[0] + return math_ops.sqrt(left_norm / right_norm) + + +def compute_pi_adjusted_damping(left_cov, right_cov, damping): + + if PI_TYPE == PI_TRACENORM_NAME: + pi = compute_pi_tracenorm(left_cov, right_cov) + return (damping * pi, damping / pi) + + elif PI_TYPE == PI_OFF_NAME: + return (damping, damping) + @six.add_metaclass(abc.ABCMeta) class FisherBlock(object): @@ -153,7 +201,7 @@ class FullFB(FisherBlock): self._factor.register_damped_inverse(damping) def multiply_inverse(self, vector): - inverse = self._factor.get_inverse(self._damping) + inverse = self._factor.get_damped_inverse(self._damping) out_flat = math_ops.matmul(inverse, utils.tensors_to_column(vector)) return utils.column_to_tensors(vector, out_flat) @@ -409,10 +457,7 @@ class ConvDiagonalFB(FisherBlock): self._num_locations = ( inputs_shape[1] * inputs_shape[2] // (self._strides[1] * self._strides[2])) - - if NORMALIZE_DAMPING_POWER: - damping /= self._num_locations ** NORMALIZE_DAMPING_POWER - self._damping = damping + self._damping = normalize_damping(damping, self._num_locations) self._factor = self._layer_collection.make_or_get_factor( fisher_factors.ConvDiagonalFactor, @@ -465,11 +510,10 @@ class KroneckerProductFB(FisherBlock): Args: damping: The base damping factor (float or Tensor) for the damped inverse. """ - pi = utils.compute_pi(self._input_factor.get_cov(), - self._output_factor.get_cov()) - - self._input_damping = (damping**0.5) * pi - self._output_damping = (damping**0.5) / pi + self._input_damping, self._output_damping = compute_pi_adjusted_damping( + self._input_factor.get_cov(), + self._output_factor.get_cov(), + damping**0.5) self._input_factor.register_damped_inverse(self._input_damping) self._output_factor.register_damped_inverse(self._output_damping) @@ -487,8 +531,9 @@ class KroneckerProductFB(FisherBlock): return 1.0 def multiply_inverse(self, vector): - left_factor_inv = self._input_factor.get_inverse(self._input_damping) - right_factor_inv = self._output_factor.get_inverse(self._output_damping) + left_factor_inv = self._input_factor.get_damped_inverse(self._input_damping) + right_factor_inv = self._output_factor.get_damped_inverse( + self._output_damping) reshaped_vector = utils.layer_params_to_mat2d(vector) reshaped_out = math_ops.matmul(left_factor_inv, math_ops.matmul(reshaped_vector, @@ -650,8 +695,8 @@ class ConvKFCBasicFB(KroneckerProductFB): grads_list = tuple(_concat_along_batch_dim(grads) for grads in grads_list) # Infer number of locations upon which convolution is applied. - self._num_locations = _num_conv_locations(inputs.shape.as_list(), - self._strides) + self._num_locations = num_conv_locations(inputs.shape.as_list(), + self._strides) self._input_factor = self._layer_collection.make_or_get_factor( fisher_factors.ConvInputKroneckerFactor, @@ -660,11 +705,9 @@ class ConvKFCBasicFB(KroneckerProductFB): self._output_factor = self._layer_collection.make_or_get_factor( fisher_factors.ConvOutputKroneckerFactor, (grads_list,)) - if NORMALIZE_DAMPING_POWER: - damping /= self._num_locations**NORMALIZE_DAMPING_POWER - self._damping = damping - + damping = normalize_damping(damping, self._num_locations) self._register_damped_input_and_output_inverses(damping) + self._damping = damping @property def _renorm_coeff(self): @@ -717,6 +760,267 @@ def _concat_along_batch_dim(tensor_list): return array_ops.concat(tensor_list, axis=0) -def _num_conv_locations(input_shape, strides): - """Returns the number of locations a Conv kernel is applied to.""" +def num_conv_locations(input_shape, strides): + """Returns the number of spatial locations a 2D Conv kernel is applied to. + + Args: + input_shape: list representing shape of inputs to the Conv layer. + strides: list representing strides for the Conv kernel. + + Returns: + A scalar |T| denoting the number of spatial locations for the Conv layer. + """ return input_shape[1] * input_shape[2] // (strides[1] * strides[2]) + + +class FullyConnectedMultiIndepFB(KroneckerProductFB): + """FisherBlock for fully-connected layers that share parameters. + """ + + def __init__(self, layer_collection, inputs, outputs, has_bias=False): + """Creates a FullyConnectedMultiIndepFB block. + + Args: + layer_collection: LayerCollection instance. + inputs: list or tuple of Tensors. Each Tensor has shape [batch_size, + inputs_size]. + outputs: list or tuple of Tensors. Each Tensor has shape [batch_size, + outputs_size]. + has_bias: bool. If True, estimates Fisher with respect to a bias + parameter as well as the layer's parameters. + """ + + assert len(inputs) == len(outputs) + # We need to make sure inputs and outputs are tuples and not lists so that + # they get hashed by layer_collection.make_or_get_factor properly. + self._inputs = tuple(inputs) + self._outputs = tuple(outputs) + self._has_bias = has_bias + self._num_uses = len(inputs) + + super(FullyConnectedMultiIndepFB, self).__init__(layer_collection) + + @property + def num_registered_minibatches(self): + # TODO(b/69411207): Add support for registering additional minibatches. + return 1 + + def instantiate_factors(self, grads_list, damping): + + self._input_factor = self._layer_collection.make_or_get_factor( + fisher_factors.FullyConnectedMultiKF, + ((self._inputs,), self._has_bias)) + + self._output_factor = self._layer_collection.make_or_get_factor( + fisher_factors.FullyConnectedMultiKF, (grads_list,)) + + damping = normalize_damping(damping, self._num_uses) + self._register_damped_input_and_output_inverses(damping) + + @property + def _renorm_coeff(self): + return self._num_uses + + def tensors_to_compute_grads(self): + return self._outputs + + def num_inputs(self): + return len(self._inputs) + + +class SeriesFBApproximation(enum.IntEnum): + """See FullyConnectedSeriesFB.__init__ for description and usage.""" + option1 = 1 + option2 = 2 + + +class FullyConnectedSeriesFB(FisherBlock): + """FisherBlock for fully-connected layers that share parameters across time. + + See the following preprint for details: + https://openreview.net/pdf?id=HyMTkQZAb + + See the end of the appendix of the paper for a pseudo-code of the + algorithm being implemented by multiply_inverse here. Note that we are + using pre-computed versions of certain matrix-matrix products to speed + things up. This is explicitly explained wherever it is done. + """ + + def __init__(self, + layer_collection, + inputs, + outputs, + has_bias=False, + option=SeriesFBApproximation.option2): + """Constructs a new `FullyConnectedSeriesFB`. + + Args: + layer_collection: The collection of all layers in the K-FAC approximate + Fisher information matrix to which this FisherBlock belongs. + inputs: List of tensors of shape [batch_size, input_size]. + Inputs to the layer. + outputs: List of tensors of shape [batch_size, input_size]. + Outputs of the layer (before activations). + has_bias: Whether the layer includes a bias parameter. + option: A `SeriesFBApproximation` specifying the simplifying assumption + to be used in this block. `option1` approximates the cross-covariance + over time as a symmetric matrix, while `option2` makes + the assumption that training sequences are infinitely long. See section + 3.5 of the paper for more details. + """ + + assert len(inputs) == len(outputs) + # We need to make sure inputs and outputs are tuples and not lists so that + # they get hashed by layer_collection.make_or_get_factor properly. + self._inputs = tuple(inputs) + self._outputs = tuple(outputs) + self._has_bias = has_bias + self._num_timesteps = len(inputs) + self._option = option + + super(FullyConnectedSeriesFB, self).__init__(layer_collection) + + @property + def num_registered_minibatches(self): + # TODO(b/69411207): Add support for registering additional minibatches. + return 1 + + def instantiate_factors(self, grads_list, damping): + + self._input_factor = self._layer_collection.make_or_get_factor( + fisher_factors.FullyConnectedMultiKF, ((self._inputs,), self._has_bias)) + + self._output_factor = self._layer_collection.make_or_get_factor( + fisher_factors.FullyConnectedMultiKF, (grads_list,)) + + damping = normalize_damping(damping, self._num_timesteps) + self._damping_input, self._damping_output = compute_pi_adjusted_damping( + self._input_factor.get_cov(), + self._output_factor.get_cov(), + damping**0.5) + + if self._option == SeriesFBApproximation.option1: + self._input_factor.register_option1quants(self._damping_input) + self._output_factor.register_option1quants(self._damping_output) + elif self._option == SeriesFBApproximation.option2: + self._input_factor.register_option2quants(self._damping_input) + self._output_factor.register_option2quants(self._damping_output) + else: + raise ValueError( + "Unrecognized FullyConnectedSeriesFB approximation: {}".format( + self._option)) + + def multiply_inverse(self, vector): + # pylint: disable=invalid-name + + Z = utils.layer_params_to_mat2d(vector) + + # Derivations were done for "batch_dim==1" case so we need to convert to + # that orientation: + Z = array_ops.transpose(Z) + + if self._option == SeriesFBApproximation.option1: + + # Note that L_A = A0^(-1/2) * U_A and L_G = G0^(-1/2) * U_G. + L_A, psi_A = self._input_factor.get_option1quants(self._damping_input) + L_G, psi_G = self._output_factor.get_option1quants(self._damping_output) + + def gamma(x): + # We are assuming that each case has the same number of time-steps. + # If this stops being the case one shouldn't simply replace this T + # with its average value. Instead, one needs to go back to the + # definition of the gamma function from the paper. + T = self._num_timesteps + return (1 - x)**2 / (T * (1 - x**2) - 2 * x * (1 - x**T)) + + # Y = gamma( psi_G*psi_A^T ) (computed element-wise) + # Even though Y is Z-independent we are recomputing it from the psi's + # each since Y depends on both A and G quantities, and it is relatively + # cheap to compute. + Y = gamma(array_ops.reshape(psi_G, [int(psi_G.shape[0]), -1]) * psi_A) + + # Z = L_G^T * Z * L_A + # This is equivalent to the following computation from the original + # pseudo-code: + # Z = G0^(-1/2) * Z * A0^(-1/2) + # Z = U_G^T * Z * U_A + Z = math_ops.matmul(L_G, math_ops.matmul(Z, L_A), transpose_a=True) + + # Z = Z .* Y + Z *= Y + + # Z = L_G * Z * L_A^T + # This is equivalent to the following computation from the original + # pseudo-code: + # Z = U_G * Z * U_A^T + # Z = G0^(-1/2) * Z * A0^(-1/2) + Z = math_ops.matmul(L_G, math_ops.matmul(Z, L_A, transpose_b=True)) + + elif self._option == SeriesFBApproximation.option2: + + # Note that P_A = A_1^T * A_0^(-1) and P_G = G_1^T * G_0^(-1), + # and K_A = A_0^(-1/2) * E_A and K_G = G_0^(-1/2) * E_G. + P_A, K_A, mu_A = self._input_factor.get_option2quants(self._damping_input) + P_G, K_G, mu_G = self._output_factor.get_option2quants( + self._damping_output) + + # Our approach differs superficially from the pseudo-code in the paper + # in order to reduce the total number of matrix-matrix multiplies. + # In particular, the first three computations in the pseudo code are + # Z = G0^(-1/2) * Z * A0^(-1/2) + # Z = Z - hPsi_G^T * Z * hPsi_A + # Z = E_G^T * Z * E_A + # Noting that hPsi = C0^(-1/2) * C1 * C0^(-1/2), so that + # C0^(-1/2) * hPsi = C0^(-1) * C1 * C0^(-1/2) = P^T * C0^(-1/2) + # the entire computation can be written as + # Z = E_G^T * (G0^(-1/2) * Z * A0^(-1/2) + # - hPsi_G^T * G0^(-1/2) * Z * A0^(-1/2) * hPsi_A) * E_A + # = E_G^T * (G0^(-1/2) * Z * A0^(-1/2) + # - G0^(-1/2) * P_G * Z * P_A^T * A0^(-1/2)) * E_A + # = E_G^T * G0^(-1/2) * Z * A0^(-1/2) * E_A + # - E_G^T* G0^(-1/2) * P_G * Z * P_A^T * A0^(-1/2) * E_A + # = K_G^T * Z * K_A - K_G^T * P_G * Z * P_A^T * K_A + # This final expression is computed by the following two lines: + # Z = Z - P_G * Z * P_A^T + Z -= math_ops.matmul(P_G, math_ops.matmul(Z, P_A, transpose_b=True)) + # Z = K_G^T * Z * K_A + Z = math_ops.matmul(K_G, math_ops.matmul(Z, K_A), transpose_a=True) + + # Z = Z ./ (1*1^T - mu_G*mu_A^T) + # Be careful with the outer product. We don't want to accidentally + # make it an inner-product instead. + tmp = 1.0 - array_ops.reshape(mu_G, [int(mu_G.shape[0]), -1]) * mu_A + # Prevent some numerical issues by setting any 0.0 eigs to 1.0 + tmp += 1.0 * math_ops.cast(math_ops.equal(tmp, 0.0), dtype=tmp.dtype) + Z /= tmp + + # We now perform the transpose/reverse version of the operations + # derived above, whose derivation from the original pseudo-code is + # analgous. + # Z = K_G * Z * K_A^T + Z = math_ops.matmul(K_G, math_ops.matmul(Z, K_A, transpose_b=True)) + + # Z = Z - P_G^T * Z * P_A + Z -= math_ops.matmul(P_G, math_ops.matmul(Z, P_A), transpose_a=True) + + # Z = normalize (1/E[T]) * Z + # Note that this normalization is done because we compute the statistics + # by averaging, not summing, over time. (And the gradient is presumably + # summed over time, not averaged, and thus their scales are different.) + Z /= math_ops.cast(self._num_timesteps, Z.dtype) + + # Convert back to the "batch_dim==0" orientation. + Z = array_ops.transpose(Z) + + return utils.mat2d_to_layer_params(vector, Z) + + # pylint: enable=invalid-name + + def multiply(self, vector): + raise NotImplementedError + + def tensors_to_compute_grads(self): + return self._outputs + + def num_inputs(self): + return len(self._inputs) diff --git a/tensorflow/contrib/kfac/python/ops/fisher_blocks_lib.py b/tensorflow/contrib/kfac/python/ops/fisher_blocks_lib.py index 59389f8d385c18f50914d690cfaa2825ef807ed3..ac396309206fe09af65c2b70840a513fb25b579b 100644 --- a/tensorflow/contrib/kfac/python/ops/fisher_blocks_lib.py +++ b/tensorflow/contrib/kfac/python/ops/fisher_blocks_lib.py @@ -33,6 +33,10 @@ _allowed_symbols = [ 'ConvKFCBasicFB', 'ConvDiagonalFB', 'set_global_constants', + 'compute_pi_tracenorm', + 'compute_pi_adjusted_damping', + 'num_conv_locations', + 'normalize_damping' ] remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/kfac/python/ops/fisher_factors.py b/tensorflow/contrib/kfac/python/ops/fisher_factors.py index fbc192f1dcfa0b384e2cb31c43af3651436321ea..f59168cbc05fffd104ff5a44308eefd206beb9db 100644 --- a/tensorflow/contrib/kfac/python/ops/fisher_factors.py +++ b/tensorflow/contrib/kfac/python/ops/fisher_factors.py @@ -27,6 +27,8 @@ import six from tensorflow.contrib.kfac.python.ops import utils from tensorflow.python.framework import ops as tf_ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import init_ops from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import special_math_ops @@ -50,11 +52,15 @@ EIGENVALUE_DECOMPOSITION_THRESHOLD = 2 # matrix powers. Must be nonnegative. EIGENVALUE_CLIPPING_THRESHOLD = 0.0 +# Colocate the covariance ops and variables with the input tensors for each +# factor. +COLOCATE_COV_OPS_WITH_INPUTS = True + @contextlib.contextmanager -def _maybe_colocate_with(op, colocate_cov_ops_with_inputs): - """Context to colocate with `op` if `colocate_cov_ops_with_inputs`.""" - if colocate_cov_ops_with_inputs: +def maybe_colocate_with(op): + """Context to colocate with `op` if `COLOCATE_COV_OPS_WITH_INPUTS`.""" + if COLOCATE_COV_OPS_WITH_INPUTS: if isinstance(op, (list, tuple)): with tf_ops.colocate_with(op[0]): yield @@ -68,12 +74,14 @@ def _maybe_colocate_with(op, colocate_cov_ops_with_inputs): def set_global_constants(init_covariances_at_zero=None, zero_debias=None, eigenvalue_decomposition_threshold=None, - eigenvalue_clipping_threshold=None): + eigenvalue_clipping_threshold=None, + colocate_cov_ops_with_inputs=None): """Sets various global constants used by the classes in this module.""" global INIT_COVARIANCES_AT_ZERO global ZERO_DEBIAS global EIGENVALUE_DECOMPOSITION_THRESHOLD global EIGENVALUE_CLIPPING_THRESHOLD + global COLOCATE_COV_OPS_WITH_INPUTS if init_covariances_at_zero is not None: INIT_COVARIANCES_AT_ZERO = init_covariances_at_zero @@ -83,6 +91,8 @@ def set_global_constants(init_covariances_at_zero=None, EIGENVALUE_DECOMPOSITION_THRESHOLD = eigenvalue_decomposition_threshold if eigenvalue_clipping_threshold is not None: EIGENVALUE_CLIPPING_THRESHOLD = eigenvalue_clipping_threshold + if colocate_cov_ops_with_inputs is not None: + COLOCATE_COV_OPS_WITH_INPUTS = colocate_cov_ops_with_inputs def inverse_initializer(shape, dtype, partition_info=None): # pylint: disable=unused-argument @@ -101,7 +111,7 @@ def diagonal_covariance_initializer(shape, dtype, partition_info): # pylint: di return array_ops.ones(shape, dtype) -def _compute_cov(tensor, normalizer=None): +def compute_cov(tensor, tensor_right=None, normalizer=None): """Compute the empirical second moment of the rows of a 2D Tensor. This function is meant to be applied to random matrices for which the true row @@ -109,6 +119,8 @@ def _compute_cov(tensor, normalizer=None): Args: tensor: A 2D Tensor. + tensor_right: An optional 2D Tensor. If provided, this function computes + the matrix product tensor^T * tensor_right instead of tensor^T * tensor. normalizer: optional scalar for the estimator (by default, the normalizer is the number of rows of tensor). @@ -117,12 +129,17 @@ def _compute_cov(tensor, normalizer=None): """ if normalizer is None: normalizer = array_ops.shape(tensor)[0] - cov = (math_ops.matmul(tensor, tensor, transpose_a=True) / math_ops.cast( - normalizer, tensor.dtype)) - return (cov + array_ops.transpose(cov)) / math_ops.cast(2, cov.dtype) + if tensor_right is None: + cov = ( + math_ops.matmul(tensor, tensor, transpose_a=True) / math_ops.cast( + normalizer, tensor.dtype)) + return (cov + array_ops.transpose(cov)) / math_ops.cast(2.0, cov.dtype) + else: + return (math_ops.matmul(tensor, tensor_right, transpose_a=True) / + math_ops.cast(normalizer, tensor.dtype)) -def _append_homog(tensor): +def append_homog(tensor): """Appends a homogeneous coordinate to the last dimension of a Tensor. Args: @@ -135,7 +152,7 @@ def _append_homog(tensor): rank = len(tensor.shape.as_list()) shape = array_ops.concat([array_ops.shape(tensor)[:-1], [1]], axis=0) ones = array_ops.ones(shape, dtype=tensor.dtype) - return array_ops.concat([tensor, ones], axis=rank-1) + return array_ops.concat([tensor, ones], axis=rank - 1) def scope_string_from_params(params): @@ -173,8 +190,8 @@ def scope_string_from_params(params): elif isinstance(param, (tf_ops.Tensor, variables.Variable)): name_parts.append(scope_string_from_name(param)) else: - raise ValueError( - "Encountered an unsupported param type {}".format(type(param))) + raise ValueError("Encountered an unsupported param type {}".format( + type(param))) return "_".join(name_parts) @@ -225,6 +242,10 @@ class FisherFactor(object): """ pass + @abc.abstractproperty + def _dtype(self): + pass + @property def _cov_initializer(self): return covariance_initializer @@ -236,7 +257,8 @@ class FisherFactor(object): "cov", initializer=self._cov_initializer, shape=self._cov_shape, - trainable=False) + trainable=False, + dtype=self._dtype) @abc.abstractmethod def _compute_new_cov(self, idx=0): @@ -250,15 +272,27 @@ class FisherFactor(object): Returns: An Op for updating the covariance Variable referenced by _cov. """ - new_cov = math_ops.add_n( - tuple(self._compute_new_cov(idx) for idx in range(self._num_sources))) - - return moving_averages.assign_moving_average( - self._cov, new_cov, ema_decay, zero_debias=ZERO_DEBIAS) + new_cov_contribs = tuple(self._compute_new_cov(idx) + for idx in range(self._num_sources)) + # This gets the job done but we might want a better solution in the future. + # In particular, we could have a separate way of specifying where the + # the cov variables finally end up, independent of where their various + # contributions are computed. Right now these are the same thing, but in + # the future we might want to perform the cov computations on each tower, + # so that each tower will be considered a "source" (allowing us to reuse + # the existing "source" code for this). + with maybe_colocate_with(new_cov_contribs[0]): + new_cov = math_ops.add_n(new_cov_contribs) + # Synchronize value across all TPU cores. + if utils.on_tpu(): + new_cov = utils.cross_replica_mean(new_cov) + return moving_averages.assign_moving_average( + self._cov, new_cov, ema_decay, zero_debias=ZERO_DEBIAS) + @abc.abstractmethod def make_inverse_update_ops(self): """Create and return update ops corresponding to registered computations.""" - return [] + pass def get_cov(self): return self._cov @@ -273,6 +307,13 @@ class InverseProvidingFactor(FisherFactor): _cov_shape properties. """ + # TODO(b/69108481): This class (and its subclasses) should be refactored to + # serve the matrix quantities it computes as both (potentially stale) + # variables, updated by the inverse update ops, and fresh values stored in + # tensors that recomputed once every session.run() call. Currently matpower + # and damp_inverse have the former behavior, while eigendecomposition has + # the latter. + def __init__(self): self._inverses_by_damping = {} self._matpower_by_exp_and_damping = {} @@ -283,6 +324,10 @@ class InverseProvidingFactor(FisherFactor): def register_damped_inverse(self, damping): """Registers a damped inverse needed by a FisherBlock. + This creates a variable and signals make_inverse_update_ops to make the + corresponding update op. The variable can be read via the method + get_inverse. + Args: damping: The damping value (float or Tensor) for this factor. """ @@ -293,12 +338,17 @@ class InverseProvidingFactor(FisherFactor): "inv_damp{}".format(damping_string), initializer=inverse_initializer, shape=self._cov_shape, - trainable=False) + trainable=False, + dtype=self._dtype) self._inverses_by_damping[damping] = inv def register_matpower(self, exp, damping): """Registers a matrix power needed by a FisherBlock. + This creates a variable and signals make_inverse_update_ops to make the + corresponding update op. The variable can be read via the method + get_matpower. + Args: exp: The exponent (float or Tensor) to raise the matrix to. damping: The damping value (float or Tensor). @@ -311,59 +361,81 @@ class InverseProvidingFactor(FisherFactor): "matpower_exp{}_damp{}".format(exp_string, damping_string), initializer=inverse_initializer, shape=self._cov_shape, - trainable=False) + trainable=False, + dtype=self._dtype) self._matpower_by_exp_and_damping[(exp, damping)] = matpower - def register_eigendecomp(self): - """Registers that an eigendecomposition is needed by a FisherBlock.""" - if not self._eigendecomp: - self._eigendecomp = linalg_ops.self_adjoint_eig(self._cov) - def make_inverse_update_ops(self): """Create and return update ops corresponding to registered computations.""" - ops = super(InverseProvidingFactor, self).make_inverse_update_ops() + ops = [] + + # We do this to ensure that we don't reuse the eigendecomp from old calls + # to make_inverse_update_ops that may be placed on different devices. This + # can happen is the user has both a permanent and lazily constructed + # version of the inverse ops (and only uses one of them). + self.reset_eigendecomp() num_inverses = len(self._inverses_by_damping) matrix_power_registered = bool(self._matpower_by_exp_and_damping) - use_eig = (self._eigendecomp or matrix_power_registered or - num_inverses >= EIGENVALUE_DECOMPOSITION_THRESHOLD) + use_eig = ( + self._eigendecomp or matrix_power_registered or + num_inverses >= EIGENVALUE_DECOMPOSITION_THRESHOLD) if use_eig: - self.register_eigendecomp() # ensures self._eigendecomp is set - eigenvalues, eigenvectors = self._eigendecomp # pylint: disable=unpacking-non-sequence - - # The matrix self._cov is positive semidefinite by construction, but the - # numerical eigenvalues could be negative due to numerical errors, so here - # we clip them to be at least EIGENVALUE_CLIPPING_THRESHOLD. - clipped_eigenvalues = math_ops.maximum(eigenvalues, - EIGENVALUE_CLIPPING_THRESHOLD) + eigenvalues, eigenvectors = self.get_eigendecomp() # pylint: disable=unpacking-non-sequence for damping, inv in self._inverses_by_damping.items(): ops.append( inv.assign( - math_ops.matmul(eigenvectors / (clipped_eigenvalues + damping), + math_ops.matmul(eigenvectors / (eigenvalues + damping), array_ops.transpose(eigenvectors)))) for (exp, damping), matpower in self._matpower_by_exp_and_damping.items(): ops.append( matpower.assign( - math_ops.matmul(eigenvectors * (clipped_eigenvalues + damping)** - exp, array_ops.transpose(eigenvectors)))) + math_ops.matmul(eigenvectors * + (eigenvalues + damping)**exp, + array_ops.transpose(eigenvectors)))) + # These ops share computation and should be run on a single device. + ops = [control_flow_ops.group(*ops)] else: for damping, inv in self._inverses_by_damping.items(): ops.append(inv.assign(utils.posdef_inv(self._cov, damping))) return ops - def get_inverse(self, damping): + def get_damped_inverse(self, damping): + # Note that this function returns a variable which gets updated by the + # inverse ops. It may be stale / inconsistent with the latest value of + # get_cov(). return self._inverses_by_damping[damping] def get_matpower(self, exp, damping): + # Note that this function returns a variable which gets updated by the + # inverse ops. It may be stale / inconsistent with the latest value of + # get_cov(). return self._matpower_by_exp_and_damping[(exp, damping)] def get_eigendecomp(self): + """Creates or retrieves eigendecomposition of self._cov.""" + # Unlike get_inverse and get_matpower this doesn't retrieve a stored + # variable, but instead always computes a fresh version from the current + # value of get_cov(). + if not self._eigendecomp: + eigenvalues, eigenvectors = linalg_ops.self_adjoint_eig(self._cov) + + # The matrix self._cov is positive semidefinite by construction, but the + # numerical eigenvalues could be negative due to numerical errors, so here + # we clip them to be at least FLAGS.eigenvalue_clipping_threshold + clipped_eigenvalues = math_ops.maximum(eigenvalues, + EIGENVALUE_CLIPPING_THRESHOLD) + self._eigendecomp = (clipped_eigenvalues, eigenvectors) + return self._eigendecomp + def reset_eigendecomp(self): + self._eigendecomp = None + class FullFactor(InverseProvidingFactor): """FisherFactor for a full matrix representation of the Fisher of a parameter. @@ -374,41 +446,38 @@ class FullFactor(InverseProvidingFactor): def __init__(self, params_grads, - batch_size, - colocate_cov_ops_with_inputs=False): + batch_size): self._batch_size = batch_size - self._colocate_cov_ops_with_inputs = colocate_cov_ops_with_inputs - self._orig_params_grads_name = scope_string_from_params( - [params_grads, self._batch_size]) - params_grads_flat = [] - for params_grad in params_grads: - with _maybe_colocate_with(params_grad, - self._colocate_cov_ops_with_inputs): - col = utils.tensors_to_column(params_grad) - params_grads_flat.append(col) - self._params_grads_flat = tuple(params_grads_flat) + self._params_grads = tuple(utils.ensure_sequence(params_grad) + for params_grad in params_grads) super(FullFactor, self).__init__() @property def _var_scope(self): - return "ff_full/" + self._orig_params_grads_name + return "ff_full/" + scope_string_from_params( + [self._params_grads, self._batch_size]) @property def _cov_shape(self): - size = self._params_grads_flat[0].shape[0] - return [size, size] + size = sum(param_grad.shape.num_elements() + for param_grad in self._params_grads[0]) + return (size, size) @property def _num_sources(self): - return len(self._params_grads_flat) + return len(self._params_grads) + + @property + def _dtype(self): + return self._params_grads[0][0].dtype def _compute_new_cov(self, idx=0): # This will be a very basic rank 1 estimate - with _maybe_colocate_with(self._params_grads_flat[idx], - self._colocate_cov_ops_with_inputs): - return ((self._params_grads_flat[idx] * array_ops.transpose( - self._params_grads_flat[idx])) / math_ops.cast( - self._batch_size, self._params_grads_flat[idx].dtype)) + with maybe_colocate_with(self._params_grads[idx]): + params_grads_flat = utils.tensors_to_column(self._params_grads[idx]) + return ((params_grads_flat * array_ops.transpose( + params_grads_flat)) / math_ops.cast(self._batch_size, + params_grads_flat.dtype)) class DiagonalFactor(FisherFactor): @@ -421,6 +490,9 @@ class DiagonalFactor(FisherFactor): def _cov_initializer(self): return diagonal_covariance_initializer + def make_inverse_update_ops(self): + return [] + class NaiveDiagonalFactor(DiagonalFactor): """FisherFactor for a diagonal approximation of any type of param's Fisher. @@ -431,38 +503,36 @@ class NaiveDiagonalFactor(DiagonalFactor): def __init__(self, params_grads, - batch_size, - colocate_cov_ops_with_inputs=False): + batch_size): + self._params_grads = tuple(utils.ensure_sequence(params_grad) + for params_grad in params_grads) self._batch_size = batch_size - self._colocate_cov_ops_with_inputs = colocate_cov_ops_with_inputs - params_grads_flat = [] - for params_grad in params_grads: - with _maybe_colocate_with(params_grad, - self._colocate_cov_ops_with_inputs): - col = utils.tensors_to_column(params_grad) - params_grads_flat.append(col) - self._params_grads = tuple(params_grads_flat) - self._orig_params_grads_name = scope_string_from_params( - [self._params_grads, self._batch_size]) super(NaiveDiagonalFactor, self).__init__() @property def _var_scope(self): - return "ff_naivediag/" + self._orig_params_grads_name + return "ff_naivediag/" + scope_string_from_params( + [self._params_grads, self._batch_size]) @property def _cov_shape(self): - return self._params_grads[0].shape + size = sum(param_grad.shape.num_elements() + for param_grad in self._params_grads[0]) + return (size, 1) @property def _num_sources(self): return len(self._params_grads) + @property + def _dtype(self): + return self._params_grads[0][0].dtype + def _compute_new_cov(self, idx=0): - with _maybe_colocate_with(self._params_grads[idx], - self._colocate_cov_ops_with_inputs): - return (math_ops.square(self._params_grads[idx]) / math_ops.cast( - self._batch_size, self._params_grads[idx].dtype)) + with maybe_colocate_with(self._params_grads[idx]): + params_grads_flat = utils.tensors_to_column(self._params_grads[idx]) + return (math_ops.square(params_grads_flat) / math_ops.cast( + self._batch_size, params_grads_flat.dtype)) class FullyConnectedDiagonalFactor(DiagonalFactor): @@ -471,18 +541,15 @@ class FullyConnectedDiagonalFactor(DiagonalFactor): Given in = [batch_size, input_size] and out_grad = [batch_size, output_size], approximates the covariance as, - Cov(in, out) = (1/batch_size) \sum_{i} outer(in[i], out_grad[i]) ** 2.0 + Cov(in, out) = (1/batch_size) sum_{i} outer(in[i], out_grad[i]) ** 2.0 where the square is taken element-wise. """ - # TODO(jamesmartens): add units tests for this class - def __init__(self, inputs, outputs_grads, - has_bias=False, - colocate_cov_ops_with_inputs=False): + has_bias=False): """Instantiate FullyConnectedDiagonalFactor. Args: @@ -491,44 +558,46 @@ class FullyConnectedDiagonalFactor(DiagonalFactor): outputs_grads: List of Tensors of shape [batch_size, output_size]. Gradient of loss with respect to layer's preactivations. has_bias: bool. If True, append '1' to each input. - colocate_cov_ops_with_inputs: Whether to colocate cov_update ops with - their inputs. """ + self._inputs = inputs + self._has_bias = has_bias self._outputs_grads = outputs_grads - self._colocate_cov_ops_with_inputs = colocate_cov_ops_with_inputs self._batch_size = array_ops.shape(inputs)[0] - self._orig_tensors_name = scope_string_from_params((inputs,) + - tuple(outputs_grads)) - - # Note that we precompute the required operations on the inputs since the - # inputs don't change with the 'idx' argument to _compute_new_cov. (Only - # the target entry of _outputs_grads changes with idx.) - with _maybe_colocate_with(inputs, self._colocate_cov_ops_with_inputs): - if has_bias: - inputs = _append_homog(inputs) - self._squared_inputs = math_ops.square(inputs) + self._squared_inputs = None super(FullyConnectedDiagonalFactor, self).__init__() @property def _var_scope(self): - return "ff_diagfc/" + self._orig_tensors_name + return "ff_diagfc/" + scope_string_from_params( + (self._inputs,) + tuple(self._outputs_grads)) @property def _cov_shape(self): - return [self._squared_inputs.shape[1], self._outputs_grads[0].shape[1]] + return [self._inputs.shape[1] + self._has_bias, + self._outputs_grads[0].shape[1]] @property def _num_sources(self): return len(self._outputs_grads) + @property + def _dtype(self): + return self._outputs_grads[0].dtype + def _compute_new_cov(self, idx=0): # The well-known special formula that uses the fact that the entry-wise # square of an outer product is the outer-product of the entry-wise squares. # The gradient is the outer product of the input and the output gradients, # so we just square both and then take their outer-product. - with _maybe_colocate_with(self._squared_inputs, - self._colocate_cov_ops_with_inputs): + with maybe_colocate_with(self._outputs_grads[idx]): + # We only need to compute squared_inputs once + if self._squared_inputs is None: + inputs = self._inputs + if self._has_bias: + inputs = append_homog(self._inputs) + self._squared_inputs = math_ops.square(inputs) + new_cov = math_ops.matmul( self._squared_inputs, math_ops.square(self._outputs_grads[idx]), @@ -540,16 +609,13 @@ class FullyConnectedDiagonalFactor(DiagonalFactor): class ConvDiagonalFactor(DiagonalFactor): """FisherFactor for a diagonal approx of a convolutional layer's Fisher.""" - # TODO(jamesmartens): add units tests for this class - def __init__(self, inputs, outputs_grads, filter_shape, strides, padding, - has_bias=False, - colocate_cov_ops_with_inputs=False): + has_bias=False): """Creates a ConvDiagonalFactor object. Args: @@ -564,53 +630,64 @@ class ConvDiagonalFactor(DiagonalFactor): padding: The padding in this layer (1-D of Tensor length 4). has_bias: Python bool. If True, the layer is assumed to have a bias parameter in addition to its filter parameter. - colocate_cov_ops_with_inputs: Whether to colocate cov_update ops with - their inputs. """ + self._inputs = inputs self._filter_shape = filter_shape + self._strides = strides + self._padding = padding self._has_bias = has_bias self._outputs_grads = outputs_grads - self._colocate_cov_ops_with_inputs = colocate_cov_ops_with_inputs - - self._orig_tensors_name = scope_string_from_name((inputs,) - + tuple(outputs_grads)) - - # Note that we precompute the required operations on the inputs since the - # inputs don't change with the 'idx' argument to _compute_new_cov. (Only - # the target entry of _outputs_grads changes with idx.) - with _maybe_colocate_with(inputs, self._colocate_cov_ops_with_inputs): - filter_height, filter_width, _, _ = self._filter_shape - patches = array_ops.extract_image_patches( - inputs, - ksizes=[1, filter_height, filter_width, 1], - strides=strides, - rates=[1, 1, 1, 1], - padding=padding) - - if has_bias: - patches = _append_homog(patches) - - self._patches = patches + self._patches = None super(ConvDiagonalFactor, self).__init__() @property def _var_scope(self): - return "ff_convdiag/" + self._orig_tensors_name + return "ff_convdiag/" + scope_string_from_name( + (self._inputs,) + tuple(self._outputs_grads)) @property def _cov_shape(self): filter_height, filter_width, in_channels, out_channels = self._filter_shape - return [filter_height * filter_width * in_channels + self._has_bias, - out_channels] + return [ + filter_height * filter_width * in_channels + self._has_bias, + out_channels + ] @property def _num_sources(self): return len(self._outputs_grads) + @property + def _dtype(self): + return self._outputs_grads[0].dtype + + def make_covariance_update_op(self, ema_decay): + with maybe_colocate_with(self._inputs): + filter_height, filter_width, _, _ = self._filter_shape + + # TODO(b/64144716): there is potential here for a big savings in terms + # of memory use. + patches = array_ops.extract_image_patches( + self._inputs, + ksizes=[1, filter_height, filter_width, 1], + strides=self._strides, + rates=[1, 1, 1, 1], + padding=self._padding) + + if self._has_bias: + patches = append_homog(patches) + + self._patches = patches + + op = super(ConvDiagonalFactor, self).make_covariance_update_op(ema_decay) + + self._patches = None + + return op + def _compute_new_cov(self, idx=0): - with _maybe_colocate_with(self._outputs_grads[idx], - self._colocate_cov_ops_with_inputs): + with maybe_colocate_with(self._outputs_grads[idx]): outputs_grad = self._outputs_grads[idx] batch_size = array_ops.shape(self._patches)[0] @@ -634,23 +711,18 @@ class FullyConnectedKroneckerFactor(InverseProvidingFactor): def __init__(self, tensors, - has_bias=False, - colocate_cov_ops_with_inputs=False): + has_bias=False): """Instantiate FullyConnectedKroneckerFactor. Args: tensors: List of Tensors of shape [batch_size, n]. Represents either a layer's inputs or its output's gradients. - has_bias: bool. If True, assume this factor is for the layer's inputs and - append '1' to each row. - colocate_cov_ops_with_inputs: Whether to colocate cov_update ops with - their inputs. + has_bias: bool. If True, append '1' to each row. """ # The tensor argument is either a tensor of input activations or a tensor of # output pre-activation gradients. self._has_bias = has_bias self._tensors = tensors - self._colocate_cov_ops_with_inputs = colocate_cov_ops_with_inputs super(FullyConnectedKroneckerFactor, self).__init__() @property @@ -667,13 +739,16 @@ class FullyConnectedKroneckerFactor(InverseProvidingFactor): def _num_sources(self): return len(self._tensors) + @property + def _dtype(self): + return self._tensors[0].dtype + def _compute_new_cov(self, idx=0): - with _maybe_colocate_with(self._tensors[idx], - self._colocate_cov_ops_with_inputs): + with maybe_colocate_with(self._tensors[idx]): tensor = self._tensors[idx] if self._has_bias: - tensor = _append_homog(tensor) - return _compute_cov(tensor) + tensor = append_homog(tensor) + return compute_cov(tensor) class ConvInputKroneckerFactor(InverseProvidingFactor): @@ -682,7 +757,7 @@ class ConvInputKroneckerFactor(InverseProvidingFactor): Estimates E[ a a^T ] where a is the inputs to a convolutional layer given example x. Expectation is taken over all examples and locations. - Equivalent to \Omega in https://arxiv.org/abs/1602.01407 for details. See + Equivalent to Omega in https://arxiv.org/abs/1602.01407 for details. See Section 3.1 Estimating the factors. """ @@ -691,8 +766,7 @@ class ConvInputKroneckerFactor(InverseProvidingFactor): filter_shape, strides, padding, - has_bias=False, - colocate_cov_ops_with_inputs=False): + has_bias=False): """Initializes ConvInputKroneckerFactor. Args: @@ -704,15 +778,12 @@ class ConvInputKroneckerFactor(InverseProvidingFactor): width_stride, in_channel_stride]. padding: str. Padding method for layer. "SAME" or "VALID". has_bias: bool. If True, append 1 to in_channel. - colocate_cov_ops_with_inputs: Whether to colocate cov_update ops with - their inputs. """ self._filter_shape = filter_shape self._strides = strides self._padding = padding self._has_bias = has_bias self._inputs = inputs - self._colocate_cov_ops_with_inputs = colocate_cov_ops_with_inputs super(ConvInputKroneckerFactor, self).__init__() @property @@ -732,13 +803,19 @@ class ConvInputKroneckerFactor(InverseProvidingFactor): def _num_sources(self): return 1 + @property + def _dtype(self): + return self._inputs.dtype + def _compute_new_cov(self, idx=0): if idx != 0: raise ValueError("ConvInputKroneckerFactor only supports idx = 0") - # TODO(jamesmartens): factor this patches stuff out into a utility function - with _maybe_colocate_with(self._inputs, self._colocate_cov_ops_with_inputs): + with maybe_colocate_with(self._inputs): filter_height, filter_width, in_channels, _ = self._filter_shape + + # TODO(b/64144716): there is potential here for a big savings in terms of + # memory use. patches = array_ops.extract_image_patches( self._inputs, ksizes=[1, filter_height, filter_width, 1], @@ -747,12 +824,24 @@ class ConvInputKroneckerFactor(InverseProvidingFactor): padding=self._padding) flatten_size = (filter_height * filter_width * in_channels) + # patches_flat below is the matrix [[A_l]] from the KFC paper (tilde + # omitted over A for clarity). It has shape M|T| x J|Delta| (eq. 14), + # where M = minibatch size, |T| = number of spatial locations, + # |Delta| = number of spatial offsets, and J = number of input maps + # for convolutional layer l. patches_flat = array_ops.reshape(patches, [-1, flatten_size]) - + # We append a homogenous coordinate to patches_flat if the layer has + # bias parameters. This gives us [[A_l]]_H from the paper. if self._has_bias: - patches_flat = _append_homog(patches_flat) - - return _compute_cov(patches_flat) + patches_flat = append_homog(patches_flat) + # We call compute_cov without passing in a normalizer. compute_cov uses + # the first dimension of patches_flat i.e. M|T| as the normalizer by + # default. Hence we end up computing 1/M|T| * [[A_l]]^T [[A_l]], with + # shape J|Delta| x J|Delta|. This is related to hat{Omega}_l from + # the paper but has a different scale here for consistency with + # ConvOutputKroneckerFactor. + # (Tilde omitted over A for clarity.) + return compute_cov(patches_flat) class ConvOutputKroneckerFactor(InverseProvidingFactor): @@ -762,22 +851,19 @@ class ConvOutputKroneckerFactor(InverseProvidingFactor): given example x and ds = (d / d s) log(p(y|x, w)). Expectation is taken over all examples and locations. - Equivalent to \Gamma in https://arxiv.org/abs/1602.01407 for details. See + Equivalent to Gamma in https://arxiv.org/abs/1602.01407 for details. See Section 3.1 Estimating the factors. """ - def __init__(self, outputs_grads, colocate_cov_ops_with_inputs=False): + def __init__(self, outputs_grads): """Initializes ConvOutputKroneckerFactor. Args: outputs_grads: list of Tensors. Each Tensor is of shape [batch_size, height, width, out_channels]. - colocate_cov_ops_with_inputs: Whether to colocate cov_update ops with - their inputs. """ self._out_channels = outputs_grads[0].shape.as_list()[3] self._outputs_grads = outputs_grads - self._colocate_cov_ops_with_inputs = colocate_cov_ops_with_inputs super(ConvOutputKroneckerFactor, self).__init__() @property @@ -793,9 +879,292 @@ class ConvOutputKroneckerFactor(InverseProvidingFactor): def _num_sources(self): return len(self._outputs_grads) + @property + def _dtype(self): + return self._outputs_grads[0].dtype + def _compute_new_cov(self, idx=0): - with _maybe_colocate_with(self._outputs_grads[idx], - self._colocate_cov_ops_with_inputs): + with maybe_colocate_with(self._outputs_grads[idx]): + # reshaped_tensor below is the matrix DS_l defined in the KFC paper + # (tilde omitted over S for clarity). It has shape M|T| x I, where + # M = minibatch size, |T| = number of spatial locations, and + # I = number of output maps for convolutional layer l. reshaped_tensor = array_ops.reshape(self._outputs_grads[idx], [-1, self._out_channels]) - return _compute_cov(reshaped_tensor) + # Following the reasoning in ConvInputKroneckerFactor._compute_new_cov, + # compute_cov here returns 1/M|T| * DS_l^T DS_l = hat{Gamma}_l + # as defined in the paper, with shape I x I. + # (Tilde omitted over S for clarity.) + return compute_cov(reshaped_tensor) + + +class FullyConnectedMultiKF(InverseProvidingFactor): + """Kronecker factor for a fully connected recurrent layer.""" + + def __init__(self, + tensor_lists, + has_bias=False): + """Constructs a new `FullyConnectedMultiKF`. + + Args: + tensor_lists: List of lists of Tensors of shape [batch_size, n]. + has_bias: bool. If True, '1' is appended to each row. + """ + + self._tensor_lists = tensor_lists + self._has_bias = has_bias + self._batch_size = array_ops.shape(tensor_lists[0][0])[0] + self._num_timesteps = len(tensor_lists[0]) + self._tensors = [None] * len(tensor_lists) + + self._cov_dt1 = None + self._option1quants_by_damping = {} + self._option2quants_by_damping = {} + + super(FullyConnectedMultiKF, self).__init__() + + @property + def _var_scope(self): + return "ff_fc_multi/" + scope_string_from_params(self._tensor_lists) + + @property + def _num_sources(self): + return len(self._tensor_lists) + + @property + def _dtype(self): + return self._tensor_lists[0][0].dtype + + def make_covariance_update_op(self, ema_decay): + + op = super(FullyConnectedMultiKF, self).make_covariance_update_op(ema_decay) + + if self._cov_dt1 is not None: + new_cov_dt1_contribs = tuple(self._compute_new_cov_dt1(idx) + for idx in range(self._num_sources)) + + with maybe_colocate_with(new_cov_dt1_contribs[0]): + new_cov_dt1 = math_ops.add_n(new_cov_dt1_contribs) + + op2 = moving_averages.assign_moving_average( + self._cov_dt1, new_cov_dt1, ema_decay, zero_debias=ZERO_DEBIAS) + + # TODO(b/69112164): + # It's important that _cov and _cov_dt1 remain consistent with each + # other while the inverse ops are happening. How can we ensure this? + # We will need to add explicit synchronization for this to + # work with asynchronous training. + op = control_flow_ops.group(op, op2) + + return op + + def _compute_new_cov(self, idx=0): + with maybe_colocate_with(self._tensor_lists[idx]): + tensor = array_ops.concat(self._tensor_lists[idx], 0) + if self._has_bias: + tensor = append_homog(tensor) + # We save these so they can be used by _compute_new_cov_dt1 + self._tensors[idx] = tensor + return compute_cov(tensor) + + def _compute_new_cov_dt1(self, idx=0): + tensor = self._tensors[idx] + with maybe_colocate_with(tensor): + # Is there a more elegant way to do this computation? + tensor_present = tensor[:-self._batch_size, :] + tensor_future = tensor[self._batch_size:, :] + # We specify a normalizer for this computation to ensure a PSD Fisher + # block estimate. This is equivalent to padding with zeros, as was done + # in Section B.2 of the appendix. + normalizer = self._num_timesteps * self._batch_size + return compute_cov( + tensor_future, tensor_right=tensor_present, normalizer=normalizer) + + @property + def _cov_shape(self): + size = self._tensor_lists[0][0].shape[1] + self._has_bias + return [size, size] + + @property + def _vec_shape(self): + size = self._tensor_lists[0][0].shape[1] + self._has_bias + return [size] + + def get_option1quants(self, damping): + return self._option1quants_by_damping[damping] + + def get_option2quants(self, damping): + return self._option2quants_by_damping[damping] + + def get_cov_dt1(self): + assert self._cov_dt1 is not None + return self._cov_dt1 + + def register_cov_dt1(self): + """Create a variable representing temporal cross-covariance. + + (This is technically the second moment, not covariance, since it's + not mean subtracted.) + """ + if self._cov_dt1 is None: + with variable_scope.variable_scope(self._var_scope): + self._cov_dt1 = variable_scope.get_variable( + "cov_dt1", + initializer=init_ops.zeros_initializer, + shape=self._cov_shape, + trainable=False, + dtype=self._dtype) + + def register_option1quants(self, damping): + + self.register_cov_dt1() + + if damping not in self._option1quants_by_damping: + # It's questionable as to whether we should initialize with stuff like + # this at all. Ideally these values should never be used until they are + # updated at least once. + damping_string = scalar_or_tensor_to_string(damping) + with variable_scope.variable_scope(self._var_scope): + Lmat = variable_scope.get_variable( # pylint: disable=invalid-name + "Lmat_damp{}".format(damping_string), + initializer=inverse_initializer, + shape=self._cov_shape, + trainable=False, + dtype=self._dtype) + psi = variable_scope.get_variable( + "psi_damp{}".format(damping_string), + initializer=init_ops.ones_initializer, + shape=self._vec_shape, + trainable=False, + dtype=self._dtype) + + self._option1quants_by_damping[damping] = (Lmat, psi) + + def register_option2quants(self, damping): + + self.register_cov_dt1() + + if damping not in self._option2quants_by_damping: + # It's questionable as to whether we should initialize with stuff like + # this at all. Ideally these values should never be used until they are + # updated at least once. + damping_string = scalar_or_tensor_to_string(damping) + with variable_scope.variable_scope(self._var_scope): + Pmat = variable_scope.get_variable( # pylint: disable=invalid-name + "Lmat_damp{}".format(damping_string), + initializer=inverse_initializer, + shape=self._cov_shape, + trainable=False, + dtype=self._dtype) + Kmat = variable_scope.get_variable( # pylint: disable=invalid-name + "Kmat_damp{}".format(damping_string), + initializer=inverse_initializer, + shape=self._cov_shape, + trainable=False, + dtype=self._dtype) + mu = variable_scope.get_variable( + "mu_damp{}".format(damping_string), + initializer=init_ops.ones_initializer, + shape=self._vec_shape, + trainable=False, + dtype=self._dtype) + + self._option2quants_by_damping[damping] = (Pmat, Kmat, mu) + + def make_inverse_update_ops(self): + """Create and return update ops corresponding to registered computations.""" + # TODO(b/69918258): Add correctness tests for this method. + # pylint: disable=invalid-name + + ops = super(FullyConnectedMultiKF, self).make_inverse_update_ops() + + if (len(self._option1quants_by_damping) + + len(self._option2quants_by_damping)): + + # Note that C0 and C1 are stand-ins for A0 and A1, or G0 and G1, from + # the pseudo-code in the original paper. Because the computations for + # the A and G case are essentially the same they can both be performed by + # the same class (this one). + + C1 = self.get_cov_dt1() + + # Get the eigendecomposition of C0 (= self.get_cov()) + eigen_e, eigen_V = self.get_eigendecomp() + + # TODO(b/69678661): Note, there is an implicit assumption here that C1 + # and C0 (as represented here by its eigen-decomp) are consistent. This + # could fail to be the case if self._cov and self._cov_dt1 are not updated + # consistently, or are somehow read between or during the cov updates. + # Can this possibly happen? Is there a way to prevent it? + + for damping, (Lmat_var, + psi_var) in self._option1quants_by_damping.items(): + + invsqrtC0 = math_ops.matmul( + eigen_V * (eigen_e + damping)**(-0.5), eigen_V, transpose_b=True) + + # Might need to enforce symmetry lost due to numerical issues. + invsqrtC0 = (invsqrtC0 + array_ops.transpose(invsqrtC0)) / 2.0 + + # The following line imposses the symmetry assumed by "Option 1" on C1. + # Stangely the code can work okay with this line commented out, + # depending on how psd_eig is defined. I'm not sure why. + C1 = (C1 + array_ops.transpose(C1)) / 2.0 + + # hPsi = C0^(-1/2) * C1 * C0^(-1/2) (hPsi means hat{Psi}) + hPsi = math_ops.matmul(math_ops.matmul(invsqrtC0, C1), invsqrtC0) + + # Compute the decomposition U*diag(psi)*U^T = hPsi + psi, U = utils.posdef_eig(hPsi) + + # L = C0^(-1/2) * U + Lmat = math_ops.matmul(invsqrtC0, U) + + ops.append(Lmat_var.assign(Lmat)) + ops.append(psi_var.assign(psi)) + + for damping, (Pmat_var, Kmat_var, + mu_var) in self._option2quants_by_damping.items(): + + # compute C0^(-1/2) + invsqrtC0 = math_ops.matmul( + eigen_V * (eigen_e + damping)**(-0.5), eigen_V, transpose_b=True) + + # Might need to enforce symmetry lost due to numerical issues. + invsqrtC0 = (invsqrtC0 + array_ops.transpose(invsqrtC0)) / 2.0 + + # Compute the product C0^(-1/2) * C1 + invsqrtC0C1 = math_ops.matmul(invsqrtC0, C1) + + # hPsi = C0^(-1/2) * C1 * C0^(-1/2) (hPsi means hat{Psi}) + hPsi = math_ops.matmul(invsqrtC0C1, invsqrtC0) + + # Compute the decomposition E*diag(mu)*E^T = hPsi^T * hPsi + # Note that we using the notation mu instead of "m" for the eigenvalues. + # Instead of computing the product hPsi^T * hPsi and then doing an + # eigen-decomposition of this we just compute the SVD of hPsi and then + # square the singular values to get the eigenvalues. For a justification + # of this approach, see: + # https://en.wikipedia.org/wiki/Singular-value_decomposition#Relation_to_eigenvalue_decomposition + sqrtmu, _, E = linalg_ops.svd(hPsi) + mu = math_ops.square(sqrtmu) + + # Mathematically, the eigenvalues should not should not exceed 1.0, but + # due to numerical issues, or possible issues with inconsistent + # values of C1 and (the eigen-decomposition of) C0 they might. So + # we enforce this condition. + mu = math_ops.minimum(mu, 1.0) + + # P = (C0^(-1/2) * C1)^T * C0^(-1/2) = C_1^T * C_0^(-1) + Pmat = math_ops.matmul(invsqrtC0C1, invsqrtC0, transpose_a=True) + + # K = C_0^(-1/2) * E + Kmat = math_ops.matmul(invsqrtC0, E) + + ops.append(Pmat_var.assign(Pmat)) + ops.append(Kmat_var.assign(Kmat)) + ops.append(mu_var.assign(mu)) + + return [control_flow_ops.group(*ops)] + + # pylint: enable=invalid-name diff --git a/tensorflow/contrib/kfac/python/ops/fisher_factors_lib.py b/tensorflow/contrib/kfac/python/ops/fisher_factors_lib.py index 23ee93cd405bbf719939df89d525c812ee061f8b..ad93919149c287b1932dd2b6bd772c0dab26192d 100644 --- a/tensorflow/contrib/kfac/python/ops/fisher_factors_lib.py +++ b/tensorflow/contrib/kfac/python/ops/fisher_factors_lib.py @@ -41,6 +41,9 @@ _allowed_symbols = [ "ConvOutputKroneckerFactor", "ConvDiagonalFactor", "set_global_constants", + "maybe_colocate_with", + "compute_cov", + "append_homog" ] remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/kfac/python/ops/layer_collection.py b/tensorflow/contrib/kfac/python/ops/layer_collection.py index 3a005ee39dd9400c21ae6c41fad5351d7fff2aac..8d450f04f379701e46a18b2e34bbbd6fcfcce2bb 100644 --- a/tensorflow/contrib/kfac/python/ops/layer_collection.py +++ b/tensorflow/contrib/kfac/python/ops/layer_collection.py @@ -26,7 +26,9 @@ from __future__ import print_function from collections import defaultdict from collections import OrderedDict +from functools import partial +import math import six from tensorflow.contrib.kfac.python.ops import fisher_blocks as fb @@ -57,20 +59,22 @@ _CONV2D_APPROX_TO_BLOCK_TYPES = { APPROX_DIAGONAL_NAME: fb.ConvDiagonalFB, } +APPROX_KRONECKER_INDEP_NAME = "kron_indep" +APPROX_KRONECKER_SERIES_1_NAME = "kron_series_1" +APPROX_KRONECKER_SERIES_2_NAME = "kron_series_2" + +_FULLY_CONNECTED_MULTI_APPROX_TO_BLOCK_TYPES = { + APPROX_KRONECKER_INDEP_NAME: fb.FullyConnectedMultiIndepFB, + APPROX_KRONECKER_SERIES_1_NAME: partial(fb.FullyConnectedSeriesFB, + option=1), + APPROX_KRONECKER_SERIES_2_NAME: partial(fb.FullyConnectedSeriesFB, + option=2) +} + # Possible value for 'reuse' keyword argument. Sets 'reuse' to # tf.get_variable_scope().reuse. VARIABLE_SCOPE = "VARIABLE_SCOPE" -# TODO(jamesmartens): need to add find_canonical_output back into this somewhere - - -def ensure_sequence(obj): - """If `obj` isn't a tuple or list, return a tuple containing `obj`.""" - if isinstance(obj, (tuple, list)): - return obj - else: - return (obj,) - class LayerParametersDict(OrderedDict): """An OrderedDict where keys are Tensors or tuples of Tensors. @@ -130,7 +134,6 @@ class LayerCollection(object): def __init__(self, graph=None, - colocate_cov_ops_with_inputs=False, name="LayerCollection"): self.fisher_blocks = LayerParametersDict() self.fisher_factors = OrderedDict() @@ -142,7 +145,8 @@ class LayerCollection(object): self._default_generic_approximation = APPROX_FULL_NAME self._default_fully_connected_approximation = APPROX_KRONECKER_NAME self._default_convolution_2d_approximation = APPROX_KRONECKER_NAME - self._colocate_cov_ops_with_inputs = colocate_cov_ops_with_inputs + self._default_fully_connected_multi_approximation = ( + APPROX_KRONECKER_SERIES_2_NAME) with variable_scope.variable_scope(None, default_name=name) as scope: self._var_scope = scope.name @@ -152,19 +156,13 @@ class LayerCollection(object): """LossFunctions registered with this LayerCollection.""" return list(self._loss_dict.values()) - def is_variable_registered(self, variable): - """Checks whether the variable has already been registered. - - Args: - variable: A single variable or tensor. - Returns: - True if the variable has been registered either by itself or as part of a - tuple. - """ - return any([ - variable in key if isinstance(key, (tuple, list)) else variable == key - for key in self.fisher_blocks.keys() - ]) + @property + def registered_variables(self): + """A tuple of all of the variables currently registered.""" + tuple_of_tuples = (utils.ensure_sequence(key) for key, block + in six.iteritems(self.fisher_blocks)) + flat_tuple = tuple(item for tuple_ in tuple_of_tuples for item in tuple_) + return flat_tuple @property def linked_parameters(self): @@ -213,6 +211,16 @@ class LayerCollection(object): value)) self._default_convolution_2d_approximation = value + @property + def default_fully_connected_multi_approximation(self): + return self._default_fully_connected_multi_approximation + + def set_default_fully_connected_multi_approximation(self, value): + if value not in _FULLY_CONNECTED_MULTI_APPROX_TO_BLOCK_TYPES: + raise ValueError("{} is not a valid approximation for a fully-connected " + "multi layer.".format(value)) + self._default_fully_connected_multi_approximation = value + def register_block(self, layer_key, fisher_block, reuse=VARIABLE_SCOPE): """Validates and registers the layer_key associated with the fisher_block. @@ -221,7 +229,7 @@ class LayerCollection(object): existing registrations and to register if valid. fisher_block: The associated `FisherBlock`. reuse: Method to use for inserting new `FisherBlock`s. One of True, False, - or VARIABLE_SCOPE. + or 'VARIABLE_SCOPE'. Raises: ValueError: If `layer_key` was already registered and reuse is `False`, @@ -258,9 +266,9 @@ class LayerCollection(object): variable_to_block = { var: (params, block) for (params, block) in self.fisher_blocks.items() - for var in ensure_sequence(params) + for var in utils.ensure_sequence(params) } - for variable in ensure_sequence(layer_key): + for variable in utils.ensure_sequence(layer_key): if variable in variable_to_block: prev_key, prev_block = variable_to_block[variable] raise ValueError( @@ -272,13 +280,65 @@ class LayerCollection(object): def get_use_count_map(self): """Returns a dict of variables to their number of registrations.""" + # TODO(b/70283403): Reimplement this in the old way, where each + # registration function would be responsible for incrementing the count. + # Also, this version has a bug: it won't do the right thing for generic + # registration for parameters that are shared. i.e. it won't set the use + # count to infinity. vars_to_uses = defaultdict(int) for key, block in six.iteritems(self.fisher_blocks): - key = key if isinstance(key, (tuple, list)) else (key,) + n = ( + block.num_inputs()*block.num_registered_minibatches if isinstance( + block, (fb.FullyConnectedSeriesFB, fb.FullyConnectedMultiIndepFB)) + else block.num_registered_minibatches) + key = utils.ensure_sequence(key) for k in key: - vars_to_uses[k] += block.num_registered_minibatches + vars_to_uses[k] += n return vars_to_uses + def check_registration(self, variables): + """Checks that all variable uses have been registered properly. + + Args: + variables: List of variables. + + Raises: + ValueError: If any registered variables are not included in the list. + ValueError: If any variable in the list is not registered. + ValueError: If any variable in the list is registered with the wrong + number of "uses" in the subgraph recorded (vs the number of times that + variable is actually used in the subgraph). + """ + # Note that overlapping parameters (i.e. those that share variables) will + # be caught by layer_collection.LayerParametersDict during registration. + + reg_use_map = self.get_use_count_map() + + error_messages = [] + + for var in variables: + total_uses = self.subgraph.variable_uses(var) + reg_uses = reg_use_map[var] + + if reg_uses == 0: + error_messages.append("Variable {} not registered.".format(var)) + elif (not math.isinf(reg_uses)) and reg_uses != total_uses: + error_messages.append( + "Variable {} registered with wrong number of uses ({} " + "registrations vs {} uses).".format(var, reg_uses, total_uses)) + + num_get_vars = len(reg_use_map) + + if num_get_vars > len(variables): + error_messages.append("{} registered variables were not included in list." + .format(num_get_vars - len(variables))) + + if error_messages: + error_messages = [ + "Found the following errors with variable registration:" + ] + error_messages + raise ValueError("\n\t".join(error_messages)) + def get_blocks(self): return self.fisher_blocks.values() @@ -312,12 +372,12 @@ class LayerCollection(object): ValueError: If the parameters were already registered in a layer or identified as part of an incompatible group. """ - params = frozenset(ensure_sequence(params)) + params = frozenset(utils.ensure_sequence(params)) # Check if any of the variables in 'params' is already in # 'self.fisher_blocks.keys()'. for registered_params, fisher_block in self.fisher_blocks.items(): - registered_params_set = set(ensure_sequence(registered_params)) + registered_params_set = set(utils.ensure_sequence(registered_params)) for variable in params: if (variable in registered_params_set and params != registered_params_set): @@ -351,7 +411,7 @@ class LayerCollection(object): def _get_linked_approx(self, params): """If params were linked, return their specified approximation.""" - params_set = frozenset(ensure_sequence(params)) + params_set = frozenset(utils.ensure_sequence(params)) if params_set in self.linked_parameters: return self.linked_parameters[params_set] else: @@ -370,11 +430,11 @@ class LayerCollection(object): this layer. Weight matrix should have shape [input_size, output_size]. Bias should have shape [output_size]. inputs: Tensor of shape [batch_size, input_size]. Inputs to layer. - outputs: Tensor of shape [batch_size, output_size]. Preactivations + outputs: Tensor of shape [batch_size, output_size]. Outputs produced by layer. - approx: str. One of APPROX_KRONECKER_NAME or APPROX_DIAGONAL_NAME. + approx: str. One of "kron" or "diagonal". reuse: bool or str. If True, reuse an existing FisherBlock. If False, - create a new FisherBlock. If VARIABLE_SCOPE, use + create a new FisherBlock. If "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. Raises: @@ -416,10 +476,10 @@ class LayerCollection(object): inputs: Tensor of shape [batch_size, height, width, in_channels]. Inputs to layer. outputs: Tensor of shape [batch_size, height, width, out_channels]. - Preactivations produced by layer. - approx: str. One of APPROX_KRONECKER_NAME or APPROX_DIAGONAL_NAME. + Output produced by layer. + approx: str. One of "kron" or "diagonal". reuse: bool or str. If True, reuse an existing FisherBlock. If False, - create a new FisherBlock. If VARIABLE_SCOPE, use + create a new FisherBlock. If "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. Raises: @@ -449,14 +509,11 @@ class LayerCollection(object): """Registers a generic layer. Args: - params: Tensor or 2-tuple of Tensors corresponding to weight and bias of - this layer. Weight matrix should have shape [kernel_height, - kernel_width, in_channels, out_channels]. Bias should have shape - [out_channels]. + params: Tensor or tuple of Tensors corresponding to the parameters. batch_size: 0-D Tensor. Size of the minibatch. - approx: str. One of APPROX_KRONECKER_NAME or APPROX_DIAGONAL_NAME. + approx: str. One of "full" or "diagonal". reuse: bool or str. If True, reuse an existing FisherBlock. If False, - create a new FisherBlock. If VARIABLE_SCOPE, use + create a new FisherBlock. If "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. Raises: @@ -477,6 +534,47 @@ class LayerCollection(object): block = self.register_block(params, block_type(self, params), reuse=reuse) block.register_additional_minibatch(batch_size) + def register_fully_connected_multi(self, params, inputs, outputs, + approx=None): + """Register fully connected layers with shared parameters. + + This can handle general fully-connected layers with shared parameters, but + has specialized approximations to deal with the case where there is a + meaningful linear order to the share instances (such as in an RNN). + + Args: + params: Tensor or 2-tuple of Tensors corresponding to weight and bias of + this layer. Weight matrix should have shape [input_size, output_size]. + Bias should have shape [output_size]. + inputs: A list of tensors, each of shape [batch_size, input_size]. Inputs + to layer. In the case of RNNs, one Tensor per time step. + outputs: A list of tensors, the same length as 'inputs', each of shape + [batch_size, output_size]. Outputs produced by layer. In the case of + RNNs, one Tensor per time step. + approx: str. One of "kron_indep", "kron_series_1", or "kron_series_2". + + Raises: + ValueError: For improper value to 'approx'. + """ + if approx is None: + approx = self._get_linked_approx(params) + if approx is None: + approx = self.default_fully_connected_multi_approximation + has_bias = isinstance(params, (tuple, list)) + + # TODO(b/70283649): something along the lines of find_canonical_output + # should be added back in here (and for the other block types, arguably). + + if approx not in _FULLY_CONNECTED_MULTI_APPROX_TO_BLOCK_TYPES: + raise ValueError("Bad value {} for approx.".format(approx)) + block_type = _FULLY_CONNECTED_MULTI_APPROX_TO_BLOCK_TYPES[approx] + + # For now we don't support multiple minibatches for this type of layer, so + # we set reuse=False + self.register_block(params, + block_type(self, inputs, outputs, has_bias=has_bias), + reuse=False) + def register_categorical_predictive_distribution(self, logits, seed=None, @@ -619,7 +717,6 @@ class LayerCollection(object): key = cls, args if key not in self.fisher_factors: - colo = self._colocate_cov_ops_with_inputs with variable_scope.variable_scope(self._var_scope): - self.fisher_factors[key] = cls(*args, colocate_cov_ops_with_inputs=colo) + self.fisher_factors[key] = cls(*args) return self.fisher_factors[key] diff --git a/tensorflow/contrib/kfac/python/ops/layer_collection_lib.py b/tensorflow/contrib/kfac/python/ops/layer_collection_lib.py index d6bf61a210203dd74d4e93b65005f660b1fab4ff..f8aa230d9ca1f542950f56b1e6cf1ab7ccd3d05f 100644 --- a/tensorflow/contrib/kfac/python/ops/layer_collection_lib.py +++ b/tensorflow/contrib/kfac/python/ops/layer_collection_lib.py @@ -36,6 +36,9 @@ _allowed_symbols = [ "APPROX_DIAGONAL_NAME", "APPROX_FULL_NAME", "VARIABLE_SCOPE", + "APPROX_KRONECKER_INDEP_NAME", + "APPROX_KRONECKER_SERIES_1_NAME", + "APPROX_KRONECKER_SERIES_2_NAME" ] remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/kfac/python/ops/loss_functions.py b/tensorflow/contrib/kfac/python/ops/loss_functions.py index e2e5bc3ffea3e52087c24802948bc8260e3b199a..2daead2a7180fe57b715bd896303cd4c3fbdaca8 100644 --- a/tensorflow/contrib/kfac/python/ops/loss_functions.py +++ b/tensorflow/contrib/kfac/python/ops/loss_functions.py @@ -22,6 +22,7 @@ import abc import six +from tensorflow.contrib.distributions.python.ops import onehot_categorical from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops @@ -91,13 +92,13 @@ class LossFunction(object): @abc.abstractmethod def _evaluate(self, targets): - """Evaluates the log probability of the targets. + """Evaluates the negative log probability of the targets. Args: targets: Tensor that distribution can calculate log_prob() of. Returns: - log probability of each target, summed across all targets. + negative log probability of each target, summed across all targets. """ pass @@ -785,3 +786,16 @@ def insert_slice_in_zeros(slice_to_insert, dim, dim_size, position): after[dim] = dim_size - position - 1 return array_ops.pad(slice_to_insert, list(zip(before, after))) + + +class OnehotCategoricalLogitsNegativeLogProbLoss( + CategoricalLogitsNegativeLogProbLoss): + """Neg log prob loss for a categorical distribution with onehot targets. + + Identical to CategoricalLogitsNegativeLogProbLoss except that the underlying + distribution is OneHotCategorical as opposed to Categorical. + """ + + @property + def dist(self): + return onehot_categorical.OneHotCategorical(logits=self._logits) diff --git a/tensorflow/contrib/kfac/python/ops/loss_functions_lib.py b/tensorflow/contrib/kfac/python/ops/loss_functions_lib.py index e9bb4f14e9e24128382832fcdaccdc9b24017046..705a871d482565897e7ac850327729a6186f1746 100644 --- a/tensorflow/contrib/kfac/python/ops/loss_functions_lib.py +++ b/tensorflow/contrib/kfac/python/ops/loss_functions_lib.py @@ -31,6 +31,7 @@ _allowed_symbols = [ "NormalMeanNegativeLogProbLoss", "NormalMeanVarianceNegativeLogProbLoss", "CategoricalLogitsNegativeLogProbLoss", + "OnehotCategoricalLogitsNegativeLogProbLoss", "MultiBernoulliNegativeLogProbLoss", "MultiBernoulliNegativeLogProbLoss", "insert_slice_in_zeros", diff --git a/tensorflow/contrib/kfac/python/ops/optimizer.py b/tensorflow/contrib/kfac/python/ops/optimizer.py index ecf7f3e4e5ab7d9c151f760fdab733bc3830e37b..1974b07acfc879dc4bc844db9af88fd1043d6698 100644 --- a/tensorflow/contrib/kfac/python/ops/optimizer.py +++ b/tensorflow/contrib/kfac/python/ops/optimizer.py @@ -41,12 +41,12 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): damping, layer_collection, var_list=None, - momentum=0., + momentum=0.9, momentum_type="regular", norm_constraint=None, name="KFAC", estimation_mode="gradients", - colocate_gradients_with_ops=False, + colocate_gradients_with_ops=True, cov_devices=None, inv_devices=None): """Initializes the KFAC optimizer with the given settings. @@ -70,8 +70,8 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): var_list: Optional list or tuple of variables to train. Defaults to the list of variables collected in the graph under the key `GraphKeys.TRAINABLE_VARIABLES`. - momentum: The momentum value for this optimizer. Only applies when - momentum_type is 'regular' or 'adam'. (Default: 0) + momentum: The momentum decay constant to use. Only applies when + momentum_type is 'regular' or 'adam'. (Default: 0.9) momentum_type: The type of momentum to use in this optimizer, one of 'regular', 'adam', or 'qmodel'. (Default: 'regular') norm_constraint: float or Tensor. If specified, the update is scaled down @@ -85,6 +85,7 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): more a more detailed description of these options. colocate_gradients_with_ops: Whether we should request gradients we compute in the estimator be colocated with their respective ops. + (Default: True) cov_devices: Iterable of device strings (e.g. '/gpu:0'). Covariance computations will be placed on these devices in a round-robin fashion. Can be None, which means that no devices are specified. @@ -136,12 +137,32 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): self._batch_size = array_ops.shape(layer_collection.losses[0].inputs)[0] self._losses = layer_collection.losses - self.cov_update_op = self._fisher_est.cov_update_op - self.inv_update_op = self._fisher_est.inv_update_op - self.inv_updates_dict = self._fisher_est.inv_updates_dict - super(KfacOptimizer, self).__init__(learning_rate, name=name) + @property + def cov_update_thunks(self): + return self._fisher_est.cov_update_thunks + + @property + def cov_update_ops(self): + return self._fisher_est.cov_update_ops + + @property + def cov_update_op(self): + return self._fisher_est.cov_update_op + + @property + def inv_update_thunks(self): + return self._fisher_est.inv_update_thunks + + @property + def inv_update_ops(self): + return self._fisher_est.inv_update_ops + + @property + def inv_update_op(self): + return self._fisher_est.inv_update_op + @property def variables(self): return self._fisher_est.variables diff --git a/tensorflow/contrib/kfac/python/ops/utils.py b/tensorflow/contrib/kfac/python/ops/utils.py index d5461c9f2ea0512ad7c4f2d393ac8e7f441d1b77..e89508fa46b6e2ce278e5373e6c9d17203ad1ef2 100644 --- a/tensorflow/contrib/kfac/python/ops/utils.py +++ b/tensorflow/contrib/kfac/python/ops/utils.py @@ -20,16 +20,22 @@ from __future__ import print_function import numpy as np +from tensorflow.contrib.tpu.python.ops import tpu_ops +from tensorflow.contrib.tpu.python.tpu import tpu_function from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import variables # Method used for inverting matrices. POSDEF_INV_METHOD = "cholesky" +POSDEF_EIG_METHOD = "self_adjoint" def set_global_constants(posdef_inv_method=None): @@ -161,33 +167,11 @@ def mat2d_to_layer_params(vector_template, mat2d): return array_ops.reshape(mat2d, vector_template.shape) -def compute_pi(left_factor, right_factor): - """Computes the scalar constant pi for Tikhonov regularization/damping. - - pi = sqrt( (trace(A) / dim(A)) / (trace(B) / dim(B)) ) - See section 6.3 of https://arxiv.org/pdf/1503.05671.pdf for details. - - Args: - left_factor: The left Kronecker factor Tensor. - right_factor: The right Kronecker factor Tensor. - - Returns: - The computed scalar constant pi for these Kronecker Factors (as a Tensor). - """ - # Instead of dividing by the dim of the norm, we multiply by the dim of the - # other norm. This works out the same in the ratio. - left_norm = math_ops.trace(left_factor) * right_factor.get_shape().as_list()[ - 0] - right_norm = math_ops.trace(right_factor) * left_factor.get_shape().as_list()[ - 0] - return math_ops.sqrt(left_norm / right_norm) - - def posdef_inv(tensor, damping): """Computes the inverse of tensor + damping * identity.""" identity = linalg_ops.eye(tensor.shape.as_list()[0], dtype=tensor.dtype) damping = math_ops.cast(damping, dtype=tensor.dtype) - return posdef_inv_funcs[POSDEF_INV_METHOD](tensor, identity, damping) + return posdef_inv_functions[POSDEF_INV_METHOD](tensor, identity, damping) def posdef_inv_matrix_inverse(tensor, identity, damping): @@ -209,23 +193,51 @@ def posdef_inv_eig(tensor, identity, damping): eigenvectors / eigenvalues, eigenvectors, transpose_b=True) -posdef_inv_funcs = { +posdef_inv_functions = { "matrix_inverse": posdef_inv_matrix_inverse, "cholesky": posdef_inv_cholesky, "eig": posdef_inv_eig, } +def posdef_eig(mat): + """Computes the eigendecomposition of a positive semidefinite matrix.""" + return posdef_eig_functions[POSDEF_EIG_METHOD](mat) + + +def posdef_eig_svd(mat): + """Computes the singular values and left singular vectors of a matrix.""" + evals, evecs, _ = linalg_ops.svd(mat) + + return evals, evecs + + +def posdef_eig_self_adjoint(mat): + """Computes eigendecomposition using self_adjoint_eig.""" + evals, evecs = linalg_ops.self_adjoint_eig(mat) + evals = math_ops.abs(evals) # Should be equivalent to svd approach. + + return evals, evecs + + +posdef_eig_functions = { + "self_adjoint": posdef_eig_self_adjoint, + "svd": posdef_eig_svd, +} + + class SubGraph(object): """Defines a subgraph given by all the dependencies of a given set of outputs. """ def __init__(self, outputs): + # Set of all ancestor Tensors, Ops to 'outputs'. self._members = set() self._recurse_add(outputs) def _recurse_add(self, nodes): + """Recursively adds all of nodes' ancestors.""" for node in nodes: if node in self._members: continue @@ -241,8 +253,25 @@ class SubGraph(object): return node in self._members def variable_uses(self, var): - """Computes number of times a variable is used.""" - return len(self._members.intersection(set(var.value().consumers()))) + """Computes number of times a variable is used. + + Args: + var: Variable or ResourceVariable instance. + + Returns: + Number of times a variable is used within this subgraph. + + Raises: + ValueError: If 'var' is not a variable type. + """ + if isinstance(var, resource_variable_ops.ResourceVariable): + var = var.handle + elif isinstance(var, variables.Variable): + var = var.value() + else: + raise ValueError("%s does not appear to be a variable." % str(var)) + + return len(self._members.intersection(set(var.consumers()))) def filter_list(self, node_list): """Filters 'node_list' to nodes in this subgraph.""" @@ -287,5 +316,109 @@ def fwd_gradients(ys, xs, grad_xs=None, stop_gradients=None): return dysdx + +def on_tpu(): + """Returns True when building a TPU computation.""" + return tpu_function.get_tpu_context().number_of_shards is not None + + +def cross_replica_mean(tensor, name=None): + """Takes mean value of a Tensor across all TPU cores. + + Args: + tensor: Tensor to be synchronized. + name: None or string. Name of Op. + + Returns: + Average of Tensor across all TPU cores. + + Raises: + ValueError: If called outside of TPU context. + """ + with ops.name_scope(name, "cross_replica_mean", [tensor]): + num_shards = tpu_function.get_tpu_context().number_of_shards + if num_shards is None: + raise ValueError( + "Cannot take cross_replica_mean() outside of TPU Context.") + if num_shards == 1: + return tensor + return tpu_ops.cross_replica_sum(tensor / num_shards) + + +def ensure_sequence(obj): + """If `obj` isn't a tuple or list, return a tuple containing `obj`.""" + if isinstance(obj, (tuple, list)): + return obj + else: + return (obj,) + + +def batch_execute(global_step, thunks, batch_size, name=None): + """Executes a subset of ops per global step. + + Given a list of thunks, each of which produces a single stateful op, + ensures that exactly 'batch_size' ops are run per global step. Ops are + scheduled in a round-robin fashion. For example, with 3 ops + + global_step | op0 | op1 | op2 + ------------+-----+-----+----- + 0 | x | x | + ------------+-----+-----+----- + 1 | x | | x + ------------+-----+-----+----- + 2 | | x | x + ------------+-----+-----+----- + 3 | x | x | + ------------+-----+-----+----- + 4 | x | | x + + Does not guarantee order of op execution within a single global step. + + Args: + global_step: Tensor indicating time. Determines which ops run. + thunks: List of thunks. Each thunk encapsulates one op. Return values are + ignored. + batch_size: int. Number of ops to execute per global_step. + name: string or None. Name scope for newly added ops. + + Returns: + List of ops. Exactly 'batch_size' ops are guaranteed to have an effect + every global step. + """ + + def true_fn(thunk): + """Ensures thunk is executed and returns an Op (not a Tensor).""" + + def result(): + with ops.control_dependencies([thunk()]): + return control_flow_ops.no_op() + + return result + + def false_fn(_): + """Executes a no-op.""" + + def result(): + return control_flow_ops.no_op() + + return result + + with ops.name_scope(name, "batch_execute"): + true_fns = [true_fn(thunk) for thunk in thunks] + false_fns = [false_fn(thunk) for thunk in thunks] + num_thunks = len(thunks) + conditions = [ + math_ops.less( + math_ops.mod(batch_size - 1 + global_step * batch_size - j, + num_thunks), batch_size) for j in range(num_thunks) + ] + result = [ + control_flow_ops.cond(condition, true_fn, false_fn) + for (condition, true_fn, + false_fn) in zip(conditions, true_fns, false_fns) + ] + return result + + # TODO(b/69623235): Add a function for finding tensors that share gradients # to eliminate redundant fisher factor computations. diff --git a/tensorflow/contrib/kfac/python/ops/utils_lib.py b/tensorflow/contrib/kfac/python/ops/utils_lib.py index 9df07d69aad5e61f9cfb994c9a63fdec04f025fe..cc48e3c69f24c2abd343e2e120d3589cd323fcdc 100644 --- a/tensorflow/contrib/kfac/python/ops/utils_lib.py +++ b/tensorflow/contrib/kfac/python/ops/utils_lib.py @@ -30,7 +30,6 @@ _allowed_symbols = [ "kronecker_product", "layer_params_to_mat2d", "mat2d_to_layer_params", - "compute_pi", "posdef_inv", "posdef_inv_matrix_inverse", "posdef_inv_cholesky", @@ -38,6 +37,8 @@ _allowed_symbols = [ "SubGraph", "generate_random_signs", "fwd_gradients", + "ensure_sequence", + "batch_execute", ] remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/layers/python/layers/feature_column.py b/tensorflow/contrib/layers/python/layers/feature_column.py index 226d933d85d91600e36ffb84212703e10455bfbb..b7d34d6435789e54403926a342481971e854b449 100644 --- a/tensorflow/contrib/layers/python/layers/feature_column.py +++ b/tensorflow/contrib/layers/python/layers/feature_column.py @@ -156,6 +156,10 @@ from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import deprecation +# Imports the core `InputLayer` symbol in contrib during development. +InputLayer = fc_core.InputLayer # pylint: disable=invalid-name + + class _LinearEmbeddingLookupArguments( collections.namedtuple("_LinearEmbeddingLookupArguments", ["input_tensor", @@ -521,7 +525,7 @@ def sparse_column_with_integerized_feature(column_name, Args: column_name: A string defining sparse column name. - bucket_size: An int that is > 1. The number of buckets. It should be bigger + bucket_size: An int that is >= 1. The number of buckets. It should be bigger than maximum feature. In other words features in this column should be an int64 in range [0, bucket_size) combiner: A string specifying how to reduce if the sparse column is @@ -539,7 +543,7 @@ def sparse_column_with_integerized_feature(column_name, An integerized _SparseColumn definition. Raises: - ValueError: bucket_size is not greater than 1. + ValueError: bucket_size is less than 1. ValueError: dtype is not integer. """ return _SparseColumnIntegerized( @@ -748,6 +752,10 @@ class _WeightedSparseColumn( {self.weight_column_name: parsing_ops.VarLenFeature(self.dtype)}) return config + @property + def lookup_config(self): + return self.sparse_id_column.lookup_config + @property def key(self): """Returns a string which will be used as a key when we do sorting.""" diff --git a/tensorflow/contrib/layers/python/layers/feature_column_ops.py b/tensorflow/contrib/layers/python/layers/feature_column_ops.py index fa0047f05d893f6543ddb1680824a32469e13293..78affea44cbfb92523063968dbc1be98841854db 100644 --- a/tensorflow/contrib/layers/python/layers/feature_column_ops.py +++ b/tensorflow/contrib/layers/python/layers/feature_column_ops.py @@ -97,10 +97,13 @@ def _input_from_feature_columns(columns_to_tensors, trainable, scope, output_rank, - default_name): + default_name, + cols_to_outs=None): """Implementation of `input_from(_sequence)_feature_columns`.""" columns_to_tensors = columns_to_tensors.copy() check_feature_columns(feature_columns) + if cols_to_outs is not None and not isinstance(cols_to_outs, dict): + raise ValueError('cols_to_outs must be a dict unless None') with variable_scope.variable_scope(scope, default_name=default_name, values=columns_to_tensors.values()): @@ -144,6 +147,8 @@ def _input_from_feature_columns(columns_to_tensors, except ValueError as e: raise ValueError('Error creating input layer for column: {}.\n' '{}, {}'.format(column.name, e, ee)) + if cols_to_outs is not None: + cols_to_outs[column] = output_tensors[-1] return array_ops.concat(output_tensors, output_rank - 1) @@ -151,7 +156,8 @@ def input_from_feature_columns(columns_to_tensors, feature_columns, weight_collections=None, trainable=True, - scope=None): + scope=None, + cols_to_outs=None): """A tf.contrib.layers style input layer builder based on FeatureColumns. Generally a single example in training data is described with feature columns. @@ -196,6 +202,8 @@ def input_from_feature_columns(columns_to_tensors, trainable: If `True` also add variables to the graph collection `GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable). scope: Optional scope for variable_scope. + cols_to_outs: Optional dict from feature column to output tensor, + which is concatenated into the returned tensor. Returns: A Tensor which can be consumed by hidden layers in the neural network. @@ -209,7 +217,8 @@ def input_from_feature_columns(columns_to_tensors, trainable, scope, output_rank=2, - default_name='input_from_feature_columns') + default_name='input_from_feature_columns', + cols_to_outs=cols_to_outs) @experimental diff --git a/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py b/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py index fbfa0e32de55edab3c90189ddfe05ab826ac9167..e6bbd86ab722c4e853a59f816bed8a8ac1fe9ede 100644 --- a/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py +++ b/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py @@ -607,6 +607,31 @@ class CreateInputLayersForDNNsTest(test.TestCase): # Verify cross compatibility: Core builder output should equal to contrib. self.assertAllEqual(output.eval().shape, output_core.eval().shape) + def testAllDNNColumnsWithColumnwiseOutputs(self): + sparse_column = feature_column.sparse_column_with_keys( + "ids", ["a", "b", "c", "unseen"]) + real_valued_column = feature_column.real_valued_column("income", 2) + one_hot_column = feature_column.one_hot_column(sparse_column) + embedding_column = feature_column.embedding_column(sparse_column, 10) + features = { + "ids": + sparse_tensor.SparseTensor( + values=["c", "b", "a"], + indices=[[0, 0], [1, 0], [2, 0]], + dense_shape=[3, 1]), + "income": + constant_op.constant([[20.3, 10], [110.3, 0.4], [-3.0, 30.4]]), + } + columns = [one_hot_column, embedding_column, real_valued_column] + cols_to_outs = {} + feature_column_ops.input_from_feature_columns( + features, columns, cols_to_outs=cols_to_outs) + with self.test_session(): + variables_lib.global_variables_initializer().run() + lookup_ops.tables_initializer().run() + for column in columns: + self.assertTrue(column in cols_to_outs) + def testRealValuedColumn(self): real_valued = feature_column.real_valued_column("price") features = {"price": constant_op.constant([[20.], [110], [-3]])} diff --git a/tensorflow/contrib/layers/python/layers/feature_column_test.py b/tensorflow/contrib/layers/python/layers/feature_column_test.py index 5ae885b7202357326bd8494d382adb57fa636d20..2eaea231776bd2f5fb8bb4bd422074beacd61720 100644 --- a/tensorflow/contrib/layers/python/layers/feature_column_test.py +++ b/tensorflow/contrib/layers/python/layers/feature_column_test.py @@ -102,6 +102,16 @@ class FeatureColumnTest(test.TestCase): weighted_ids = fc.weighted_sparse_column(ids, "weights") self.assertEqual(weighted_ids.name, "ids_weighted_by_weights") + def testWeightedSparseColumnWithVocabularyFile(self): + ids = fc.sparse_column_with_vocabulary_file( + "ids", "a_file", num_oov_buckets=7, vocab_size=3) + weighted_ids = fc.weighted_sparse_column(ids, "weights") + self.assertEqual(weighted_ids.name, "ids_weighted_by_weights") + self.assertEqual(weighted_ids.lookup_config, ids.lookup_config) + self.assertEqual(weighted_ids.lookup_config.vocab_size, 3) + self.assertEqual(weighted_ids.lookup_config.num_oov_buckets, 7) + self.assertEqual(weighted_ids.lookup_config.vocabulary_file, "a_file") + def testWeightedSparseColumnDeepCopy(self): ids = fc.sparse_column_with_keys("ids", ["marlo", "omar", "stringer"]) weighted = fc.weighted_sparse_column(ids, "weights") diff --git a/tensorflow/contrib/layers/python/layers/initializers.py b/tensorflow/contrib/layers/python/layers/initializers.py index b12a882d9ae88f7cf4f920cfa5872e5de1c67290..51610f21b24f1d40f26630cc1e69ca723d130639 100644 --- a/tensorflow/contrib/layers/python/layers/initializers.py +++ b/tensorflow/contrib/layers/python/layers/initializers.py @@ -79,7 +79,8 @@ def variance_scaling_initializer(factor=2.0, mode='FAN_IN', uniform=False, ``` * To get [Delving Deep into Rectifiers]( - http://arxiv.org/pdf/1502.01852v1.pdf), use (Default):
+ http://arxiv.org/pdf/1502.01852v1.pdf) (also know as the "MSRA + initialization"), use (Default):
`factor=2.0 mode='FAN_IN' uniform=False` * To get [Convolutional Architecture for Fast Feature Embedding]( http://arxiv.org/abs/1408.5093), use:
diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py index 6cd586a5f016c76cc52b340bfd0d32fa08f23748..f3229a1605c72c61d0d1cc638a9a21048ac60cbe 100644 --- a/tensorflow/contrib/layers/python/layers/layers.py +++ b/tensorflow/contrib/layers/python/layers/layers.py @@ -1896,7 +1896,7 @@ class GDN(base.Layer): outputs.set_shape(inputs.get_shape()) return outputs - def _compute_output_shape(self, input_shape): + def compute_output_shape(self, input_shape): channel_axis = self._channel_axis() input_shape = tensor_shape.TensorShape(input_shape) if not 3 <= input_shape.ndim <= 5: @@ -2561,7 +2561,10 @@ def separable_convolution2d( regularizer=weights_regularizer, trainable=trainable, collections=weights_collections) - strides = [1, 1, stride_h, stride_w] if data_format.startswith('NC') else [1, stride_h, stride_w, 1] + strides = [1, 1, stride_h, + stride_w] if data_format.startswith('NC') else [ + 1, stride_h, stride_w, 1 + ] outputs = nn.depthwise_conv2d(inputs, depthwise_weights, strides, padding, rate=utils.two_element_tuple(rate), @@ -2651,7 +2654,7 @@ def spatial_softmax(features, ValueError: If unexpected data_format specified. ValueError: If num_channels dimension is unspecified. """ - with variable_scope.variable_scope(name, 'spatial_softmax'): + with variable_scope.variable_scope(name, 'spatial_softmax'): shape = array_ops.shape(features) static_shape = features.shape if data_format == DATA_FORMAT_NHWC: @@ -2663,30 +2666,39 @@ def spatial_softmax(features, if num_channels.value is None: raise ValueError('The num_channels dimension of the inputs to ' '`spatial_softmax` should be defined. Found `None`.') - - with ops.name_scope('spatial_softmax_op', 'spatial_softmax_op', [features]): + + with ops.name_scope('spatial_softmax_op', 'spatial_softmax_op', [features]): # Create tensors for x and y coordinate values, scaled to range [-1, 1]. pos_x, pos_y = array_ops.meshgrid(math_ops.lin_space(-1., 1., num=height), math_ops.lin_space(-1., 1., num=width), indexing='ij') pos_x = array_ops.reshape(pos_x, [height * width]) pos_y = array_ops.reshape(pos_y, [height * width]) + if temperature is None: - temperature_collections = utils.get_variable_collections( - variables_collections, 'temperature') - temperature = variables.model_variable( - 'temperature', - shape=(), - dtype=dtypes.float32, - initializer=init_ops.ones_initializer(), - collections=temperature_collections, - trainable=trainable) + temp_initializer = init_ops.ones_initializer() + else: + temp_initializer = init_ops.constant_initializer(temperature) + + if not trainable: + temp_collections = None + else: + temp_collections = utils.get_variable_collections( + variables_collections, 'temperature') + + temperature = variables.model_variable( + 'temperature', + shape=(), + dtype=dtypes.float32, + initializer=temp_initializer, + collections=temp_collections, + trainable=trainable) if data_format == 'NCHW': features = array_ops.reshape(features, [-1, height * width]) else: features = array_ops.reshape( array_ops.transpose(features, [0, 3, 1, 2]), [-1, height * width]) - + softmax_attention = nn.softmax(features/temperature) expected_x = math_ops.reduce_sum( pos_x * softmax_attention, [1], keep_dims=True) @@ -2699,8 +2711,6 @@ def spatial_softmax(features, return feature_keypoints - - def stack(inputs, layer, stack_args, **kwargs): """Builds a stack of layers by applying layer repeatedly using stack_args. diff --git a/tensorflow/contrib/layers/python/layers/layers_test.py b/tensorflow/contrib/layers/python/layers/layers_test.py index a05e464a26d8167707ce6d6455aca50b0416aa1f..a9bdbe01387653bada1f1e5e9948db7a737eb600 100644 --- a/tensorflow/contrib/layers/python/layers/layers_test.py +++ b/tensorflow/contrib/layers/python/layers/layers_test.py @@ -1747,6 +1747,12 @@ class BatchNormTest(test.TestCase): expected_var *= correction_factor return expected_var, correction_factor + def testBatchNormCenterFalse(self): + a = array_ops.placeholder(dtype=dtypes.float32, shape=(10, 10, 10, 10)) + # Test that center=False builds a valid graph. + _layers.batch_norm(a, center=False, data_format='NCHW', + zero_debias_moving_mean=True) + def testUnknownShape(self): with ops.Graph().as_default() as g, self.test_session(g): inputs = array_ops.placeholder(dtype=dtypes.float32) @@ -3231,7 +3237,11 @@ class SeparableConv2dTest(test.TestCase): images = random_ops.random_uniform((5, height, width, 3), seed=1) regularizer = regularizers.l2_regularizer(0.01) layers_lib.separable_conv2d( - images, 32, [3, 3], 2, weights_regularizer=regularizer) + images, + 32, [3, 3], + 2, + weights_regularizer=regularizer, + weights_initializer=init_ops.ones_initializer()) self.assertEqual( len(ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)), 2) weight_decay = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)[0] @@ -3239,12 +3249,31 @@ class SeparableConv2dTest(test.TestCase): weight_decay.op.name, 'SeparableConv2d/depthwise_kernel/Regularizer/l2_regularizer') sess.run(variables_lib.global_variables_initializer()) - self.assertLessEqual(sess.run(weight_decay), 0.05) + depth_weight_one = sess.run(weight_decay) weight_decay = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)[1] self.assertEqual( weight_decay.op.name, 'SeparableConv2d/pointwise_kernel/Regularizer/l2_regularizer') - self.assertLessEqual(sess.run(weight_decay), 0.05) + pointwise_weight_one = sess.run(weight_decay) + + regularizer = regularizers.l2_regularizer(1.0) + layers_lib.separable_conv2d( + images, + 32, [3, 3], + 2, + weights_regularizer=regularizer, + weights_initializer=init_ops.ones_initializer()) + self.assertEqual( + len(ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)), 4) + weight_decay = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)[2] + sess.run(variables_lib.global_variables_initializer()) + depth_weight_two = sess.run(weight_decay) + weight_decay = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)[3] + pointwise_weight_two = sess.run(weight_decay) + + self.assertAllClose( + [100.0 * depth_weight_one, 100.0 * pointwise_weight_one], + [depth_weight_two, pointwise_weight_two]) def testReuseConvWithWeightDecay(self): height, width = 3, 3 @@ -3332,11 +3361,18 @@ class SeparableConv2dTest(test.TestCase): batch, height, width = 4, 10, 12 kernel_dim, stride = 3, 2 images = random_ops.random_uniform((batch, 3, height, width), seed=1) - output = layers_lib.separable_conv2d(images, num_outputs=num_filters, kernel_size=[kernel_dim, kernel_dim], - depth_multiplier=2, stride=stride, padding='VALID', data_format='NCHW') - self.assertListEqual( - output.get_shape().as_list(), [batch, correct_output_filters, - (height - kernel_dim + 1) // stride, (width - kernel_dim + 1) // stride]) + output = layers_lib.separable_conv2d( + images, + num_outputs=num_filters, + kernel_size=[kernel_dim, kernel_dim], + depth_multiplier=2, + stride=stride, + padding='VALID', + data_format='NCHW') + self.assertListEqual(output.get_shape().as_list(), [ + batch, correct_output_filters, (height - kernel_dim + 1) // stride, + (width - kernel_dim + 1) // stride + ]) class ScaleGradientTests(test.TestCase): diff --git a/tensorflow/contrib/layers/python/layers/rev_block_lib.py b/tensorflow/contrib/layers/python/layers/rev_block_lib.py index 31a1b38bd4832c5816136cab3297aa22e843b0f3..123275e1fde047cd3772528641b2e3b09742fbdc 100644 --- a/tensorflow/contrib/layers/python/layers/rev_block_lib.py +++ b/tensorflow/contrib/layers/python/layers/rev_block_lib.py @@ -34,12 +34,13 @@ from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.contrib.framework.python import ops as contrib_framework_ops from tensorflow.python.framework import function from tensorflow.python.framework import ops as framework_ops +from tensorflow.python.layers import base from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import math_ops -from tensorflow.python.ops import template from tensorflow.python.ops import variable_scope +from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import nest __all__ = ["rev_block", "RevBlock", "recompute_grad"] @@ -137,7 +138,17 @@ def _rev_block_forward(x1, return y1, y2 -class RevBlock(object): +def _scope_wrap(fn, scope): + + @functools.wraps(fn) + def wrap(*args, **kwargs): + with variable_scope.variable_scope(scope): + return fn(*args, **kwargs) + + return wrap + + +class RevBlock(base.Layer): """Block of reversible layers. See rev_block.""" def __init__(self, @@ -146,7 +157,10 @@ class RevBlock(object): num_layers=1, f_side_input=None, g_side_input=None, - use_efficient_backprop=True): + use_efficient_backprop=True, + name="revblock", + **kwargs): + super(RevBlock, self).__init__(name=name, **kwargs) if isinstance(f, list): assert len(f) == num_layers @@ -158,18 +172,8 @@ class RevBlock(object): else: g = [g] * num_layers - scope_prefix = "revblock/revlayer_%d/" - f_scope = scope_prefix + "f" - g_scope = scope_prefix + "g" - - f = [ - template.make_template(f_scope % i, fn, create_scope_now_=True) - for i, fn in enumerate(f) - ] - g = [ - template.make_template(g_scope % i, fn, create_scope_now_=True) - for i, fn in enumerate(g) - ] + f = [_scope_wrap(fn, "revlayer_%d/f" % i) for i, fn in enumerate(f)] + g = [_scope_wrap(fn, "revlayer_%d/g" % i) for i, fn in enumerate(g)] self.f = f self.g = g @@ -180,6 +184,39 @@ class RevBlock(object): self._use_efficient_backprop = use_efficient_backprop + def call(self, inputs, forward=True): + vs = variable_scope.get_variable_scope() + vars_before = vs.global_variables() + + if forward: + x1, x2 = inputs + out = self._forward(x1, x2) + else: + y1, y2 = inputs + out = self._backward(y1, y2) + + # Add any created variables to the Layer's variable stores + new_vars = vs.global_variables()[len(vars_before):] + train_vars = vs.trainable_variables() + for new_var in new_vars: + if new_var in train_vars: + self._trainable_weights.append(new_var) + else: + self._non_trainable_weights.append(new_var) + + return out + + def forward(self, x1, x2): + return self.apply([x1, x2]) + + def backward(self, y1, y2): + return self.apply([y1, y2], forward=False) + + def build(self, _): + logging.warn("RevBlock constructs its variables on first call, not on " + "build.") + self.built = True + def _efficient_grad_fn(self, inputs, variables, ys, grad_ys): """Custom gradient fn for a block of reversible residual layers.""" side_inputs = inputs[2:] @@ -228,17 +265,18 @@ class RevBlock(object): f.reverse() g.reverse() - for i in xrange(self.num_layers): - ys, grad_ys, f_ret, g_ret = _rev_layer_backward( - ys, grad_ys, f[i], g[i], f_vars[i], self.f_side_input, g_vars[i], - self.g_side_input) + with variable_scope.variable_scope(self.scope_name, reuse=True): + for i in xrange(self.num_layers): + ys, grad_ys, f_ret, g_ret = _rev_layer_backward( + ys, grad_ys, f[i], g[i], f_vars[i], self.f_side_input, g_vars[i], + self.g_side_input) - grad_f_vars, grad_f_side = f_ret - grad_g_vars, grad_g_side = g_ret - f_var_grads.append(grad_f_vars) - g_var_grads.append(grad_g_vars) - f_side_grads.append(grad_f_side) - g_side_grads.append(grad_g_side) + grad_f_vars, grad_f_side = f_ret + grad_g_vars, grad_g_side = g_ret + f_var_grads.append(grad_f_vars) + g_var_grads.append(grad_g_vars) + f_side_grads.append(grad_f_side) + g_side_grads.append(grad_g_side) # Accumulate layer gradients for f_side_input and g_side_input acc_f_side_grads = _acc_grads(*f_side_grads) @@ -265,7 +303,7 @@ class RevBlock(object): grad_x1, grad_x2 = grad_ys return [grad_x1, grad_x2] + side_input_grads, variable_grads - def forward(self, x1, x2): + def _forward(self, x1, x2): """Run forward through the reversible layers.""" side_inputs = [self.f_side_input, self.g_side_input] @@ -275,7 +313,7 @@ class RevBlock(object): self._efficient_grad_fn if self._use_efficient_backprop else None) @_fn_with_custom_grad(custom_grad_fn) - def _forward(x1_, x2_, *flat_side_inputs): + def _forward_wrap(x1_, x2_, *flat_side_inputs): f_side, g_side = nest.pack_sequence_as(side_inputs, flat_side_inputs) return _rev_block_forward( x1_, @@ -287,9 +325,9 @@ class RevBlock(object): g_side_input=g_side, gate_outputs=self._use_efficient_backprop) - return _forward(x1, x2, *flat_side_inputs) + return _forward_wrap(x1, x2, *flat_side_inputs) - def backward(self, y1, y2): + def _backward(self, y1, y2): """Run backward through the reversible layers.""" f = list(self.f) @@ -356,7 +394,14 @@ def rev_block(x1, Returns: y1, y2: tuple of float Tensors. """ - block = RevBlock(f, g, num_layers, f_side_input, g_side_input, is_training) + block = RevBlock( + f=f, + g=g, + num_layers=num_layers, + f_side_input=f_side_input, + g_side_input=g_side_input, + use_efficient_backprop=is_training, + _reuse=variable_scope.get_variable_scope().reuse) return block.forward(x1, x2) diff --git a/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py b/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py index a420753fd5728e7eef4f135d4943d25e8e05d5c2..cbcbcd75114a522b95631e4e7e95c1641b0a9987 100644 --- a/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py +++ b/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py @@ -188,13 +188,46 @@ class RevBlockTest(test.TestCase): def f(x): x = convolutional.conv1d(x, self.CHANNELS // 2, 3, padding="same") - x = core_layers.batch_normalization(x, training=True) + x = layers.batch_norm(x, is_training=True) x = convolutional.conv1d(x, self.CHANNELS // 2, 3, padding="same") - x = core_layers.batch_normalization(x, training=True) + x = layers.batch_norm(x, is_training=True) return x self._testRevBlock(x=x, f=f) + def testReuse(self): + + def f(x): + return core_layers.dense(x, self.CHANNELS // 2) + + def g(x): + return core_layers.dense(x, self.CHANNELS // 2) + + x = random_ops.random_uniform( + [self.BATCH_SIZE, self.CHANNELS], dtype=dtypes.float32) + x1, x2 = array_ops.split(x, 2, axis=-1) + + with variable_scope.variable_scope("test"): + y1, y2 = rev_block_lib.rev_block(x1, x2, f, g, num_layers=self.NUM_LAYERS) + + num_vars_before = len(variables.global_variables()) + + with variable_scope.variable_scope("test", reuse=True): + y1, y2 = rev_block_lib.rev_block(x1, x2, f, g, num_layers=self.NUM_LAYERS) + + num_vars_after = len(variables.global_variables()) + self.assertEqual(num_vars_before, num_vars_after) + + loss = math_ops.reduce_mean(y1 + y2) + _ = gradients_impl.gradients(loss, + [x] + variables.trainable_variables()) + + with variable_scope.variable_scope("test", reuse=True): + y1, y2 = rev_block_lib.rev_block(x1, x2, f, g, num_layers=self.NUM_LAYERS) + + num_vars_after = len(variables.global_variables()) + self.assertEqual(num_vars_before, num_vars_after) + class RecomputeTest(test.TestCase): diff --git a/tensorflow/contrib/learn/BUILD b/tensorflow/contrib/learn/BUILD index 94920db574e07529c28313a78e0128676fcc7970..ee3611ca9385e80d30e42f8405c8ac318e66771b 100644 --- a/tensorflow/contrib/learn/BUILD +++ b/tensorflow/contrib/learn/BUILD @@ -10,7 +10,7 @@ package(default_visibility = [ "//tensorflow:internal", ]) -load("//tensorflow:tensorflow.bzl", "py_test") +load("//tensorflow:tensorflow.bzl", "py_test", "tf_py_test") py_library( name = "learn", @@ -154,12 +154,11 @@ py_test( ], ) -py_test( +tf_py_test( name = "experiment_test", size = "medium", srcs = ["python/learn/experiment_test.py"], - srcs_version = "PY2AND3", - deps = [ + additional_deps = [ ":learn", "//tensorflow/contrib/layers:layers_py", "//tensorflow/core:protos_all_py", @@ -173,6 +172,17 @@ py_test( ], ) +py_test( + name = "export_strategy_test", + size = "small", + srcs = ["python/learn/export_strategy_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":learn", + "//tensorflow/python:client_testlib", + ], +) + py_test( name = "graph_actions_test", size = "small", @@ -346,6 +356,7 @@ py_test( srcs = ["python/learn/estimators/dnn_linear_combined_test.py"], shard_count = 4, srcs_version = "PY2AND3", + tags = ["no_oss"], # flaky b/70524820 deps = [ ":learn", "//tensorflow/contrib/layers:layers_py", @@ -461,6 +472,7 @@ py_test( size = "medium", srcs = ["python/learn/estimators/state_saving_rnn_estimator_test.py"], srcs_version = "PY2AND3", + tags = ["noasan"], deps = [ ":learn", "//tensorflow/contrib/layers:layers_py", @@ -715,12 +727,11 @@ py_test( ], ) -py_test( +tf_py_test( name = "graph_io_test", size = "small", srcs = ["python/learn/learn_io/graph_io_test.py"], - srcs_version = "PY2AND3", - deps = [ + additional_deps = [ ":learn", "//tensorflow/python:client", "//tensorflow/python:client_testlib", @@ -736,20 +747,7 @@ py_test( "//tensorflow/python:training", "//tensorflow/python:variables", ], -) - -py_test( - name = "numpy_io_test", - size = "small", - srcs = ["python/learn/learn_io/numpy_io_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":learn", - "//tensorflow/python:client_testlib", - "//tensorflow/python:errors", - "//tensorflow/python:training", - "//third_party/py/numpy", - ], + grpc_enabled = True, ) py_test( diff --git a/tensorflow/contrib/learn/python/learn/estimators/composable_model_test.py b/tensorflow/contrib/learn/python/learn/estimators/composable_model_test.py index 14750961efa30128708430fac038498de0a42118..ef5e620e8f08cffa7c2b945089aa5d150baefefc 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/composable_model_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/composable_model_test.py @@ -18,7 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.framework.python.ops import variables as contrib_variables +from tensorflow.python.training import training_util from tensorflow.contrib.layers.python.layers import feature_column from tensorflow.contrib.learn.python.learn.datasets import base from tensorflow.contrib.learn.python.learn.estimators import composable_model @@ -55,7 +55,7 @@ def _base_model_fn(features, labels, mode, params): raise NotImplementedError def _train_op_fn(loss): - global_step = contrib_variables.get_global_step() + global_step = training_util.get_global_step() assert global_step train_step = model.get_train_step(loss) diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn.py b/tensorflow/contrib/learn/python/learn/estimators/dnn.py index cb15ef23e95d27c737d8ae08065b804bafd39a07..c17b41c0f767e19d9c3635a8f60347a49b297cfb 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dnn.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dnn.py @@ -23,7 +23,7 @@ import six from tensorflow.contrib import layers from tensorflow.contrib.framework import deprecated from tensorflow.contrib.framework import deprecated_arg_values -from tensorflow.contrib.framework.python.ops import variables as contrib_variables +from tensorflow.python.training import training_util from tensorflow.contrib.layers.python.layers import feature_column from tensorflow.contrib.layers.python.layers import optimizers from tensorflow.contrib.learn.python.learn import metric_spec @@ -189,7 +189,7 @@ def _dnn_model_fn(features, labels, mode, params, config=None): """Returns the op to optimize the loss.""" return optimizers.optimize_loss( loss=loss, - global_step=contrib_variables.get_global_step(), + global_step=training_util.get_global_step(), learning_rate=_LEARNING_RATE, optimizer=_get_optimizer(optimizer), gradient_multipliers=( diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py index 57e70e169ca9d6fb2adc4e50bf387cc7cf330aed..4e65c180d8bee9ab8fe9b1fbf32edc229c31af09 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py @@ -1046,11 +1046,14 @@ class DNNLinearCombinedClassifierTest(test.TestCase): if global_step == 100: # Expected is 100, but because of the global step increment bug, is 50. - self.assertEqual(50, step_counter.steps) + # Occasionally, step increments one more time due to a race condition, + # reaching 51 steps. + self.assertIn(step_counter.steps, [50, 51]) else: - # Occasionally, training stops when global_step == 101, due to a race - # condition. - self.assertEqual(51, step_counter.steps) + # Occasionally, training stops when global_step == 102, due to a race + # condition. In addition, occasionally step increments one more time due + # to a race condition reaching 52 steps. + self.assertIn(step_counter.steps, [51, 52]) def testGlobalStepDNNLinearCombinedBugFixed(self): """Tests global step update for dnn-linear combined model.""" diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator.py b/tensorflow/contrib/learn/python/learn/estimators/estimator.py index 788d2d0b1a58fad16712c968593b40de0d3979f0..50c74add86fcf62c738e81426bfaf842fbac2b4e 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimator.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimator.py @@ -30,7 +30,6 @@ import six from google.protobuf import message from tensorflow.contrib import layers -from tensorflow.contrib import metrics as metrics_lib from tensorflow.contrib.framework import deprecated from tensorflow.contrib.framework import deprecated_args from tensorflow.contrib.framework import list_variables @@ -60,6 +59,7 @@ from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_util from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import lookup_ops +from tensorflow.python.ops import metrics as metrics_lib from tensorflow.python.ops import resources from tensorflow.python.ops import variables from tensorflow.python.platform import gfile @@ -360,10 +360,23 @@ def _write_dict_to_summary(output_dir, dictionary, current_global_step): logging.warn('Skipping summary for %s, cannot parse string to Summary.', key) continue + elif isinstance(dictionary[key], np.ndarray): + value = summary_proto.value.add() + value.tag = key + value.node_name = key + tensor_proto = tensor_util.make_tensor_proto(dictionary[key]) + value.tensor.CopyFrom(tensor_proto) + logging.info( + 'Summary for np.ndarray is not visible in Tensorboard by default. ' + 'Consider using a Tensorboard plugin for visualization (see ' + 'https://github.com/tensorflow/tensorboard-plugin-example/blob/master/README.md ' # pylint:disable=line-too-long + 'for more information).' + ) else: logging.warn( 'Skipping summary for %s, must be a float, np.float32, np.int64, ' - 'np.int32 or int or a serialized string of Summary.', key) + 'np.int32 or int or np.ndarray or a serialized string of Summary.', + key) summary_writer.add_summary(summary_proto, current_global_step) summary_writer.flush() @@ -1230,7 +1243,7 @@ class Estimator(BaseEstimator): if metric_key.MetricKey.LOSS not in model_fn_ops.eval_metric_ops: model_fn_ops.eval_metric_ops[metric_key.MetricKey.LOSS] = ( - metrics_lib.streaming_mean(model_fn_ops.loss)) + metrics_lib.mean(model_fn_ops.loss)) return model_fn_ops def _get_predict_ops(self, features): @@ -1256,7 +1269,9 @@ class Estimator(BaseEstimator): assets_extra=None, as_text=False, checkpoint_path=None, - graph_rewrite_specs=(GraphRewriteSpec((tag_constants.SERVING,), ()),)): + graph_rewrite_specs=(GraphRewriteSpec((tag_constants.SERVING,), ()),), + strip_default_attrs=False): + # pylint: disable=line-too-long """Exports inference graph as a SavedModel into given dir. Args: @@ -1280,6 +1295,9 @@ class Estimator(BaseEstimator): produce a separate MetaGraphDef within the exported SavedModel, tagged and rewritten as specified. Defaults to a single entry using the default serving tag ("serve") and no rewriting. + strip_default_attrs: Boolean. If `True`, default-valued attributes will be + removed from the NodeDefs. For a detailed guide, see + [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). Returns: The string path to the exported directory. @@ -1287,6 +1305,7 @@ class Estimator(BaseEstimator): Raises: ValueError: if an unrecognized export_type is requested. """ + # pylint: enable=line-too-long if serving_input_fn is None: raise ValueError('serving_input_fn must be defined.') @@ -1366,7 +1385,8 @@ class Estimator(BaseEstimator): signature_def_map=signature_def_map, assets_collection=ops.get_collection( ops.GraphKeys.ASSET_FILEPATHS), - legacy_init_op=init_op) + legacy_init_op=init_op, + strip_default_attrs=strip_default_attrs) # pylint: disable=protected-access base_meta_graph_def = builder._saved_model.meta_graphs[0] diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator_input_test.py b/tensorflow/contrib/learn/python/learn/estimators/estimator_input_test.py index 248c6c733ffca351c848ba07110ba89928634a23..9d7c1a099aa4be64ca0296fa5b870597dabec7b4 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimator_input_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimator_input_test.py @@ -23,7 +23,7 @@ import tempfile import numpy as np -from tensorflow.contrib.framework.python.ops import variables +from tensorflow.python.training import training_util from tensorflow.contrib.layers.python.layers import optimizers from tensorflow.contrib.learn.python.learn import metric_spec from tensorflow.contrib.learn.python.learn import models @@ -114,7 +114,7 @@ def linear_model_params_fn(features, labels, mode, params): prediction, loss = (models.linear_regression_zero_init(features, labels)) train_op = optimizers.optimize_loss( loss, - variables.get_global_step(), + training_util.get_global_step(), optimizer='Adagrad', learning_rate=params['learning_rate']) return prediction, loss, train_op @@ -129,7 +129,7 @@ def linear_model_fn(features, labels, mode): (_, features), = features.items() prediction, loss = (models.linear_regression_zero_init(features, labels)) train_op = optimizers.optimize_loss( - loss, variables.get_global_step(), optimizer='Adagrad', learning_rate=0.1) + loss, training_util.get_global_step(), optimizer='Adagrad', learning_rate=0.1) return prediction, loss, train_op @@ -139,7 +139,7 @@ def linear_model_fn_with_model_fn_ops(features, labels, mode): model_fn.ModeKeys.INFER) prediction, loss = (models.linear_regression_zero_init(features, labels)) train_op = optimizers.optimize_loss( - loss, variables.get_global_step(), optimizer='Adagrad', learning_rate=0.1) + loss, training_util.get_global_step(), optimizer='Adagrad', learning_rate=0.1) return model_fn.ModelFnOps( mode=mode, predictions=prediction, loss=loss, train_op=train_op) @@ -150,7 +150,7 @@ def logistic_model_no_mode_fn(features, labels): labels = array_ops.one_hot(labels, 3, 1, 0) prediction, loss = (models.logistic_regression_zero_init(features, labels)) train_op = optimizers.optimize_loss( - loss, variables.get_global_step(), optimizer='Adagrad', learning_rate=0.1) + loss, training_util.get_global_step(), optimizer='Adagrad', learning_rate=0.1) return { 'class': math_ops.argmax(prediction, 1), 'prob': prediction diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py b/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py index be2b0cb3ca959323b4de095ca072278f028be301..5f682838b7afadec7a54df782cb5b89ac6746659 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py @@ -32,7 +32,7 @@ from google.protobuf import text_format from tensorflow.contrib import learn from tensorflow.contrib import lookup -from tensorflow.contrib.framework.python.ops import variables +from tensorflow.python.training import training_util from tensorflow.contrib.layers.python.layers import feature_column as feature_column_lib from tensorflow.contrib.layers.python.layers import optimizers from tensorflow.contrib.learn.python.learn import experiment @@ -132,7 +132,7 @@ def linear_model_params_fn(features, labels, mode, params): prediction, loss = (models.linear_regression_zero_init(features, labels)) train_op = optimizers.optimize_loss( loss, - variables.get_global_step(), + training_util.get_global_step(), optimizer='Adagrad', learning_rate=params['learning_rate']) return prediction, loss, train_op @@ -147,7 +147,7 @@ def linear_model_fn(features, labels, mode): (_, features), = features.items() prediction, loss = (models.linear_regression_zero_init(features, labels)) train_op = optimizers.optimize_loss( - loss, variables.get_global_step(), optimizer='Adagrad', learning_rate=0.1) + loss, training_util.get_global_step(), optimizer='Adagrad', learning_rate=0.1) return prediction, loss, train_op @@ -157,7 +157,7 @@ def linear_model_fn_with_model_fn_ops(features, labels, mode): model_fn.ModeKeys.INFER) prediction, loss = (models.linear_regression_zero_init(features, labels)) train_op = optimizers.optimize_loss( - loss, variables.get_global_step(), optimizer='Adagrad', learning_rate=0.1) + loss, training_util.get_global_step(), optimizer='Adagrad', learning_rate=0.1) return model_fn.ModelFnOps( mode=mode, predictions=prediction, loss=loss, train_op=train_op) @@ -168,7 +168,7 @@ def logistic_model_no_mode_fn(features, labels): labels = array_ops.one_hot(labels, 3, 1, 0) prediction, loss = (models.logistic_regression_zero_init(features, labels)) train_op = optimizers.optimize_loss( - loss, variables.get_global_step(), optimizer='Adagrad', learning_rate=0.1) + loss, training_util.get_global_step(), optimizer='Adagrad', learning_rate=0.1) return { 'class': math_ops.argmax(prediction, 1), 'prob': prediction @@ -241,7 +241,7 @@ def _build_estimator_for_resource_export_test(): const = constant_op.constant(-1, dtype=dtypes.int64) table = lookup.MutableHashTable( dtypes.string, dtypes.int64, const, name='LookupTableModel') - update_global_step = variables.get_global_step().assign_add(1) + update_global_step = training_util.get_global_step().assign_add(1) if mode in (model_fn.ModeKeys.TRAIN, model_fn.ModeKeys.EVAL): key = constant_op.constant(['key']) value = constant_op.constant([42], dtype=dtypes.int64) @@ -306,7 +306,7 @@ def _model_fn_ops( mode=mode, predictions=constant_op.constant(0.), loss=constant_op.constant(0.), - train_op=variables.get_global_step().assign_add(1)) + train_op=training_util.get_global_step().assign_add(1)) def _make_input_fn(features, labels): @@ -389,7 +389,7 @@ class EstimatorModelFnTest(test.TestCase): self.assertEqual(expected_param, params) self.assertEqual(model_dir, expected_model_dir) return (constant_op.constant(0.), constant_op.constant(0.), - variables.get_global_step().assign_add(1)) + training_util.get_global_step().assign_add(1)) est = estimator.Estimator(model_fn=_argument_checker, params=expected_param, model_dir=expected_model_dir) @@ -400,7 +400,7 @@ class EstimatorModelFnTest(test.TestCase): def _invalid_model_fn(features, labels): # pylint: disable=unused-argument w = variables_lib.Variable(42.0, 'weight') - update_global_step = variables.get_global_step().assign_add(1) + update_global_step = training_util.get_global_step().assign_add(1) with ops.control_dependencies([update_global_step]): loss = 100.0 - w return None, loss, None @@ -415,7 +415,7 @@ class EstimatorModelFnTest(test.TestCase): # pylint: disable=unused-argument w = variables_lib.Variable(42.0, 'weight') loss = 100.0 - w - update_global_step = variables.get_global_step().assign_add(1) + update_global_step = training_util.get_global_step().assign_add(1) with ops.control_dependencies([update_global_step]): train_op = w.assign_add(loss / 100.0) predictions = loss @@ -434,7 +434,7 @@ class EstimatorModelFnTest(test.TestCase): # pylint: disable=unused-argument w = variables_lib.Variable(42.0, 'weight') loss = 100.0 - w - update_global_step = variables.get_global_step().assign_add(1) + update_global_step = training_util.get_global_step().assign_add(1) with ops.control_dependencies([update_global_step]): train_op = w.assign_add(loss / 100.0) return None, loss, train_op @@ -464,7 +464,7 @@ class EstimatorModelFnTest(test.TestCase): mode=mode, predictions=constant_op.constant(0.), loss=constant_op.constant(0.), - train_op=variables.get_global_step().assign_add(1), + train_op=training_util.get_global_step().assign_add(1), scaffold=monitored_session.Scaffold(init_fn=_init_fn)) est = estimator.Estimator(model_fn=_model_fn_scaffold) @@ -483,7 +483,7 @@ class EstimatorModelFnTest(test.TestCase): mode=mode, predictions=constant_op.constant([[1.]]), loss=constant_op.constant(0.), - train_op=variables.get_global_step().assign_add(1), + train_op=training_util.get_global_step().assign_add(1), scaffold=monitored_session.Scaffold(saver=self.mock_saver)) def input_fn(): @@ -884,6 +884,35 @@ class EstimatorTest(test.TestCase): self.assertTrue('MSE' in output_values) self.assertTrue(output_values['MSE'].HasField('histo')) + def testSummaryWritingWithTensor(self): + + def _streaming_precition_mean_tensor(predictions, + weights=None, + metrics_collections=None, + updates_collections=None, + name=None): + return metric_ops.streaming_mean_tensor( + predictions, + weights=weights, + metrics_collections=metrics_collections, + updates_collections=updates_collections, + name=name) + + est = estimator.Estimator(model_fn=linear_model_fn) + est.fit(input_fn=boston_input_fn, steps=200) + est.evaluate( + input_fn=boston_input_fn, + steps=200, + metrics={'PMT': _streaming_precition_mean_tensor}) + events = util_test.latest_events(est.model_dir + '/eval') + output_values = {} + for e in events: + if e.HasField('summary'): + for v in e.summary.value: + output_values[v.tag] = v + self.assertTrue('PMT' in output_values) + self.assertTrue(output_values['PMT'].HasField('tensor')) + def testLossInGraphCollection(self): class _LossCheckerHook(session_run_hook.SessionRunHook): diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimators_test.py b/tensorflow/contrib/learn/python/learn/estimators/estimators_test.py index 1d89dfb55b10b032cab7dcf434d396404d4eb83b..8131e0fde6fea5501cacc4714f53ed8d867ca70f 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimators_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimators_test.py @@ -22,7 +22,7 @@ import random import numpy as np -from tensorflow.contrib.framework.python.ops import variables +from tensorflow.python.training import training_util from tensorflow.contrib.learn.python import learn from tensorflow.contrib.learn.python.learn import datasets from tensorflow.contrib.learn.python.learn import metric_spec @@ -62,7 +62,7 @@ class FeatureEngineeringFunctionTest(test.TestCase): _ = labels predictions = features["transformed_x"] loss = constant_op.constant([2.]) - update_global_step = variables.get_global_step().assign_add(1) + update_global_step = training_util.get_global_step().assign_add(1) return predictions, loss, update_global_step estimator = estimator_lib.Estimator( @@ -100,7 +100,7 @@ class FeatureEngineeringFunctionTest(test.TestCase): _ = labels predictions = features["x"] loss = constant_op.constant([2.]) - update_global_step = variables.get_global_step().assign_add(1) + update_global_step = training_util.get_global_step().assign_add(1) return predictions, loss, update_global_step estimator = estimator_lib.Estimator( @@ -139,7 +139,7 @@ class FeatureEngineeringFunctionTest(test.TestCase): _ = labels predictions = features["x"] loss = constant_op.constant([2.]) - update_global_step = variables.get_global_step().assign_add(1) + update_global_step = training_util.get_global_step().assign_add(1) return predictions, loss, update_global_step estimator_with_fe_fn = estimator_lib.Estimator( diff --git a/tensorflow/contrib/learn/python/learn/estimators/kmeans.py b/tensorflow/contrib/learn/python/learn/estimators/kmeans.py index 992b804f59ecd88fedc2fba10d3079f93c4fe83d..8f9d6fc318a357853bdb8e3264f6691b410006b1 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/kmeans.py +++ b/tensorflow/contrib/learn/python/learn/estimators/kmeans.py @@ -28,7 +28,7 @@ import time import numpy as np from tensorflow.contrib.factorization.python.ops import clustering_ops -from tensorflow.contrib.framework.python.ops import variables +from tensorflow.python.training import training_util from tensorflow.contrib.learn.python.learn.estimators import estimator from tensorflow.contrib.learn.python.learn.estimators.model_fn import ModelFnOps from tensorflow.python.framework import ops @@ -128,7 +128,7 @@ def _kmeans_clustering_model_fn(features, labels, mode, params, config): random_seed=params.get('random_seed'), kmeans_plus_plus_num_retries=params.get( 'kmeans_plus_plus_num_retries')).training_graph() - incr_step = state_ops.assign_add(variables.get_global_step(), 1) + incr_step = state_ops.assign_add(training_util.get_global_step(), 1) loss = math_ops.reduce_sum(losses, name=KMeansClustering.LOSS_OP_NAME) summary.scalar('loss/raw', loss) training_op = with_dependencies([training_op, incr_step], loss) diff --git a/tensorflow/contrib/learn/python/learn/estimators/kmeans_test.py b/tensorflow/contrib/learn/python/learn/estimators/kmeans_test.py index ce87b4723d436495e5fb149f0ab8f2eea44d82b8..b28835a809736a099ad2f08d127dc68d7977a3c1 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/kmeans_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/kmeans_test.py @@ -199,15 +199,7 @@ class KMeansTest(KMeansTestBase): input_fn=self.input_fn(batch_size=self.num_points), steps=1) self.assertNear(self.true_score, score, self.true_score * 0.01) - def test_infer(self): - kmeans = self._kmeans() - # Make a call to fit to initialize the cluster centers. - max_steps = 1 - kmeans.fit(input_fn=self.input_fn(), max_steps=max_steps) - clusters = kmeans.clusters() - - # Make a small test set - num_points = 10 + def _infer_helper(self, kmeans, clusters, num_points): points, true_assignments, true_offsets = make_random_points( clusters, num_points) # Test predict @@ -231,6 +223,17 @@ class KMeansTest(KMeansTestBase): np.transpose(np.sum(np.square(clusters), axis=1, keepdims=True))) self.assertAllClose(transform, true_transform, rtol=0.05, atol=10) + def test_infer(self): + kmeans = self._kmeans() + # Make a call to fit to initialize the cluster centers. + max_steps = 1 + kmeans.fit(input_fn=self.input_fn(), max_steps=max_steps) + clusters = kmeans.clusters() + + # Run inference on small datasets. + self._infer_helper(kmeans, clusters, num_points=10) + self._infer_helper(kmeans, clusters, num_points=1) + class KMeansTestMultiStageInit(KMeansTestBase): diff --git a/tensorflow/contrib/learn/python/learn/estimators/linear.py b/tensorflow/contrib/learn/python/learn/estimators/linear.py index f5445ad4e728dbd3904279573771de9454b5d17c..37aa8b339622415d082933cdf66d2472a4119b48 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/linear.py +++ b/tensorflow/contrib/learn/python/learn/estimators/linear.py @@ -26,7 +26,7 @@ import six from tensorflow.contrib import layers from tensorflow.contrib.framework import deprecated from tensorflow.contrib.framework import deprecated_arg_values -from tensorflow.contrib.framework.python.ops import variables as contrib_variables +from tensorflow.python.training import training_util from tensorflow.contrib.layers.python.layers import feature_column from tensorflow.contrib.learn.python.learn.estimators import estimator from tensorflow.contrib.learn.python.learn.estimators import head as head_lib @@ -170,7 +170,7 @@ def _linear_model_fn(features, labels, mode, params, config=None): weight_collections=[parent_scope]) def _train_op_fn(loss): - global_step = contrib_variables.get_global_step() + global_step = training_util.get_global_step() my_vars = ops.get_collection(parent_scope) grads = gradients.gradients(loss, my_vars) if gradient_clip_norm: @@ -252,7 +252,7 @@ def sdca_model_fn(features, labels, mode, params): _add_bias_column(feature_columns, features, bias, columns_to_variables) def _train_op_fn(unused_loss): - global_step = contrib_variables.get_global_step() + global_step = training_util.get_global_step() sdca_model, train_op = optimizer.get_train_step(columns_to_variables, weight_column_name, loss_type, features, diff --git a/tensorflow/contrib/learn/python/learn/estimators/logistic_regressor_test.py b/tensorflow/contrib/learn/python/learn/estimators/logistic_regressor_test.py index 93c62f87e8495f299a8c456574c7b40534186304..656d68b76888d9319c0b9be481f9b0478ac4314c 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/logistic_regressor_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/logistic_regressor_test.py @@ -21,7 +21,7 @@ from __future__ import print_function import numpy as np from tensorflow.contrib import layers -from tensorflow.contrib.framework.python.ops import variables +from tensorflow.python.training import training_util from tensorflow.contrib.layers.python.layers import optimizers from tensorflow.contrib.learn.python.learn.datasets import base from tensorflow.contrib.learn.python.learn.estimators import logistic_regressor @@ -57,7 +57,7 @@ def _logistic_regression_model_fn(features, labels, mode): predictions = math_ops.sigmoid(logits) loss = losses.sigmoid_cross_entropy(labels, logits) train_op = optimizers.optimize_loss( - loss, variables.get_global_step(), optimizer='Adagrad', learning_rate=0.1) + loss, training_util.get_global_step(), optimizer='Adagrad', learning_rate=0.1) return predictions, loss, train_op diff --git a/tensorflow/contrib/learn/python/learn/experiment.py b/tensorflow/contrib/learn/python/learn/experiment.py index fc4bd1f461d7bfbfcfb78201d527959055342f0a..9576ff21c243022276bb0641882dfaf0decf05c0 100644 --- a/tensorflow/contrib/learn/python/learn/experiment.py +++ b/tensorflow/contrib/learn/python/learn/experiment.py @@ -35,6 +35,7 @@ from tensorflow.contrib.learn.python.learn import trainable from tensorflow.contrib.learn.python.learn.estimators import run_config from tensorflow.contrib.tpu.python.tpu import tpu_estimator from tensorflow.python.estimator import estimator as core_estimator +from tensorflow.python.estimator import util as estimator_util from tensorflow.python.framework import ops from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import basic_session_run_hooks @@ -46,6 +47,18 @@ from tensorflow.python.util import compat __all__ = ["Experiment"] +def _get_standardized_predicate_fn(predicate_fn): + pred_fn_args = estimator_util.fn_args(predicate_fn) + if "checkpoint_path" not in pred_fn_args: + # pylint: disable=unused-argument + def _pred_fn_wrapper(eval_results, checkpoint_path): + return predicate_fn(eval_results) + + return _pred_fn_wrapper + else: + return predicate_fn + + class _EvalAndExportListener(basic_session_run_hooks.CheckpointSaverListener): """Listener that evaluates and exports a model after creating a checkpoint. @@ -446,22 +459,33 @@ class Experiment(object): evaluate_checkpoint_only_once: Whether to skip evaluation of checkpoints that have already been evaluated. Default is `True`. continuous_eval_predicate_fn: A predicate function determining whether to - continue eval after each iteration. `predicate_fn` takes the evaluation - results as arguments. At the beginning of evaluation, the passed eval - results will be None so it's expected that the predicate function - handles that gracefully. When `predicate_fn` is not specified, - continuous eval will run in an infinite loop (if `train_steps` is None) - or exit once global step reaches `train_steps`. + continue eval after each iteration. A `predicate_fn` has one of the + following signatures: + * (eval_results) -> boolean + * (eval_results, checkpoint_path) -> boolean + Where `eval_results` is the dictionary of metric evaluations and + checkpoint_path is the path to the checkpoint containing the parameters + on which that evaluation was based. + At the beginning of evaluation, the passed `eval_results` will be None + so it's expected that the predicate function handles that gracefully. + When `predicate_fn` is not specified, continuous eval will run in an + infinite loop (if `train_steps` is None). or exit once global step + reaches `train_steps`. + export: Whether to export from this step. Default is 'True'. Raises: ValueError: if `continuous_eval_predicate_fn` is neither None nor callable. """ - if (continuous_eval_predicate_fn is not None and - not callable(continuous_eval_predicate_fn)): - raise ValueError( - "`continuous_eval_predicate_fn` must be a callable, or None.") + if continuous_eval_predicate_fn is not None: + if not callable(continuous_eval_predicate_fn): + raise ValueError( + "`continuous_eval_predicate_fn` must be a callable, or None.") + predicate_fn = _get_standardized_predicate_fn( + continuous_eval_predicate_fn) + else: + predicate_fn = None if delay_secs is None: delay_secs = self._eval_delay_secs @@ -475,8 +499,10 @@ class Experiment(object): previous_path = None eval_result = None last_warning_time = 0 - while (not continuous_eval_predicate_fn or - continuous_eval_predicate_fn(eval_result)): + while (not predicate_fn or + predicate_fn( + eval_result, + checkpoint_path=previous_path if eval_result else None)): # Exit if we have already reached number of steps to train. if self._has_training_stopped(eval_result): logging.info("Exiting continuous eval, global_step=%s >= " @@ -682,11 +708,19 @@ class Experiment(object): Args: continuous_eval_predicate_fn: A predicate function determining whether to - continue after each iteration. `predicate_fn` takes the evaluation - results as its arguments. At the beginning of evaluation, the passed - eval results will be None so it's expected that the predicate function - handles that gracefully. When `predicate_fn` is not specified, this will - run in an infinite loop or exit when global_step reaches `train_steps`. + continue eval after each iteration. A `predicate_fn` has one of the + following signatures: + * (eval_results) -> boolean + * (eval_results, checkpoint_path) -> boolean + Where `eval_results` is the dictionary of metric evaluations and + checkpoint_path is the path to the checkpoint containing the parameters + on which that evaluation was based. + At the beginning of evaluation, the passed `eval_results` and + `checkpoint_path` will be None so it's expected that the predicate + function handles that gracefully. + When `predicate_fn` is not specified, continuous eval will run in an + infinite loop (if `train_steps` is None). or exit once global step + reaches `train_steps`. Returns: A tuple of the result of the `evaluate` call to the `Estimator` and the @@ -697,13 +731,18 @@ class Experiment(object): callable. """ - if (continuous_eval_predicate_fn is not None and - not callable(continuous_eval_predicate_fn)): - raise ValueError( - "`continuous_eval_predicate_fn` must be a callable, or None.") + if continuous_eval_predicate_fn is not None: + if not callable(continuous_eval_predicate_fn): + raise ValueError( + "`continuous_eval_predicate_fn` must be a callable, or None.") + predicate_fn = _get_standardized_predicate_fn( + continuous_eval_predicate_fn) + else: + predicate_fn = None - eval_result = None export_results = None + latest_checkpoint = None + eval_result = None # Set the default value for train_steps_per_iteration, which will be # overridden by other settings. @@ -713,8 +752,10 @@ class Experiment(object): elif self._train_steps is not None: train_steps_per_iteration = int(self._train_steps / 10) - while (not continuous_eval_predicate_fn or - continuous_eval_predicate_fn(eval_result)): + while (not predicate_fn or + predicate_fn( + eval_result, + checkpoint_path=latest_checkpoint if eval_result else None)): if self._has_training_stopped(eval_result): # Exits once max steps of training is satisfied. @@ -729,11 +770,14 @@ class Experiment(object): saving_listeners=self._saving_listeners) logging.info("Evaluating model now.") - eval_result = self._call_evaluate(input_fn=self._eval_input_fn, - steps=self._eval_steps, - metrics=self._eval_metrics, - name="one_pass", - hooks=self._eval_hooks) + latest_checkpoint = saver.latest_checkpoint(self._estimator.model_dir) + eval_result = self._call_evaluate( + input_fn=self._eval_input_fn, + steps=self._eval_steps, + metrics=self._eval_metrics, + name="one_pass", + checkpoint_path=latest_checkpoint, + hooks=self._eval_hooks) export_results = self._maybe_export(eval_result) return eval_result, export_results diff --git a/tensorflow/contrib/learn/python/learn/experiment_test.py b/tensorflow/contrib/learn/python/learn/experiment_test.py index c29c198d094090a59c8c7dd2949c3f069adf49d0..545d7d8924c0c10544e6113e2968b7ae3d2090fc 100644 --- a/tensorflow/contrib/learn/python/learn/experiment_test.py +++ b/tensorflow/contrib/learn/python/learn/experiment_test.py @@ -492,6 +492,33 @@ class ExperimentTest(test.TestCase): self.assertEqual(3, est.eval_count) self.assertEqual([noop_hook], est.eval_hooks) + def test_continuous_eval_predicate_fn_with_checkpoint(self): + for est in self._estimators_for_tests(): + eval_metrics = 'eval_metrics' if not isinstance( + est, core_estimator.Estimator) else None + est.fake_checkpoint() + noop_hook = _NoopHook() + + def _predicate_fn(eval_result, checkpoint_path): + self.assertEqual(not eval_result, + checkpoint_path is None) + return est.eval_count < 3 # pylint: disable=cell-var-from-loop + + ex = experiment.Experiment( + est, + train_input_fn='train_input', + eval_input_fn='eval_input', + eval_metrics=eval_metrics, + eval_hooks=[noop_hook], + eval_delay_secs=0, + continuous_eval_throttle_secs=0) + ex.continuous_eval( + evaluate_checkpoint_only_once=False, + continuous_eval_predicate_fn=_predicate_fn) + self.assertEqual(0, est.fit_count) + self.assertEqual(3, est.eval_count) + self.assertEqual([noop_hook], est.eval_hooks) + def test_run_local(self): for est in self._estimators_for_tests(): eval_metrics = 'eval_metrics' if not isinstance( diff --git a/tensorflow/contrib/learn/python/learn/export_strategy.py b/tensorflow/contrib/learn/python/learn/export_strategy.py index f276aab0e6beb011a21c20fa194dd5212db796d1..55a8b824312b89e0ac66513242191f4201ac212a 100644 --- a/tensorflow/contrib/learn/python/learn/export_strategy.py +++ b/tensorflow/contrib/learn/python/learn/export_strategy.py @@ -26,13 +26,14 @@ __all__ = ['ExportStrategy'] class ExportStrategy( - collections.namedtuple('ExportStrategy', ['name', 'export_fn'])): + collections.namedtuple('ExportStrategy', + ['name', 'export_fn', 'strip_default_attrs'])): """A class representing a type of model export. Typically constructed by a utility function specific to the exporter, such as `saved_model_export_utils.make_export_strategy()`. - The fields are: + Attributes: name: The directory name under the export base directory where exports of this type will be written. export_fn: A function that writes an export, given an estimator, a @@ -45,11 +46,20 @@ class ExportStrategy( The signature of this function must be one of: - * `(estimator, export_path) -> export_path` - * `(estimator, export_path, checkpoint_path) -> export_path` - * `(estimator, export_path, checkpoint_path, eval_result) -> export_path` + * `(estimator, export_path) -> export_path` + * `(estimator, export_path, checkpoint_path) -> export_path` + * `(estimator, export_path, checkpoint_path, eval_result) -> export_path` + * `(estimator, export_path, checkpoint_path, eval_result, + strip_default_attrs) -> export_path` + strip_default_attrs: (Optional) Boolean. If set as True, default attrs in + the `GraphDef` will be stripped on write. This is recommended for better + forward compatibility of the resulting `SavedModel`. """ + def __new__(cls, name, export_fn, strip_default_attrs=None): + return super(ExportStrategy, cls).__new__( + cls, name, export_fn, strip_default_attrs) + def export(self, estimator, export_path, @@ -83,5 +93,6 @@ class ExportStrategy( raise ValueError('An export_fn accepting eval_result must also accept ' 'checkpoint_path.') kwargs['eval_result'] = eval_result - + if 'strip_default_attrs' in export_fn_args: + kwargs['strip_default_attrs'] = self.strip_default_attrs return self.export_fn(estimator, export_path, **kwargs) diff --git a/tensorflow/contrib/learn/python/learn/export_strategy_test.py b/tensorflow/contrib/learn/python/learn/export_strategy_test.py new file mode 100644 index 0000000000000000000000000000000000000000..43c3551cccc3b8e6b66bd2b36839a3dfc5fe8eea --- /dev/null +++ b/tensorflow/contrib/learn/python/learn/export_strategy_test.py @@ -0,0 +1,89 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for ExportStrategy.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.learn.python.learn import export_strategy +from tensorflow.python.platform import test + + +class ExportStrategyTest(test.TestCase): + + def test_no_optional_args_export(self): + model_path = '/path/to/model' + def _export_fn(estimator, export_path): + self.assertTupleEqual((estimator, export_path), (None, None)) + return model_path + + strategy = export_strategy.ExportStrategy('foo', _export_fn) + self.assertTupleEqual(strategy, ('foo', _export_fn, None)) + self.assertIs(strategy.export(None, None), model_path) + + def test_checkpoint_export(self): + ckpt_model_path = '/path/to/checkpoint_model' + def _ckpt_export_fn(estimator, export_path, checkpoint_path): + self.assertTupleEqual((estimator, export_path), (None, None)) + self.assertEqual(checkpoint_path, 'checkpoint') + return ckpt_model_path + + strategy = export_strategy.ExportStrategy('foo', _ckpt_export_fn) + self.assertTupleEqual(strategy, ('foo', _ckpt_export_fn, None)) + self.assertIs(strategy.export(None, None, 'checkpoint'), ckpt_model_path) + + def test_checkpoint_eval_export(self): + ckpt_eval_model_path = '/path/to/checkpoint_eval_model' + def _ckpt_eval_export_fn(estimator, export_path, checkpoint_path, + eval_result): + self.assertTupleEqual((estimator, export_path), (None, None)) + self.assertEqual(checkpoint_path, 'checkpoint') + self.assertEqual(eval_result, 'eval') + return ckpt_eval_model_path + + strategy = export_strategy.ExportStrategy('foo', _ckpt_eval_export_fn) + self.assertTupleEqual(strategy, ('foo', _ckpt_eval_export_fn, None)) + self.assertIs(strategy.export(None, None, 'checkpoint', 'eval'), + ckpt_eval_model_path) + + def test_eval_only_export(self): + def _eval_export_fn(estimator, export_path, eval_result): + del estimator, export_path, eval_result + + strategy = export_strategy.ExportStrategy('foo', _eval_export_fn) + self.assertTupleEqual(strategy, ('foo', _eval_export_fn, None)) + with self.assertRaisesRegexp(ValueError, 'An export_fn accepting ' + 'eval_result must also accept ' + 'checkpoint_path'): + strategy.export(None, None, eval_result='eval') + + def test_strip_default_attr_export(self): + strip_default_attrs_model_path = '/path/to/strip_default_attrs_model' + def _strip_default_attrs_export_fn(estimator, export_path, + strip_default_attrs): + self.assertTupleEqual((estimator, export_path), (None, None)) + self.assertTrue(strip_default_attrs) + return strip_default_attrs_model_path + + strategy = export_strategy.ExportStrategy('foo', + _strip_default_attrs_export_fn, + True) + self.assertTupleEqual(strategy, + ('foo', _strip_default_attrs_export_fn, True)) + self.assertIs(strategy.export(None, None), strip_default_attrs_model_path) + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py b/tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py index 86fad4c5535a918d87e0741687cfebe3afaf9ddf..f36a778b529a83f158241ddb060959c4b33e2e95 100644 --- a/tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py +++ b/tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py @@ -857,8 +857,8 @@ class DaskDataFeeder(object): """Returns a function, that will sample data and provide it to placeholders. Args: - input_placeholder: tf.Placeholder for input features mini batch. - output_placeholder: tf.Placeholder for output labels. + input_placeholder: tf.placeholder for input features mini batch. + output_placeholder: tf.placeholder for output labels. Returns: A function that when called samples a random subset of batch size diff --git a/tensorflow/contrib/learn/python/learn/learn_io/graph_io.py b/tensorflow/contrib/learn/python/learn/learn_io/graph_io.py index 4b34fc62849766370979bb2002d42ee03ea7161a..3a46c239688017f9204d2c6182a6f81cd325a417 100644 --- a/tensorflow/contrib/learn/python/learn/learn_io/graph_io.py +++ b/tensorflow/contrib/learn/python/learn/learn_io/graph_io.py @@ -24,6 +24,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.layers import utils from tensorflow.python.ops import array_ops from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import io_ops @@ -280,14 +281,33 @@ def _get_file_names(file_pattern, randomize_input): def _get_examples(file_name_queue, reader, num_threads, read_batch_size, filter_fn, parse_fn): + """Get example filenames matching. + + Args: + file_name_queue: A queue implementation that dequeues elements in + first-in first-out order. + reader: A function or class that returns an object with + `read` method, (filename tensor) -> (example tensor). + num_threads: The number of threads enqueuing examples. + read_batch_size: An int or scalar `Tensor` specifying the number of + records to read at once. + filter_fn: Filtering function, takes both keys as well as an `Example` + Tensors and returns a boolean mask of the same shape as the input Tensors + to be applied for filtering. If `None`, no filtering is done. + parse_fn: Parsing function, takes `Example` Tensor returns parsed + representation. If `None`, no parsing is done. + + Returns: + List of example file names matching `file_name_queue`. + """ with ops.name_scope('read'): example_list = [] for _ in range(num_threads): - if read_batch_size > 1: - keys, examples_proto = reader().read_up_to(file_name_queue, - read_batch_size) - else: - keys, examples_proto = reader().read(file_name_queue) + keys, examples_proto = utils.smart_cond( + read_batch_size > 1, + lambda: reader().read_up_to(file_name_queue, read_batch_size), + lambda: reader().read(file_name_queue)) + if filter_fn: mask = filter_fn(keys, examples_proto) keys = array_ops.boolean_mask(keys, mask) @@ -379,14 +399,15 @@ def _read_keyed_batch_examples_helper(file_pattern, capacity=1, dtypes=[dtypes.string], shapes=[[]]) enqueue_op = file_name_queue.enqueue( input_pipeline_ops.seek_next( - file_names, shuffle=randomize_input, num_epochs=num_epochs, + file_names, + shuffle=randomize_input, + num_epochs=num_epochs, seed=seed)) queue_runner.add_queue_runner( queue_runner.QueueRunner(file_name_queue, [enqueue_op])) else: file_name_queue = input_ops.string_input_producer( - constant_op.constant( - file_names, name='input'), + constant_op.constant(file_names, name='input'), shuffle=randomize_input, num_epochs=num_epochs, name=file_name_queue_scope, @@ -496,7 +517,8 @@ def read_keyed_batch_features(file_pattern, """ with ops.name_scope(name, 'read_batch_features', [file_pattern]) as scope: - if read_batch_size is None: read_batch_size = batch_size + if read_batch_size is None: + read_batch_size = batch_size keys, examples = read_keyed_batch_examples( file_pattern, batch_size, diff --git a/tensorflow/contrib/learn/python/learn/learn_io/graph_io_test.py b/tensorflow/contrib/learn/python/learn/learn_io/graph_io_test.py index 6f0fd9a2976d37d1c701a96f50c2b987562cb191..e11e8b698adc113486bbb45572c8129e964cc931 100644 --- a/tensorflow/contrib/learn/python/learn/learn_io/graph_io_test.py +++ b/tensorflow/contrib/learn/python/learn/learn_io/graph_io_test.py @@ -204,8 +204,7 @@ class GraphIOTest(test.TestCase): shape = (0,) features = { "feature": - parsing_ops.FixedLenFeature( - shape=shape, dtype=dtypes_lib.float32) + parsing_ops.FixedLenFeature(shape=shape, dtype=dtypes_lib.float32) } with ops.Graph().as_default() as g, self.test_session(graph=g) as sess: @@ -255,8 +254,8 @@ class GraphIOTest(test.TestCase): self.assertAllEqual((None,), inputs.get_shape().as_list()) self.assertEqual("%s:1" % name, inputs.name) file_name_queue_name = "%s/file_name_queue" % name - file_name_queue_limit_name = ("%s/limit_epochs/epochs" % - file_name_queue_name) + file_name_queue_limit_name = ( + "%s/limit_epochs/epochs" % file_name_queue_name) file_names_name = "%s/input" % file_name_queue_name example_queue_name = "%s/random_shuffle_queue" % name op_nodes = test_util.assert_ops_in_graph({ @@ -354,8 +353,8 @@ class GraphIOTest(test.TestCase): json_lines = [ "".join([ '{"features": { "feature": { "sequence": {', - '"bytes_list": { "value": ["', base64.b64encode(l).decode("ascii"), - '"]}}}}}\n' + '"bytes_list": { "value": ["', + base64.b64encode(l).decode("ascii"), '"]}}}}}\n' ]) for l in lines ] return self._create_temp_file("".join(json_lines)) @@ -823,6 +822,31 @@ class GraphIOTest(test.TestCase): coord.request_stop() coord.join(threads) + def test_read_keyed_batch_features_shared_queue(self): + batch_size = 17 + shape = (0,) + fixed_feature = parsing_ops.FixedLenFeature( + shape=shape, dtype=dtypes_lib.float32) + feature = {"feature": fixed_feature} + reader = io_ops.TFRecordReader + + _, queued_feature = graph_io.read_keyed_batch_features_shared_queue( + _VALID_FILE_PATTERN, batch_size, feature, reader) + + with ops.Graph().as_default() as g, self.test_session(graph=g) as session: + features_result = graph_io.read_batch_features( + _VALID_FILE_PATTERN, batch_size, feature, reader) + session.run(variables.local_variables_initializer()) + + self.assertAllEqual( + queued_feature.get("feature").get_shape().as_list(), + features_result.get("feature").get_shape().as_list()) + + def test_get_file_names_errors(self): + # Raise bad file_pattern. + with self.assertRaises(ValueError): + graph_io._get_file_names([], True) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/learn/python/learn/learn_io/numpy_io_test.py b/tensorflow/contrib/learn/python/learn/learn_io/numpy_io_test.py deleted file mode 100644 index 6fe8de8705b8854e5861879d2a505fe03fddc7e5..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/learn/python/learn/learn_io/numpy_io_test.py +++ /dev/null @@ -1,280 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for numpy_io.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.contrib.learn.python.learn.learn_io import numpy_io -from tensorflow.python.framework import errors -from tensorflow.python.platform import test -from tensorflow.python.training import coordinator -from tensorflow.python.training import queue_runner_impl - - -class NumpyIoTest(test.TestCase): - - def testNumpyInputFn(self): - a = np.arange(4) * 1.0 - b = np.arange(32, 36) - x = {'a': a, 'b': b} - y = np.arange(-32, -28) - - with self.test_session() as session: - input_fn = numpy_io.numpy_input_fn( - x, y, batch_size=2, shuffle=False, num_epochs=1) - features, target = input_fn() - - coord = coordinator.Coordinator() - threads = queue_runner_impl.start_queue_runners(session, coord=coord) - - res = session.run([features, target]) - self.assertAllEqual(res[0]['a'], [0, 1]) - self.assertAllEqual(res[0]['b'], [32, 33]) - self.assertAllEqual(res[1], [-32, -31]) - - session.run([features, target]) - with self.assertRaises(errors.OutOfRangeError): - session.run([features, target]) - - coord.request_stop() - coord.join(threads) - - def testNumpyInputFnWithVeryLargeBatchSizeAndMultipleEpochs(self): - a = np.arange(2) * 1.0 - b = np.arange(32, 34) - x = {'a': a, 'b': b} - y = np.arange(-32, -30) - - with self.test_session() as session: - input_fn = numpy_io.numpy_input_fn( - x, y, batch_size=128, shuffle=False, num_epochs=2) - features, target = input_fn() - - coord = coordinator.Coordinator() - threads = queue_runner_impl.start_queue_runners(session, coord=coord) - - res = session.run([features, target]) - self.assertAllEqual(res[0]['a'], [0, 1, 0, 1]) - self.assertAllEqual(res[0]['b'], [32, 33, 32, 33]) - self.assertAllEqual(res[1], [-32, -31, -32, -31]) - - with self.assertRaises(errors.OutOfRangeError): - session.run([features, target]) - - coord.request_stop() - coord.join(threads) - - def testNumpyInputFnWithZeroEpochs(self): - a = np.arange(4) * 1.0 - b = np.arange(32, 36) - x = {'a': a, 'b': b} - y = np.arange(-32, -28) - - with self.test_session() as session: - input_fn = numpy_io.numpy_input_fn( - x, y, batch_size=2, shuffle=False, num_epochs=0) - features, target = input_fn() - - coord = coordinator.Coordinator() - threads = queue_runner_impl.start_queue_runners(session, coord=coord) - - with self.assertRaises(errors.OutOfRangeError): - session.run([features, target]) - - coord.request_stop() - coord.join(threads) - - def testNumpyInputFnWithBatchSizeNotDividedByDataSize(self): - batch_size = 2 - a = np.arange(5) * 1.0 - b = np.arange(32, 37) - x = {'a': a, 'b': b} - y = np.arange(-32, -27) - - with self.test_session() as session: - input_fn = numpy_io.numpy_input_fn( - x, y, batch_size=batch_size, shuffle=False, num_epochs=1) - features, target = input_fn() - - coord = coordinator.Coordinator() - threads = queue_runner_impl.start_queue_runners(session, coord=coord) - - res = session.run([features, target]) - self.assertAllEqual(res[0]['a'], [0, 1]) - self.assertAllEqual(res[0]['b'], [32, 33]) - self.assertAllEqual(res[1], [-32, -31]) - - res = session.run([features, target]) - self.assertAllEqual(res[0]['a'], [2, 3]) - self.assertAllEqual(res[0]['b'], [34, 35]) - self.assertAllEqual(res[1], [-30, -29]) - - res = session.run([features, target]) - self.assertAllEqual(res[0]['a'], [4]) - self.assertAllEqual(res[0]['b'], [36]) - self.assertAllEqual(res[1], [-28]) - - with self.assertRaises(errors.OutOfRangeError): - session.run([features, target]) - - coord.request_stop() - coord.join(threads) - - def testNumpyInputFnWithBatchSizeNotDividedByDataSizeAndMultipleEpochs(self): - batch_size = 2 - a = np.arange(3) * 1.0 - b = np.arange(32, 35) - x = {'a': a, 'b': b} - y = np.arange(-32, -29) - - with self.test_session() as session: - input_fn = numpy_io.numpy_input_fn( - x, y, batch_size=batch_size, shuffle=False, num_epochs=3) - features, target = input_fn() - - coord = coordinator.Coordinator() - threads = queue_runner_impl.start_queue_runners(session, coord=coord) - - res = session.run([features, target]) - self.assertAllEqual(res[0]['a'], [0, 1]) - self.assertAllEqual(res[0]['b'], [32, 33]) - self.assertAllEqual(res[1], [-32, -31]) - - res = session.run([features, target]) - self.assertAllEqual(res[0]['a'], [2, 0]) - self.assertAllEqual(res[0]['b'], [34, 32]) - self.assertAllEqual(res[1], [-30, -32]) - - res = session.run([features, target]) - self.assertAllEqual(res[0]['a'], [1, 2]) - self.assertAllEqual(res[0]['b'], [33, 34]) - self.assertAllEqual(res[1], [-31, -30]) - - res = session.run([features, target]) - self.assertAllEqual(res[0]['a'], [0, 1]) - self.assertAllEqual(res[0]['b'], [32, 33]) - self.assertAllEqual(res[1], [-32, -31]) - - res = session.run([features, target]) - self.assertAllEqual(res[0]['a'], [2]) - self.assertAllEqual(res[0]['b'], [34]) - self.assertAllEqual(res[1], [-30]) - - with self.assertRaises(errors.OutOfRangeError): - session.run([features, target]) - - coord.request_stop() - coord.join(threads) - - def testNumpyInputFnWithBatchSizeLargerThanDataSize(self): - batch_size = 10 - a = np.arange(4) * 1.0 - b = np.arange(32, 36) - x = {'a': a, 'b': b} - y = np.arange(-32, -28) - - with self.test_session() as session: - input_fn = numpy_io.numpy_input_fn( - x, y, batch_size=batch_size, shuffle=False, num_epochs=1) - features, target = input_fn() - - coord = coordinator.Coordinator() - threads = queue_runner_impl.start_queue_runners(session, coord=coord) - - res = session.run([features, target]) - self.assertAllEqual(res[0]['a'], [0, 1, 2, 3]) - self.assertAllEqual(res[0]['b'], [32, 33, 34, 35]) - self.assertAllEqual(res[1], [-32, -31, -30, -29]) - - with self.assertRaises(errors.OutOfRangeError): - session.run([features, target]) - - coord.request_stop() - coord.join(threads) - - def testNumpyInputFnWithDifferentDimensionsOfFeatures(self): - a = np.array([[1, 2], [3, 4]]) - b = np.array([5, 6]) - x = {'a': a, 'b': b} - y = np.arange(-32, -30) - - with self.test_session() as session: - input_fn = numpy_io.numpy_input_fn( - x, y, batch_size=2, shuffle=False, num_epochs=1) - features, target = input_fn() - - coord = coordinator.Coordinator() - threads = queue_runner_impl.start_queue_runners(session, coord=coord) - - res = session.run([features, target]) - self.assertAllEqual(res[0]['a'], [[1, 2], [3, 4]]) - self.assertAllEqual(res[0]['b'], [5, 6]) - self.assertAllEqual(res[1], [-32, -31]) - - coord.request_stop() - coord.join(threads) - - def testNumpyInputFnWithXAsNonDict(self): - x = np.arange(32, 36) - y = np.arange(4) - with self.test_session(): - with self.assertRaisesRegexp(TypeError, 'x must be dict'): - failing_input_fn = numpy_io.numpy_input_fn( - x, y, batch_size=2, shuffle=False, num_epochs=1) - failing_input_fn() - - def testNumpyInputFnWithTargetKeyAlreadyInX(self): - array = np.arange(32, 36) - x = {'__target_key__': array} - y = np.arange(4) - - with self.test_session(): - input_fn = numpy_io.numpy_input_fn( - x, y, batch_size=2, shuffle=False, num_epochs=1) - input_fn() - self.assertAllEqual(x['__target_key__'], array) - self.assertItemsEqual(x.keys(), ['__target_key__']) - - def testNumpyInputFnWithMismatchLengthOfInputs(self): - a = np.arange(4) * 1.0 - b = np.arange(32, 36) - x = {'a': a, 'b': b} - x_mismatch_length = {'a': np.arange(1), 'b': b} - y_longer_length = np.arange(10) - - with self.test_session(): - with self.assertRaisesRegexp( - ValueError, 'Length of tensors in x and y is mismatched.'): - failing_input_fn = numpy_io.numpy_input_fn( - x, y_longer_length, batch_size=2, shuffle=False, num_epochs=1) - failing_input_fn() - - with self.assertRaisesRegexp( - ValueError, 'Length of tensors in x and y is mismatched.'): - failing_input_fn = numpy_io.numpy_input_fn( - x=x_mismatch_length, - y=None, - batch_size=2, - shuffle=False, - num_epochs=1) - failing_input_fn() - - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/contrib/learn/python/learn/metric_spec.py b/tensorflow/contrib/learn/python/learn/metric_spec.py index ed6683abedbb8ae76ba364405158eb52cbb6d762..6440bc204b8e339ff51311dcc87b36f556b94092 100644 --- a/tensorflow/contrib/learn/python/learn/metric_spec.py +++ b/tensorflow/contrib/learn/python/learn/metric_spec.py @@ -42,10 +42,8 @@ def _args(fn): """ if hasattr(fn, 'func') and hasattr(fn, 'keywords'): # Handle functools.partial and similar objects. - return tuple([ - arg for arg in tf_inspect.getargspec(fn.func).args - if arg not in set(fn.keywords.keys()) - ]) + return tuple( + [arg for arg in _args(fn.func) if arg not in set(fn.keywords.keys())]) # Handle function. return tuple(tf_inspect.getargspec(fn).args) diff --git a/tensorflow/contrib/learn/python/learn/utils/export.py b/tensorflow/contrib/learn/python/learn/utils/export.py index 6af2287761299f6725f9547917101c18b0cc0164..cb34cb1d26b6812c7f3f39e9f965615de5a8ef07 100644 --- a/tensorflow/contrib/learn/python/learn/utils/export.py +++ b/tensorflow/contrib/learn/python/learn/utils/export.py @@ -20,7 +20,7 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.framework import deprecated -from tensorflow.contrib.framework.python.ops import variables as contrib_variables +from tensorflow.python.training import training_util from tensorflow.contrib.session_bundle import exporter from tensorflow.contrib.session_bundle import gc from tensorflow.python.client import session as tf_session @@ -78,7 +78,7 @@ def _export_graph(graph, saver, checkpoint_path, export_dir, default_graph_signature=default_graph_signature, named_graph_signatures=named_graph_signatures, assets_collection=ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS)) - return export.export(export_dir, contrib_variables.get_global_step(), + return export.export(export_dir, training_util.get_global_step(), session, exports_to_keep=exports_to_keep) @@ -295,7 +295,7 @@ def _export_estimator(estimator, checkpoint_path = (checkpoint_path or tf_saver.latest_checkpoint(estimator._model_dir)) with ops.Graph().as_default() as g: - contrib_variables.create_global_step(g) + training_util.create_global_step(g) if use_deprecated_input_fn: examples = array_ops.placeholder(dtype=dtypes.string, diff --git a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py index 6ffd2a133995a6ff8b35540221fb5676bf5de19f..1593380007b2799fb1d17e92408ab19a7b47fe1e 100644 --- a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py +++ b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py @@ -33,7 +33,6 @@ from __future__ import division from __future__ import print_function import os -import tempfile import time from tensorflow.contrib.layers.python.layers import feature_column @@ -51,6 +50,7 @@ from tensorflow.python.platform import gfile from tensorflow.python.platform import tf_logging as logging from tensorflow.python.saved_model import signature_constants from tensorflow.python.saved_model import signature_def_utils +from tensorflow.python.summary import summary_iterator from tensorflow.python.training import saver from tensorflow.python.util import compat @@ -391,7 +391,8 @@ def make_export_strategy(serving_input_fn, default_output_alternative_key=None, assets_extra=None, as_text=False, - exports_to_keep=5): + exports_to_keep=5, + strip_default_attrs=None): """Create an ExportStrategy for use with Experiment. Args: @@ -412,12 +413,16 @@ def make_export_strategy(serving_input_fn, exports_to_keep: Number of exports to keep. Older exports will be garbage-collected. Defaults to 5. Set to None to disable garbage collection. + strip_default_attrs: Boolean. If True, default attrs in the + `GraphDef` will be stripped on write. This is recommended for better + forward compatibility of the resulting `SavedModel`. Returns: An ExportStrategy that can be passed to the Experiment constructor. """ - def export_fn(estimator, export_dir_base, checkpoint_path=None): + def export_fn(estimator, export_dir_base, checkpoint_path=None, + strip_default_attrs=False): """Exports the given Estimator as a SavedModel. Args: @@ -426,6 +431,8 @@ def make_export_strategy(serving_input_fn, graph and checkpoints. checkpoint_path: The checkpoint path to export. If None (the default), the most recent checkpoint found within the model directory is chosen. + strip_default_attrs: Boolean. If `True`, default-valued attributes will + be removed from the NodeDefs. Returns: The string path to the exported directory. @@ -444,7 +451,8 @@ def make_export_strategy(serving_input_fn, serving_input_fn, assets_extra=assets_extra, as_text=as_text, - checkpoint_path=checkpoint_path) + checkpoint_path=checkpoint_path, + strip_default_attrs=strip_default_attrs) else: export_result = estimator.export_savedmodel( export_dir_base, @@ -452,12 +460,13 @@ def make_export_strategy(serving_input_fn, default_output_alternative_key=default_output_alternative_key, assets_extra=assets_extra, as_text=as_text, - checkpoint_path=checkpoint_path) + checkpoint_path=checkpoint_path, + strip_default_attrs=strip_default_attrs) garbage_collect_exports(export_dir_base, exports_to_keep) return export_result - return export_strategy.ExportStrategy('Servo', export_fn) + return export_strategy.ExportStrategy('Servo', export_fn, strip_default_attrs) def make_parsing_export_strategy(feature_columns, @@ -465,7 +474,8 @@ def make_parsing_export_strategy(feature_columns, assets_extra=None, as_text=False, exports_to_keep=5, - target_core=False): + target_core=False, + strip_default_attrs=None): """Create an ExportStrategy for use with Experiment, using `FeatureColumn`s. Creates a SavedModel export that expects to be fed with a single string @@ -493,6 +503,9 @@ def make_parsing_export_strategy(feature_columns, target_core: If True, prepare an ExportStrategy for use with tensorflow.python.estimator.*. If False (default), prepare an ExportStrategy for use with tensorflow.contrib.learn.python.learn.*. + strip_default_attrs: Boolean. If True, default attrs in the + `GraphDef` will be stripped on write. This is recommended for better + forward compatibility of the resulting `SavedModel`. Returns: An ExportStrategy that can be passed to the Experiment constructor. @@ -509,7 +522,8 @@ def make_parsing_export_strategy(feature_columns, default_output_alternative_key=default_output_alternative_key, assets_extra=assets_extra, as_text=as_text, - exports_to_keep=exports_to_keep) + exports_to_keep=exports_to_keep, + strip_default_attrs=strip_default_attrs) def _default_compare_fn(curr_best_eval_result, cand_eval_result): @@ -543,15 +557,16 @@ def _default_compare_fn(curr_best_eval_result, cand_eval_result): class BestModelSelector(object): """A helper that keeps track of export selection candidates.""" - def __init__(self, compare_fn=None): + def __init__(self, event_file_pattern=None, compare_fn=None): """Constructor of this class. Args: + event_file_pattern: absolute event file name pattern. compare_fn: a function that returns true if the candidate is better than the current best model. """ - self._best_eval_result = None self._compare_fn = compare_fn or _default_compare_fn + self._best_eval_result = self._get_best_eval_result(event_file_pattern) def update(self, checkpoint_path, eval_result): """Records a given checkpoint and exports if this is the best model. @@ -581,11 +596,40 @@ class BestModelSelector(object): else: return '', None + def _get_best_eval_result(self, event_files): + """Get the best eval result from event files. -def make_best_model_export_strategy(serving_input_fn, - exports_to_keep=1, - compare_fn=None, - default_output_alternative_key=None): + Args: + event_files: Absolute pattern of event files. + + Returns: + The best eval result. + """ + if not event_files: + return None + + best_eval_result = None + for event_file in gfile.Glob(os.path.join(event_files)): + for event in summary_iterator.summary_iterator(event_file): + if event.HasField('summary'): + event_eval_result = {} + for value in event.summary.value: + if value.HasField('simple_value'): + event_eval_result[value.tag] = value.simple_value + if best_eval_result is None or self._compare_fn( + best_eval_result, event_eval_result): + best_eval_result = event_eval_result + return best_eval_result + + +def make_best_model_export_strategy( + serving_input_fn, + exports_to_keep=1, + model_dir=None, + event_file_pattern=None, + compare_fn=None, + default_output_alternative_key=None, + strip_default_attrs=None): """Creates an custom ExportStrategy for use with tf.contrib.learn.Experiment. Args: @@ -593,10 +637,24 @@ def make_best_model_export_strategy(serving_input_fn, `InputFnOps`. exports_to_keep: an integer indicating how many historical best models need to be preserved. + model_dir: Directory where model parameters, graph etc. are saved. This will + be used to load eval metrics from the directory when the export strategy + is created. So the best metrics would not be lost even if the export + strategy got preempted, which guarantees that only the best model would + be exported regardless of preemption. If None, however, the export + strategy would not be preemption-safe. To be preemption-safe, both + model_dir and event_file_pattern would be needed. + event_file_pattern: event file name pattern relative to model_dir, e.g. + "eval_continuous/*.tfevents.*". If None, however, the export strategy + would not be preemption-safe. To be preemption-safe, both + model_dir and event_file_pattern would be needed. compare_fn: a function that select the 'best' candidate from a dictionary of evaluation result keyed by corresponding checkpoint path. default_output_alternative_key: the key for default serving signature for multi-headed inference graphs. + strip_default_attrs: Boolean. If True, default attrs in the + `GraphDef` will be stripped on write. This is recommended for better + forward compatibility of the resulting `SavedModel`. Returns: An ExportStrategy that can be passed to the Experiment constructor. @@ -604,9 +662,13 @@ def make_best_model_export_strategy(serving_input_fn, best_model_export_strategy = make_export_strategy( serving_input_fn, exports_to_keep=exports_to_keep, - default_output_alternative_key=default_output_alternative_key) + default_output_alternative_key=default_output_alternative_key, + strip_default_attrs=strip_default_attrs) - best_model_selector = BestModelSelector(compare_fn) + full_event_file_pattern = os.path.join( + model_dir, + event_file_pattern) if model_dir and event_file_pattern else None + best_model_selector = BestModelSelector(full_event_file_pattern, compare_fn) def export_fn(estimator, export_dir_base, checkpoint_path, eval_result=None): """Exports the given Estimator as a SavedModel. @@ -682,22 +744,36 @@ def extend_export_strategy(base_export_strategy, ValueError: If `estimator` is a ${tf.estimator.Estimator} instance and `default_output_alternative_key` was specified or if post_export_fn does not return a valid directory. + RuntimeError: If unable to create temporary or final export directory. """ - tmp_base_export_dir = tempfile.mkdtemp() + tmp_base_export_folder = 'temp-base-export-' + str(int(time.time())) + tmp_base_export_dir = os.path.join(export_dir_base, tmp_base_export_folder) + if gfile.Exists(tmp_base_export_dir): + raise RuntimeError('Failed to obtain base export directory') + gfile.MakeDirs(tmp_base_export_dir) tmp_base_export = base_export_strategy.export( estimator, tmp_base_export_dir, checkpoint_path) - tmp_post_export_dir = tempfile.mkdtemp() + + tmp_post_export_folder = 'temp-post-export-' + str(int(time.time())) + tmp_post_export_dir = os.path.join(export_dir_base, tmp_post_export_folder) + if gfile.Exists(tmp_post_export_dir): + raise RuntimeError('Failed to obtain temp export directory') + + gfile.MakeDirs(tmp_post_export_dir) tmp_post_export = post_export_fn(tmp_base_export, tmp_post_export_dir) if not tmp_post_export.startswith(tmp_post_export_dir): raise ValueError('post_export_fn must return a sub-directory of {}' .format(tmp_post_export_dir)) - export_relpath = os.path.relpath(tmp_post_export, tmp_post_export_dir) - - gfile.Rename( - os.path.join(tmp_post_export_dir, export_relpath), - os.path.join(export_dir_base, export_relpath)) - return os.path.join(export_dir_base, export_relpath) + post_export_relpath = os.path.relpath(tmp_post_export, tmp_post_export_dir) + post_export = os.path.join(export_dir_base, post_export_relpath) + if gfile.Exists(post_export): + raise RuntimeError('Failed to obtain final export directory') + gfile.Rename(tmp_post_export, post_export) + + gfile.DeleteRecursively(tmp_base_export_dir) + gfile.DeleteRecursively(tmp_post_export_dir) + return post_export name = post_export_name if post_export_name else base_export_strategy.name return export_strategy.ExportStrategy(name, export_fn) diff --git a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils_test.py b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils_test.py index ec3a88003f01b3b62591c13472029601b11ba491..14bf1136e8e9ab1488c4850d458382028ec5583d 100644 --- a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils_test.py +++ b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils_test.py @@ -24,13 +24,14 @@ import time from tensorflow.contrib.layers.python.layers import feature_column as fc from tensorflow.contrib.learn.python.learn import export_strategy as export_strategy_lib from tensorflow.contrib.learn.python.learn.estimators import constants -from tensorflow.contrib.learn.python.learn.estimators import estimator as core_estimator +from tensorflow.contrib.learn.python.learn.estimators import estimator from tensorflow.contrib.learn.python.learn.estimators import model_fn from tensorflow.contrib.learn.python.learn.utils import input_fn_utils from tensorflow.contrib.learn.python.learn.utils import saved_model_export_utils from tensorflow.core.framework import tensor_shape_pb2 from tensorflow.core.framework import types_pb2 from tensorflow.core.protobuf import meta_graph_pb2 +from tensorflow.python.estimator import estimator as core_estimator from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops @@ -41,7 +42,7 @@ from tensorflow.python.saved_model import signature_def_utils from tensorflow.python.util import compat -class TestEstimator(core_estimator.Estimator): +class TestEstimator(estimator.Estimator): def __init__(self, *args, **kwargs): super(TestEstimator, self).__init__(*args, **kwargs) @@ -55,7 +56,8 @@ class TestEstimator(core_estimator.Estimator): default_output_alternative_key=None, assets_extra=None, as_text=False, - checkpoint_path=None): + checkpoint_path=None, + strip_default_attrs=False): if not os.path.exists(export_dir): os.makedirs(export_dir) @@ -93,9 +95,9 @@ class SavedModelExportUtilsTest(test.TestCase): name="input-tensor-1:0", dtype=dtype_string, tensor_shape=shape)) expected_signature_def.outputs[ signature_constants.REGRESS_OUTPUTS].CopyFrom( - meta_graph_pb2.TensorInfo(name="output-tensor-1:0", - dtype=dtype_float, - tensor_shape=shape)) + meta_graph_pb2.TensorInfo( + name="output-tensor-1:0", dtype=dtype_float, + tensor_shape=shape)) expected_signature_def.method_name = signature_constants.REGRESS_METHOD_NAME self.assertEqual(actual_signature_def, expected_signature_def) @@ -506,7 +508,9 @@ class SavedModelExportUtilsTest(test.TestCase): input_example = constant_op.constant(["input string"]) input_ops = input_fn_utils.InputFnOps({ "features": input_features - }, None, {"default input": input_example}) + }, None, { + "default input": input_example + }) input_alternatives, _ = ( saved_model_export_utils.get_input_alternatives(input_ops)) output_1 = constant_op.constant([1.0]) @@ -527,8 +531,9 @@ class SavedModelExportUtilsTest(test.TestCase): model_fn.ModeKeys.INFER, predictions={"some_output": constant_op.constant(["4"])}, output_alternatives=provided_output_alternatives) - output_alternatives, _ = (saved_model_export_utils.get_output_alternatives( - model_fn_ops, "head-1")) + output_alternatives, _ = ( + saved_model_export_utils.get_output_alternatives( + model_fn_ops, "head-1")) signature_defs = saved_model_export_utils.build_all_signature_defs( input_alternatives, output_alternatives, "head-1") @@ -546,7 +551,9 @@ class SavedModelExportUtilsTest(test.TestCase): "default_input_alternative:head-3": signature_def_utils.predict_signature_def({ "default input": input_example - }, {"some_output_3": output_3}), + }, { + "some_output_3": output_3 + }), # "features_input_alternative:head-1": # signature_def_utils.regression_signature_def(input_features, # output_1), @@ -589,8 +596,9 @@ class SavedModelExportUtilsTest(test.TestCase): model_fn.ModeKeys.INFER, predictions={"some_output": constant_op.constant(["4"])}, output_alternatives=provided_output_alternatives) - output_alternatives, _ = (saved_model_export_utils.get_output_alternatives( - model_fn_ops, "head-1")) + output_alternatives, _ = ( + saved_model_export_utils.get_output_alternatives( + model_fn_ops, "head-1")) with self.assertRaisesRegexp( ValueError, "A default input_alternative must be provided"): @@ -706,25 +714,72 @@ class SavedModelExportUtilsTest(test.TestCase): self.assertNotEqual("", export_strategy.export(test_estimator, export_dir_base, - "fake_ckpt_0", {"loss": 100})) + "fake_ckpt_0", { + "loss": 100 + })) self.assertNotEqual("", test_estimator.last_exported_dir) self.assertNotEqual("", test_estimator.last_exported_checkpoint) self.assertEqual("", export_strategy.export(test_estimator, export_dir_base, - "fake_ckpt_1", {"loss": 101})) + "fake_ckpt_1", { + "loss": 101 + })) self.assertEqual(test_estimator.last_exported_dir, os.path.join(export_dir_base, "fake_ckpt_0")) self.assertNotEqual("", export_strategy.export(test_estimator, export_dir_base, - "fake_ckpt_2", {"loss": 10})) + "fake_ckpt_2", { + "loss": 10 + })) + self.assertEqual(test_estimator.last_exported_dir, + os.path.join(export_dir_base, "fake_ckpt_2")) + + self.assertEqual("", + export_strategy.export(test_estimator, export_dir_base, + "fake_ckpt_3", { + "loss": 20 + })) + self.assertEqual(test_estimator.last_exported_dir, + os.path.join(export_dir_base, "fake_ckpt_2")) + + def test_make_best_model_export_strategy_with_preemption(self): + model_dir = self.get_temp_dir() + eval_dir_base = os.path.join(model_dir, "eval_continuous") + core_estimator._write_dict_to_summary(eval_dir_base, {"loss": 50}, 1) + core_estimator._write_dict_to_summary(eval_dir_base, {"loss": 60}, 2) + + test_estimator = TestEstimator() + export_strategy = saved_model_export_utils.make_best_model_export_strategy( + serving_input_fn=None, + exports_to_keep=3, + model_dir=model_dir, + event_file_pattern="eval_continuous/*.tfevents.*", + compare_fn=None) + + export_dir_base = os.path.join(self.get_temp_dir(), "export") + self.assertEqual("", + export_strategy.export(test_estimator, export_dir_base, + "fake_ckpt_0", { + "loss": 100 + })) + self.assertEqual("", test_estimator.last_exported_dir) + self.assertEqual("", test_estimator.last_exported_checkpoint) + + self.assertNotEqual("", + export_strategy.export(test_estimator, export_dir_base, + "fake_ckpt_2", { + "loss": 10 + })) self.assertEqual(test_estimator.last_exported_dir, os.path.join(export_dir_base, "fake_ckpt_2")) self.assertEqual("", export_strategy.export(test_estimator, export_dir_base, - "fake_ckpt_3", {"loss": 20})) + "fake_ckpt_3", { + "loss": 20 + })) self.assertEqual(test_estimator.last_exported_dir, os.path.join(export_dir_base, "fake_ckpt_2")) @@ -766,10 +821,11 @@ class SavedModelExportUtilsTest(test.TestCase): test_estimator = TestEstimator() tmpdir = tempfile.mkdtemp() - final_path = final_export_strategy.export(test_estimator, tmpdir, - os.path.join( - tmpdir, "checkpoint")) - self.assertEqual(os.path.join(tmpdir, "rewrite"), final_path) + export_model_dir = os.path.join(tmpdir, "model") + checkpoint_path = os.path.join(tmpdir, "checkpoint") + final_path = final_export_strategy.export(test_estimator, export_model_dir, + checkpoint_path) + self.assertEqual(os.path.join(export_model_dir, "rewrite"), final_path) def test_extend_export_strategy_same_name(self): @@ -795,10 +851,11 @@ class SavedModelExportUtilsTest(test.TestCase): test_estimator = TestEstimator() tmpdir = tempfile.mkdtemp() - final_path = final_export_strategy.export(test_estimator, tmpdir, - os.path.join( - tmpdir, "checkpoint")) - self.assertEqual(os.path.join(tmpdir, "rewrite"), final_path) + export_model_dir = os.path.join(tmpdir, "model") + checkpoint_path = os.path.join(tmpdir, "checkpoint") + final_path = final_export_strategy.export(test_estimator, export_model_dir, + checkpoint_path) + self.assertEqual(os.path.join(export_model_dir, "rewrite"), final_path) def test_extend_export_strategy_raises_error(self): diff --git a/tensorflow/contrib/legacy_seq2seq/python/__init__.py b/tensorflow/contrib/legacy_seq2seq/python/__init__.py index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..52e83069cb0c68b510da46149248369dce376647 100644 --- a/tensorflow/contrib/legacy_seq2seq/python/__init__.py +++ b/tensorflow/contrib/legacy_seq2seq/python/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function diff --git a/tensorflow/contrib/legacy_seq2seq/python/kernel_tests/__init__.py b/tensorflow/contrib/legacy_seq2seq/python/kernel_tests/__init__.py index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..52e83069cb0c68b510da46149248369dce376647 100644 --- a/tensorflow/contrib/legacy_seq2seq/python/kernel_tests/__init__.py +++ b/tensorflow/contrib/legacy_seq2seq/python/kernel_tests/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function diff --git a/tensorflow/contrib/libsvm/BUILD b/tensorflow/contrib/libsvm/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..df96402a4ffd51840f77d58d8066487030362340 --- /dev/null +++ b/tensorflow/contrib/libsvm/BUILD @@ -0,0 +1,102 @@ +package( + default_visibility = ["//visibility:private"], +) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +load("//tensorflow:tensorflow.bzl", "tf_custom_op_library") +load("//tensorflow:tensorflow.bzl", "tf_gen_op_libs") +load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py") +load("//tensorflow:tensorflow.bzl", "tf_kernel_library") +load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") +load("//tensorflow:tensorflow.bzl", "tf_py_test") + +tf_custom_op_library( + name = "python/ops/_libsvm_ops.so", + srcs = [ + "kernels/decode_libsvm_op.cc", + "ops/libsvm_ops.cc", + ], + deps = [ + "//tensorflow/core/kernels:bounds_check_lib", + ], +) + +tf_kernel_library( + name = "libsvm_kernels", + srcs = ["kernels/decode_libsvm_op.cc"], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core/kernels:bounds_check_lib", + ], +) + +tf_gen_op_libs( + op_lib_names = ["libsvm_ops"], + deps = [ + "//tensorflow/core:lib", + ], +) + +tf_gen_op_wrapper_py( + name = "libsvm_ops", + deps = [":libsvm_ops_op_lib"], +) + +tf_custom_op_py_library( + name = "libsvm", + srcs = [ + "__init__.py", + "python/ops/libsvm_ops.py", + ], + dso = [ + ":python/ops/_libsvm_ops.so", + ], + kernels = [ + ":libsvm_kernels", + ":libsvm_ops_op_lib", + ], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + ":libsvm_ops", + "//tensorflow/contrib/util:util_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:platform", + "//tensorflow/python:state_ops", + "//tensorflow/python:training", + ], +) + +tf_py_test( + name = "decode_libsvm_op_test", + srcs = ["python/kernel_tests/decode_libsvm_op_test.py"], + additional_deps = [ + ":libsvm", + "//third_party/py/numpy", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/sigmoid_centered_impl.py b/tensorflow/contrib/libsvm/__init__.py similarity index 58% rename from tensorflow/contrib/distributions/python/ops/bijectors/sigmoid_centered_impl.py rename to tensorflow/contrib/libsvm/__init__.py index 223bc9d042c69be05b0e578835a31ed6e83c0c97..a875863caab29eb59a1834ca9184a5e272cb6656 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/sigmoid_centered_impl.py +++ b/tensorflow/contrib/libsvm/__init__.py @@ -12,28 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""SigmoidCentered bijector.""" +"""Libsvm decoder. + +@@decode_libsvm +""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.distributions.python.ops.bijectors import softmax_centered +from tensorflow.contrib.libsvm.python.ops.libsvm_ops import decode_libsvm +from tensorflow.python.util.all_util import remove_undocumented -__all__ = [ - "SigmoidCentered", +_allowed_symbols = [ + "decode_libsvm", ] - -class SigmoidCentered(softmax_centered.SoftmaxCentered): - """Bijector which computes Y = g(X) = exp([X 0]) / (1 + exp(-X)). - - Equivalent to: `bijector.SoftmaxCentered(event_ndims=0)`. - - See `bijector.SoftmaxCentered` for more details. - """ - - def __init__(self, validate_args=False, name="sigmoid_centered"): - super(SigmoidCentered, self).__init__( - event_ndims=0, validate_args=validate_args, name=name) +remove_undocumented(__name__) diff --git a/tensorflow/contrib/libsvm/kernels/decode_libsvm_op.cc b/tensorflow/contrib/libsvm/kernels/decode_libsvm_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..720c74e3de5907fa006227d1278c45fd2175fe5f --- /dev/null +++ b/tensorflow/contrib/libsvm/kernels/decode_libsvm_op.cc @@ -0,0 +1,168 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/strings/numbers.h" +#include "tensorflow/core/lib/strings/str_util.h" + +namespace tensorflow { + +template +class DecodeLibsvmOp : public OpKernel { + public: + explicit DecodeLibsvmOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("num_features", &num_features_)); + OP_REQUIRES(ctx, (num_features_ >= 1), + errors::InvalidArgument("Invalid number of features \"", + num_features_, "\"")); + } + + void Compute(OpKernelContext* ctx) override { + const Tensor* input_tensor; + OP_REQUIRES_OK(ctx, ctx->input("input", &input_tensor)); + const auto& input_flat = input_tensor->flat(); + + Tensor* label_tensor; + OP_REQUIRES_OK( + ctx, ctx->allocate_output(0, input_tensor->shape(), &label_tensor)); + auto label = label_tensor->flat(); + + std::vector out_values; + std::vector> out_indices; + for (int i = 0; i < input_flat.size(); ++i) { + StringPiece line(input_flat(i)); + str_util::RemoveWhitespaceContext(&line); + + StringPiece piece; + OP_REQUIRES(ctx, str_util::ConsumeNonWhitespace(&line, &piece), + errors::InvalidArgument("No label found for input[", i, + "]: \"", input_flat(i), "\"")); + + Tlabel label_value; + OP_REQUIRES(ctx, + strings::SafeStringToNumeric(piece, &label_value), + errors::InvalidArgument("Label format incorrect: ", piece)); + + label(i) = label_value; + + str_util::RemoveLeadingWhitespace(&line); + while (str_util::ConsumeNonWhitespace(&line, &piece)) { + size_t p = piece.find(':'); + OP_REQUIRES(ctx, (p != StringPiece::npos), + errors::InvalidArgument("Invalid feature \"", piece, "\"")); + + int64 feature_index; + OP_REQUIRES( + ctx, strings::safe_strto64(piece.substr(0, p), &feature_index), + errors::InvalidArgument("Feature format incorrect: ", piece)); + OP_REQUIRES(ctx, (feature_index >= 0), + errors::InvalidArgument( + "Feature index should be >= 0, got ", feature_index)); + + T feature_value; + OP_REQUIRES( + + ctx, + strings::SafeStringToNumeric(piece.substr(p + 1), + &feature_value), + errors::InvalidArgument("Feature format incorrect: ", piece)); + + out_values.emplace_back(feature_value); + out_indices.emplace_back(std::pair(i, feature_index)); + + str_util::RemoveLeadingWhitespace(&line); + } + } + + Tensor* indices_tensor; + OP_REQUIRES_OK(ctx, ctx->allocate_output( + 1, + TensorShape({static_cast(out_indices.size()), + input_tensor->shape().dims() + 1}), + &indices_tensor)); + auto indices = indices_tensor->matrix(); + // Translate flat index to shaped index like np.unravel_index + // Calculate factors for each dimension + std::vector factors(input_tensor->shape().dims()); + factors[input_tensor->shape().dims() - 1] = 1; + for (int j = input_tensor->shape().dims() - 2; j >= 0; j--) { + factors[j] = factors[j + 1] * input_tensor->shape().dim_size(j + 1); + } + for (int i = 0; i < out_indices.size(); i++) { + indices(i, 0) = out_indices[i].first; + int64 value = out_indices[i].first; + for (int j = 0; j < input_tensor->shape().dims(); j++) { + indices(i, j) = value / factors[j]; + value = value % factors[j]; + } + indices(i, input_tensor->shape().dims()) = out_indices[i].second; + } + + Tensor* values_tensor; + OP_REQUIRES_OK(ctx, + ctx->allocate_output( + 2, TensorShape({static_cast(out_values.size())}), + &values_tensor)); + auto values = values_tensor->vec(); + std::copy_n(out_values.begin(), out_values.size(), &values(0)); + + Tensor* shape_tensor; + OP_REQUIRES_OK(ctx, ctx->allocate_output( + 3, TensorShape({input_tensor->shape().dims() + 1}), + &shape_tensor)); + auto shape = shape_tensor->flat(); + for (int i = 0; i < input_tensor->shape().dims(); i++) { + shape(i) = input_tensor->shape().dim_size(i); + } + shape(input_tensor->shape().dims()) = num_features_; + } + + private: + int64 num_features_; +}; + +#define REGISTER_KERNEL(type) \ + REGISTER_KERNEL_BUILDER(Name("DecodeLibsvm") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("dtype") \ + .TypeConstraint("label_dtype"), \ + DecodeLibsvmOp); \ + REGISTER_KERNEL_BUILDER(Name("DecodeLibsvm") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("dtype") \ + .TypeConstraint("label_dtype"), \ + DecodeLibsvmOp); \ + REGISTER_KERNEL_BUILDER(Name("DecodeLibsvm") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("dtype") \ + .TypeConstraint("label_dtype"), \ + DecodeLibsvmOp); \ + REGISTER_KERNEL_BUILDER(Name("DecodeLibsvm") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("dtype") \ + .TypeConstraint("label_dtype"), \ + DecodeLibsvmOp); + +REGISTER_KERNEL(float); +REGISTER_KERNEL(double); +REGISTER_KERNEL(int32); +REGISTER_KERNEL(int64); +#undef REGISTER_KERNEL + +} // namespace tensorflow diff --git a/tensorflow/contrib/libsvm/ops/libsvm_ops.cc b/tensorflow/contrib/libsvm/ops/libsvm_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..dec946189e3cd67e2557b83806c0db79a46e5f82 --- /dev/null +++ b/tensorflow/contrib/libsvm/ops/libsvm_ops.cc @@ -0,0 +1,58 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" + +namespace tensorflow { + +using shape_inference::InferenceContext; + +REGISTER_OP("DecodeLibsvm") + .Input("input: string") + .Output("label: label_dtype") + .Output("feature_indices: int64") + .Output("feature_values: dtype") + .Output("feature_shape: int64") + .Attr("dtype: {float, double, int32, int64} = DT_FLOAT") + .Attr("label_dtype: {float, double, int32, int64} = DT_INT64") + .Attr("num_features: int >= 1") + .SetShapeFn([](InferenceContext* c) { + c->set_output(0, c->input(0)); + + c->set_output(1, c->Matrix(InferenceContext::kUnknownDim, + InferenceContext::kUnknownDim)); + c->set_output(2, c->Vector(InferenceContext::kUnknownDim)); + c->set_output(3, c->Vector(InferenceContext::kUnknownDim)); + + return Status::OK(); + }) + + .Doc(R"doc( +Convert LibSVM input to tensors. The output consists of +a label and a feature tensor. The shape of the label tensor +is the same as input and the shape of the feature tensor is +`[input_shape, num_features]`. + +input: Each string is a record in the LibSVM. +label: A tensor of the same shape as input. +feature_indices: A 2-D int64 tensor of dense_shape [N, ndims]. +feature_values: A 1-D tensor of any type and dense_shape [N]. +feature_shape: A 1-D int64 tensor of dense_shape [ndims]. +num_features: The number of features. +)doc"); + +} // namespace tensorflow diff --git a/tensorflow/contrib/libsvm/python/kernel_tests/decode_libsvm_op_test.py b/tensorflow/contrib/libsvm/python/kernel_tests/decode_libsvm_op_test.py new file mode 100644 index 0000000000000000000000000000000000000000..423dcce8de9b9c77fcfdc8c90c909e2918852905 --- /dev/null +++ b/tensorflow/contrib/libsvm/python/kernel_tests/decode_libsvm_op_test.py @@ -0,0 +1,71 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for DecodeLibsvm op.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.libsvm.python.ops import libsvm_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import sparse_ops +from tensorflow.python.platform import test + + +class DecodeLibsvmOpTest(test.TestCase): + + def testBasic(self): + with self.test_session() as sess: + content = [ + "1 1:3.4 2:0.5 4:0.231", "1 2:2.5 3:inf 5:0.503", + "2 3:2.5 2:nan 1:0.105" + ] + sparse_features, labels = libsvm_ops.decode_libsvm( + content, num_features=6) + features = sparse_ops.sparse_tensor_to_dense( + sparse_features, validate_indices=False) + + self.assertAllEqual(labels.get_shape().as_list(), [3]) + + features, labels = sess.run([features, labels]) + self.assertAllEqual(labels, [1, 1, 2]) + self.assertAllClose( + features, [[0, 3.4, 0.5, 0, 0.231, 0], [0, 0, 2.5, np.inf, 0, 0.503], + [0, 0.105, np.nan, 2.5, 0, 0]]) + + def testNDimension(self): + with self.test_session() as sess: + content = [["1 1:3.4 2:0.5 4:0.231", "1 1:3.4 2:0.5 4:0.231"], + ["1 2:2.5 3:inf 5:0.503", "1 2:2.5 3:inf 5:0.503"], + ["2 3:2.5 2:nan 1:0.105", "2 3:2.5 2:nan 1:0.105"]] + sparse_features, labels = libsvm_ops.decode_libsvm( + content, num_features=6, label_dtype=dtypes.float64) + features = sparse_ops.sparse_tensor_to_dense( + sparse_features, validate_indices=False) + + self.assertAllEqual(labels.get_shape().as_list(), [3, 2]) + + features, labels = sess.run([features, labels]) + self.assertAllEqual(labels, [[1, 1], [1, 1], [2, 2]]) + self.assertAllClose( + features, [[[0, 3.4, 0.5, 0, 0.231, 0], [0, 3.4, 0.5, 0, 0.231, 0]], [ + [0, 0, 2.5, np.inf, 0, 0.503], [0, 0, 2.5, np.inf, 0, 0.503] + ], [[0, 0.105, np.nan, 2.5, 0, 0], [0, 0.105, np.nan, 2.5, 0, 0]]]) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/libsvm/python/ops/libsvm_ops.py b/tensorflow/contrib/libsvm/python/ops/libsvm_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..b3022505635bca81625cf7abd2be5628a4760970 --- /dev/null +++ b/tensorflow/contrib/libsvm/python/ops/libsvm_ops.py @@ -0,0 +1,50 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Libsvm decoder.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.libsvm.ops import gen_libsvm_ops +from tensorflow.contrib.util import loader +from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.platform import resource_loader + + +_libsvm_ops_so = loader.load_op_library( + resource_loader.get_path_to_datafile("_libsvm_ops.so")) + + +def decode_libsvm(content, num_features, dtype=None, label_dtype=None): + """Convert Libsvm records to a tensor of label and a tensor of feature. + + Args: + content: A `Tensor` of type `string`. Each string is a record/row in + the Libsvm format. + num_features: The number of features. + dtype: The type of the output feature tensor. Default to tf.float32. + label_dtype: The type of the output label tensor. Default to tf.int64. + + Returns: + features: A `SparseTensor` of the shape `[input_shape, num_features]`. + labels: A `Tensor` of the same shape as content. + """ + labels, indices, values, shape = gen_libsvm_ops.decode_libsvm( + content, num_features, dtype=dtype, label_dtype=label_dtype) + return sparse_tensor.SparseTensor(indices, values, shape), labels + + +ops.NotDifferentiable("DecodeLibSVM") diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable.py b/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable.py index 7e214905b13db6a7e2f54f15873f5a9aedb4f44f..ec726bbed41a86eb314e3591ecaedaa6bf0e5e9b 100644 --- a/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable.py +++ b/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable.py @@ -102,7 +102,7 @@ class ShardedMutableDenseHashTable(lookup.LookupInterface): keys.get_shape()) def lookup(self, keys, name=None): - if keys.dtype != self._key_dtype: + if keys.dtype.base_dtype != self._key_dtype: raise TypeError('Signature mismatch. Keys must be dtype %s, got %s.' % (self._key_dtype, keys.dtype)) self._check_keys(keys) diff --git a/tensorflow/contrib/linear_optimizer/python/sdca_estimator.py b/tensorflow/contrib/linear_optimizer/python/sdca_estimator.py index 701fc1c0597d1de0b0189e86feafbd1c5bbdc818..05794a42c5f2d0eece6adab36fb5610078cece31 100644 --- a/tensorflow/contrib/linear_optimizer/python/sdca_estimator.py +++ b/tensorflow/contrib/linear_optimizer/python/sdca_estimator.py @@ -19,7 +19,7 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib import layers -from tensorflow.contrib.framework.python.ops import variables as contrib_variables +from tensorflow.python.training import training_util from tensorflow.contrib.learn.python.learn.estimators import estimator from tensorflow.contrib.learn.python.learn.estimators import head as head_lib from tensorflow.contrib.learn.python.learn.estimators import prediction_key @@ -154,7 +154,7 @@ def sdca_model_fn(features, labels, mode, params, config=None): _add_bias_column(feature_columns, features, bias, columns_to_variables) def _train_op_fn(unused_loss): - global_step = contrib_variables.get_global_step() + global_step = training_util.get_global_step() sdca_model, train_op = optimizer.get_train_step( columns_to_variables, weight_column_name, loss_type, features, labels, global_step) diff --git a/tensorflow/contrib/lite/Android.bp b/tensorflow/contrib/lite/Android.bp index be4fa7c390161beddadaa2bcf34b0cdff73b6511..2b91f1e8c900ab8ab1d99cb803944821aa038d84 100644 --- a/tensorflow/contrib/lite/Android.bp +++ b/tensorflow/contrib/lite/Android.bp @@ -37,6 +37,7 @@ cc_library_static { rtti: true, srcs: [ "allocation.cc", + "arena_planner.cc", "error_reporter.cc", "interpreter.cc", "model.cc", @@ -51,7 +52,9 @@ cc_library_static { "gemmlowp_headers", ], cflags: [ + "-Wno-mismatched-tags", "-Wno-sign-compare", + "-Wno-unused-lambda-capture", ], } @@ -73,4 +76,4 @@ build = [ "tflite_static.bp", ] -subdirs = ["kernels"] \ No newline at end of file +subdirs = ["kernels"] diff --git a/tensorflow/contrib/lite/BUILD b/tensorflow/contrib/lite/BUILD index 52460123cc10ec9b2ee13043fd43f84508b05000..13350c5a438b75fe14e8753e5bb1bb77ec8f655b 100644 --- a/tensorflow/contrib/lite/BUILD +++ b/tensorflow/contrib/lite/BUILD @@ -35,6 +35,28 @@ cc_library( hdrs = ["version.h"], ) +cc_library( + name = "arena_planner", + srcs = ["arena_planner.cc"], + hdrs = ["arena_planner.h"], + deps = [ + ":context", + ":graph_info", + ":memory_planner", + ":simple_memory_arena", + ], +) + +cc_test( + name = "arena_planner_test", + size = "small", + srcs = ["arena_planner_test.cc"], + deps = [ + ":arena_planner", + "@com_google_googletest//:gtest", + ], +) + # Main library. No ops are included here. # TODO(aselle): Resolve problems preventing C99 usage. cc_library( @@ -43,6 +65,25 @@ cc_library( hdrs = ["context.h"], ) +cc_library( + name = "graph_info", + hdrs = ["graph_info.h"], + deps = [":context"], +) + +cc_library( + name = "memory_planner", + hdrs = ["memory_planner.h"], + deps = [":context"], +) + +cc_library( + name = "simple_memory_arena", + srcs = ["simple_memory_arena.cc"], + hdrs = ["simple_memory_arena.h"], + deps = [":context"], +) + cc_library( name = "builtin_op_data", hdrs = [ @@ -70,7 +111,6 @@ cc_library( "model.cc", "nnapi_delegate.cc", "optional_debug_tools.cc", - "simple_memory_arena.cc", ], hdrs = [ "allocation.h", @@ -80,13 +120,16 @@ cc_library( "model.h", "nnapi_delegate.h", "optional_debug_tools.h", - "simple_memory_arena.h", ], copts = tflite_copts(), deps = [ + ":arena_planner", ":builtin_op_data", ":context", + ":graph_info", + ":memory_planner", ":schema_fbs_version", + ":simple_memory_arena", "//tensorflow/contrib/lite/kernels:gemm_support", "//tensorflow/contrib/lite/nnapi:nnapi_lib", "//tensorflow/contrib/lite/schema:schema_fbs", @@ -111,6 +154,7 @@ cc_test( deps = [ ":framework", ":string_util", + "//tensorflow/contrib/lite/testing:util", "@com_google_googletest//:gtest", ], ) @@ -133,7 +177,8 @@ cc_test( size = "small", srcs = ["simple_memory_arena_test.cc"], deps = [ - ":framework", + ":simple_memory_arena", + "//tensorflow/contrib/lite/testing:util", "@com_google_googletest//:gtest", ], ) @@ -152,6 +197,7 @@ cc_test( ], deps = [ ":framework", + "//tensorflow/contrib/lite/testing:util", "@com_google_googletest//:gtest", ], ) @@ -163,6 +209,7 @@ cc_test( srcs = ["context_test.cc"], deps = [ ":framework", + "//tensorflow/contrib/lite/testing:util", "@com_google_googletest//:gtest", ], ) diff --git a/tensorflow/contrib/lite/Makefile b/tensorflow/contrib/lite/Makefile index 78402727abdd2742ffff54bf59ca076d8b97b042..7f316292724ea0baaf034d4e914773ad97a957d4 100644 --- a/tensorflow/contrib/lite/Makefile +++ b/tensorflow/contrib/lite/Makefile @@ -56,7 +56,7 @@ LIBS := \ -lz # If we're on Linux, also link in the dl library. -ifeq ($(OS),LINUX) +ifeq ($(HOST_OS),LINUX) LIBS += -ldl -lpthread endif diff --git a/tensorflow/contrib/lite/README.md b/tensorflow/contrib/lite/README.md index c7464bcc9d39b0e884e76f5a3ffa152e98bb0f47..55a524b207b258e794f97e68a96cf01dc60efb7f 100644 --- a/tensorflow/contrib/lite/README.md +++ b/tensorflow/contrib/lite/README.md @@ -4,7 +4,7 @@ TensorFlow Lite is TensorFlow's lightweight solution for mobile and embedded dev TensorFlow Lite uses many techniques for achieving low latency like optimizing the kernels for specific mobile apps, pre-fused activations, quantized kernels that allow smaller and faster (fixed-point math) models, and in the future, leverage specialized machine learning hardware to get the best possible performance for a particular model on a particular device. ![image](g3doc/TFLite-Architecture.jpg) -# Getting Started with a Demo App +# Getting Started with an Android Demo App This section contains an example application using TensorFlow Lite for Android devices. The demo is a sample camera app that classifies images continuously using a quantized Mobilenet model. A device running Android 5.0 ( API 21) or higher is required to run the demo. @@ -17,7 +17,7 @@ There are 3 ways to get the demo app to your device In the demo app, inference is done using the TensorFlow Lite Java API. The demo app classifies frames in real-time, displaying the top most probable classifications. It also displays the time taken to detect the object. ## Downloading the pre-built binary -The fastest path to trying the demo, is to download the pre-built binary +The fastest path to trying the demo, is to download the pre-built binary [TfLiteCameraDemo.apk](https://storage.googleapis.com/download.tensorflow.org/deps/tflite/TfLiteCameraDemo.apk) Once the apk is installed, click the app icon to start the app. The first-time the app is opened, the app asks for runtime permissions to access the device camera. The demo app opens the back-camera of the device and recognizes the objects in the camera's field of view. At the bottom of the image (or at the left of the image if the device is in landscape mode), it shows the latency of classification and the top three objects classified. @@ -69,7 +69,7 @@ android_ndk_repository( Additional details on building with Android can be found [here](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/java/demo/README.md). -### Build the source code +### Build the source code Run bazel with the following command to build the demo. Build the demo app: @@ -86,6 +86,17 @@ environment (due to a Bazel bug). ### More about the demo The demo is resizing each camera image frame to (224 width * 224 height) to match the quantized Mobilenet model being used. The resized image is converted into a ByteBuffer row by row of size 1 * 224 * 224 * 3 bytes, where 1 is the number of images in a batch 224 * 224 is the width and height of the image 3 bytes represents three colors of a pixel. This demo uses the TensorFlow Lite Java inference API for models which take a single input and provide a single output. This outputs a two-dimensional array, with the first dimension being the category index and the second dimension being the confidence of classification. The Mobilenet model has 1001 unique categories and the app sorts the probabilities of all the categories and displays the top three. The Mobilenet quantized model is bundled within the assets directory of the app. +# iOS Demo App + +Similar to the Android demo app, there's an iOS camera app that uses exactly the same model (224 * 224 quantized Mobilenet). + +This demo app requires a camera so it doesn't work with simulators. It need to be executed on a real iOS device. Follow the instructions to build and run the demo app: + +1. Follow the Building section [here](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/g3doc/ios.md#building) to build the universal iOS library for TensorFlow Lite. +1. Install [CocoaPods](https://cocoapods.org/) if it wasn't installed yet: `sudo gem install cocoapods`. +1. Run `pod install` in `tensorflow/contrib/lite/examples/ios/camera` to generate the workspace file. +1. Open the project by running `open tflite_camera_example.xcworkspace`, and build the app in XCode. + # TensorFlow Lite Quick Start ## Step 1. Decide which GraphDef to use @@ -156,6 +167,7 @@ graphviz, or [in tensorboard](https://codelabs.developers.google.com/codelabs/te This frozen Graphdef is now ready to be converted to flatbuffer format (.lite) for use on Android or iOS. On Android users have the flexibility to use either the float or quantized versions of the frozen graphdef, if available, using the Tensorflow Optimizing Converter tool. Here is a sample command line to convert the frozen Graphdef to '.lite' format for The Tensorflow Optimizing Converter supports both float and quantized models, however, different configuration parameters are needed depending on whether a FLOAT or QUANTIZED mode is being used. +(Here is a link to the pb [file](https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_1.0_224_frozen.tgz)). ``` bazel build tensorflow/contrib/lite/toco:toco @@ -174,9 +186,9 @@ bazel-bin/tensorflow/contrib/lite/toco/toco -- \ - Setting the input_array, output_array and input_shape arguments are a bit trickier. The easiest way to find these values is to explore the graph in tensorboard . The user should reuse the arguments that were used for specifying the output nodes for inference in the `freeze_graph`step. Note, it is also possible to use the Tensorflow Optimizing Converter through protos either from Python or from the command line see the -documentation [here](https://github.com/tensorflow/tensorflow/tree/mastertensorflow/contrib/lite/python:toco_from_protos target) A developer can then integrate the conversion step into their model design workflow to ensure that a model will be easily convertible to a mobile inference graph. For example, +documentation [here](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/toco/python/toco_from_protos.py). A developer can then integrate the conversion step into their model design workflow to ensure that a model will be easily convertible to a mobile inference graph. For example, -``` +```python import tensorflow as tf img = tf.placeholder(name="img", dtype=tf.float32, shape=(1, 64, 64, 3)) @@ -191,6 +203,12 @@ For detailed instructions on how to use the Tensorflow Optimizing Converter, ple You may refer to the [Ops compatibility guide](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md) for troubleshooting help. If that doesn't help, please file an [issue](https://github.com/tensorflow/tensorflow/issues). +If you would like to see a visual description of your TensorFlow Lite model after conversion, you can use tensorflow/contrib/lite/tools/visualize.py by running +```sh +bazel run tensorflow/contrib/lite/tools:visualize -- model.tflite model_viz.html +``` +and then visualize the resulting HTML file in a browser. + ## Step 3. Use the TensorFlow Lite model for inference in a mobile app After completion of Step 2 the developer should have a .lite model. @@ -204,3 +222,7 @@ Note that you'd need to follow instructions for installing TensorFlow on Android ### For iOS Follow the documentation [here](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/g3doc/ios.md) to get integrate a TFLite model into your app. + +## Core ML support + +Core ML is a machine learning framework used across Apple products. In addition to using Tensorflow Lite models directly in their applications, developers have the option to convert their trained Tensorflow models to the [CoreML](https://developer.apple.com/machine-learning/) format for use on Apple devices. For information on how to use the converter please refer to the [Tensorflow-CoreML converter documentation](https://github.com/tf-coreml/tf-coreml). diff --git a/tensorflow/contrib/lite/arena_planner.cc b/tensorflow/contrib/lite/arena_planner.cc new file mode 100644 index 0000000000000000000000000000000000000000..bf1bcdd1a7a7d3395c45ae95abd5980e9ffc0fc6 --- /dev/null +++ b/tensorflow/contrib/lite/arena_planner.cc @@ -0,0 +1,247 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/arena_planner.h" + +namespace tflite { + +namespace { + +// Memory allocation tuning +constexpr const int kDefaultArenaAlignment = 64; +constexpr const int kDefaultTensorAlignment = 4; + +} // namespace + +struct AllocationInfo { + // The node index requesting this allocation. + int node; + // The tensor index to be allocated or deallocated. + int tensor; + // Whether to allocate or deallocate + enum { ALLOC, DEALLOC } type; +}; + +ArenaPlanner::ArenaPlanner(TfLiteContext* context, + std::unique_ptr graph_info) + : context_(context), + graph_info_(std::move(graph_info)), + arena_(kDefaultArenaAlignment), + persistent_arena_(kDefaultArenaAlignment) {} + +ArenaPlanner::~ArenaPlanner() {} + +int64_t ArenaPlanner::BasePointer(TfLiteAllocationType type) { + if (type == kTfLiteArenaRwPersistent) { + return persistent_arena_.BasePointer(); + } + if (type == kTfLiteArenaRw) { + return arena_.BasePointer(); + } + return 0; +} + +TfLiteStatus ArenaPlanner::ResetAllocations() { + TF_LITE_ENSURE_STATUS(arena_.Clear()); + TF_LITE_ENSURE_STATUS(persistent_arena_.Clear()); + allocs_.clear(); + allocs_.resize(graph_info_->num_tensors()); + return kTfLiteOk; +} + +TfLiteStatus ArenaPlanner::PlanAllocations() { + // Invalidate any existing data. + TF_LITE_ENSURE_STATUS(ResetAllocations()); + + // Keeps track of references to each tensor. + std::vector refcounts(graph_info_->num_tensors(), 0); + + // There will be an entry in alloc_queue_ for the allocation of each tensor + // and another for their deallocation. + alloc_queue_.reserve(2 * graph_info_->num_tensors()); + + // We must make sure the output tensors are never overwritten. We do that by + // artificially adding one to their ref-counts so they are never selected + // for deallocation. + for (int tensor_index : graph_info_->outputs()) { + refcounts[tensor_index]++; + } + + // Count references to node input tensors. + for (int i = 0; i < graph_info_->num_nodes(); ++i) { + const TfLiteNode& node = graph_info_->node(i); + TfLiteIntArray* node_inputs = node.inputs; + for (int j = 0; j < node_inputs->size; ++j) { + int tensor_index = node_inputs->data[j]; + if (tensor_index != kOptionalTensor) { + refcounts[tensor_index]++; + } + } + } + + // Queue all graph inputs for allocation. + for (int tensor_index : graph_info_->inputs()) { + if (tensor_index != kOptionalTensor) { + alloc_queue_.push_back({0, tensor_index, AllocationInfo::ALLOC}); + } + } + + // Go through the graph in execution order. + for (int i = 0; i < graph_info_->num_nodes(); ++i) { + const TfLiteNode& node = graph_info_->node(i); + + // First queue output tensors for allocation. + TfLiteIntArray* node_outputs = node.outputs; + for (int j = 0; j < node_outputs->size; ++j) { + int tensor_index = node_outputs->data[j]; + alloc_queue_.push_back({i, tensor_index, AllocationInfo::ALLOC}); + } + + // Then update the ref-counts of the node's inputs, and if necessary queue + // them for deallocation. + TfLiteIntArray* node_inputs = node.inputs; + for (int j = 0; j < node_inputs->size; ++j) { + int tensor_index = node_inputs->data[j]; + if (tensor_index != kOptionalTensor) { + refcounts[tensor_index]--; + if (refcounts[tensor_index] == 0) { + alloc_queue_.push_back({i, tensor_index, AllocationInfo::DEALLOC}); + } + } + } + } + + // Note that graph outputs will never be scheduled for deallocation. We + // could do that here for completeness, but it won't have any effect. + return kTfLiteOk; +} + +TfLiteStatus ArenaPlanner::ExecuteAllocations(int first_node, int last_node) { + TF_LITE_ENSURE_STATUS(CalculateAllocations(first_node, last_node)); + TF_LITE_ENSURE_STATUS(Commit()); + + for (int i = 0; i < graph_info_->num_tensors(); ++i) { + // TODO(ahentz): we could do this only for the tensors that were modified + // in CalculateAllocations(), instead of redoing it for tensors that + // already had proper pointers. However we must be very careful, because + // SimpleMemoryArena::Commit() could move the base pointer. + TF_LITE_ENSURE_STATUS(ResolveTensorAllocation(i)); + } + + return kTfLiteOk; +} + +TfLiteStatus ArenaPlanner::Commit() { + TF_LITE_ENSURE_STATUS(arena_.Commit(context_)); + TF_LITE_ENSURE_STATUS(persistent_arena_.Commit(context_)); + return kTfLiteOk; +} + +TfLiteStatus ArenaPlanner::CalculateAllocations(int first_node, int last_node) { + int active_node = first_node; + // When dynamic tensors are present this method is called multiple times. + // The items in the alloc_queue_ referring to nodes before first_node were + // processed previously and should be skipped. Entries after last_node are + // not yet ready to be handled. + for (const auto& alloc_info : alloc_queue_) { + if (alloc_info.node < first_node) continue; + if (alloc_info.node > last_node) break; + if (alloc_info.node == active_node) { + // This is the first allocation/deallocation for a given node. It is + // time to deallocate the previous temporaries and allocate new ones. + if (active_node != first_node) { + TF_LITE_ENSURE_STATUS( + CalculateDeallocationOfInternalTensors(active_node - 1)); + } + TF_LITE_ENSURE_STATUS(CalculateAllocationOfInternalTensors(active_node)); + ++active_node; + } + // Handle the current item. + if (alloc_info.type == AllocationInfo::ALLOC) { + TF_LITE_ENSURE_STATUS(CalculateTensorAllocation(alloc_info.tensor)); + } else { + TF_LITE_ENSURE_STATUS(CalculateTensorDeallocation(alloc_info.tensor)); + } + } + + // Don't forget to deallocate temporaries of last node. + TF_LITE_ENSURE_STATUS( + CalculateDeallocationOfInternalTensors(active_node - 1)); + + return kTfLiteOk; +} + +TfLiteStatus ArenaPlanner::ResolveTensorAllocation(int tensor_index) { + TfLiteTensor& tensor = *graph_info_->tensor(tensor_index); + if (tensor.allocation_type == kTfLiteArenaRw) { + TF_LITE_ENSURE_STATUS( + arena_.ResolveAlloc(context_, allocs_[tensor_index], &tensor.data.raw)); + } + if (tensor.allocation_type == kTfLiteArenaRwPersistent) { + TF_LITE_ENSURE_STATUS(persistent_arena_.ResolveAlloc( + context_, allocs_[tensor_index], &tensor.data.raw)); + } + return kTfLiteOk; +} + +TfLiteStatus ArenaPlanner::CalculateTensorAllocation(int tensor_index) { + TfLiteTensor& tensor = *graph_info_->tensor(tensor_index); + if (tensor.allocation_type == kTfLiteArenaRw) { + TF_LITE_ENSURE_STATUS(arena_.Allocate(context_, kDefaultTensorAlignment, + tensor.bytes, + &allocs_[tensor_index])); + } + if (tensor.allocation_type == kTfLiteArenaRwPersistent) { + TF_LITE_ENSURE_STATUS( + persistent_arena_.Allocate(context_, kDefaultTensorAlignment, + tensor.bytes, &allocs_[tensor_index])); + } + return kTfLiteOk; +} + +TfLiteStatus ArenaPlanner::CalculateTensorDeallocation(int tensor_index) { + TfLiteTensor& tensor = *graph_info_->tensor(tensor_index); + if (tensor.allocation_type == kTfLiteArenaRw) { + TF_LITE_ENSURE_STATUS(arena_.Deallocate(context_, allocs_[tensor_index])); + } + return kTfLiteOk; +} + +TfLiteStatus ArenaPlanner::CalculateAllocationOfInternalTensors( + int node_index) { + if (node_index < graph_info_->num_nodes()) { + const TfLiteNode& node = graph_info_->node(node_index); + TfLiteIntArray* node_temporaries = node.temporaries; + for (int i = 0; i < node_temporaries->size; ++i) { + int tensor_index = node_temporaries->data[i]; + TF_LITE_ENSURE_STATUS(CalculateTensorAllocation(tensor_index)); + } + } + return kTfLiteOk; +} + +TfLiteStatus ArenaPlanner::CalculateDeallocationOfInternalTensors( + int node_index) { + if (node_index < graph_info_->num_nodes()) { + const TfLiteNode& node = graph_info_->node(node_index); + TfLiteIntArray* node_temporaries = node.temporaries; + for (int i = 0; i < node_temporaries->size; ++i) { + int tensor_index = node_temporaries->data[i]; + TF_LITE_ENSURE_STATUS(CalculateTensorDeallocation(tensor_index)); + } + } + return kTfLiteOk; +} + +} // namespace tflite diff --git a/tensorflow/contrib/lite/arena_planner.h b/tensorflow/contrib/lite/arena_planner.h new file mode 100644 index 0000000000000000000000000000000000000000..bd87414ec3c8ac75b99e730fcac977a7afa08806 --- /dev/null +++ b/tensorflow/contrib/lite/arena_planner.h @@ -0,0 +1,107 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_ARENA_PLANNER_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_ARENA_PLANNER_H_ + +#include +#include + +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/graph_info.h" +#include "tensorflow/contrib/lite/memory_planner.h" +#include "tensorflow/contrib/lite/simple_memory_arena.h" + +namespace tflite { + +class AllocationInfo; + +// A memory planner that makes all the allocations using arenas. +// +// Before a model is executed by the interpreter, this class determines when +// each tensor needs to be allocated and deallocated, and preallocates all the +// necessary memory (the PlanAllocations phase). It then assigns portions of +// this memory buffer to each tensor (the ExecuteAllocations phase). Tensors may +// share some of the bufer if a tensor B is to be allocated after another tensor +// A has been deallocated. +// +// If dynamic tensors are used the planning steps can be repeated during model +// execution. Since dynamic tensors don't have sizes until after the +// corresponding operation is executed, this class supports incremental +// planning. +class ArenaPlanner : public MemoryPlanner { + public: + // Ownership of 'context' is not taken and it must remain util the + // ArenaPlanner is destroyed. + ArenaPlanner(TfLiteContext* context, std::unique_ptr graph_info); + ~ArenaPlanner() override; + ArenaPlanner(const ArenaPlanner&) = delete; + ArenaPlanner& operator=(const ArenaPlanner&) = delete; + + TfLiteStatus ResetAllocations() override; + TfLiteStatus PlanAllocations() override; + TfLiteStatus ExecuteAllocations(int first_node, int last_node) override; + + // Returns the base arena location for a given allocation type. + int64_t BasePointer(TfLiteAllocationType type); + + private: + // Make sure all the arenas have reserved enough memory to store all their + // tensors. + TfLiteStatus Commit(); + + // Traverse the allocation queue and reserve space in the appropriate arena + // for all tensors affected by ops in the interval [first_node, last_node]. + TfLiteStatus CalculateAllocations(int first_node, int last_node); + + // Assign absolute memory location to a tensor, based on its relative + // position inside the corresponding arena buffer. + TfLiteStatus ResolveTensorAllocation(int tensor_index); + + // Register an allocation for the given tensor. + TfLiteStatus CalculateTensorAllocation(int tensor_index); + + // Register a deallocation for the given tensor. + TfLiteStatus CalculateTensorDeallocation(int tensor_index); + + // Register an allocation for all internal (temporary) tensors of + // 'node_index'. + TfLiteStatus CalculateAllocationOfInternalTensors(int node_index); + + // Register a deallocation for all internal (temporary) tensors of + // 'node_index'. + TfLiteStatus CalculateDeallocationOfInternalTensors(int node_index); + + TfLiteContext* context_; + std::unique_ptr graph_info_; + + // Stores allocation data for all tensors. + std::vector allocs_; + + // A chronological list of instructions to allocated and deallocate tensors, + // reflecting the way they are used in the graph. + std::vector alloc_queue_; + + // Raw memory buffer that is allocated for all temporary and graph outputs. + // that are declared kTfLiteArenaRw. + SimpleMemoryArena arena_; + + // Raw memory buffer that is allocated for persistent tensors that are + // declared as kTfLiteArenaRwPersistent. + SimpleMemoryArena persistent_arena_; +}; + +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_ARENA_PLANNER_H_ diff --git a/tensorflow/contrib/lite/arena_planner_test.cc b/tensorflow/contrib/lite/arena_planner_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..c27c327abc63d7bd1e3912d368a1dacb62c50ca8 --- /dev/null +++ b/tensorflow/contrib/lite/arena_planner_test.cc @@ -0,0 +1,472 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/arena_planner.h" + +#include + +#include +#include + +namespace tflite { +namespace { + +// A simple op to be used in tests, as syntactic sugar. +class TestOp { + public: + TestOp(std::initializer_list inputs, std::initializer_list outputs, + std::initializer_list temporaries) + : inputs_(inputs), outputs_(outputs), temporaries_(temporaries) {} + + const std::vector& inputs() const { return inputs_; } + const std::vector& outputs() const { return outputs_; } + const std::vector& temporaries() const { return temporaries_; } + + private: + std::vector inputs_; + std::vector outputs_; + std::vector temporaries_; +}; + +// A test graph where inputs are processed by the given nodes to produce +// outputs. +class TestGraph { + public: + TestGraph(std::initializer_list inputs, + std::initializer_list nodes, + std::initializer_list outputs) + : inputs_(inputs), outputs_(outputs) { + int max_tensor_index = 0; + + for (int t : inputs) { + max_tensor_index = std::max(max_tensor_index, t); + } + for (int t : outputs) { + max_tensor_index = std::max(max_tensor_index, t); + } + for (const auto& node : nodes) { + auto int_array = [](const std::vector& x) { + TfLiteIntArray* lite = TfLiteIntArrayCreate(x.size()); + for (size_t i = 0; i < x.size(); i++) lite->data[i] = x[i]; + return lite; + }; + + nodes_.push_back(TfLiteNode()); + nodes_.back().inputs = int_array(node.inputs()); + for (int t : node.inputs()) { + max_tensor_index = std::max(max_tensor_index, t); + } + nodes_.back().outputs = int_array(node.outputs()); + for (int t : node.outputs()) { + max_tensor_index = std::max(max_tensor_index, t); + } + nodes_.back().temporaries = int_array(node.temporaries()); + for (int t : node.temporaries()) { + max_tensor_index = std::max(max_tensor_index, t); + } + } + + for (int i = 0; i <= max_tensor_index; ++i) { + tensors_.push_back(TfLiteTensor()); + // Set some default values for allocation_type and bytes, which are the + // only fields used by the arena planner. + tensors_.back().allocation_type = kTfLiteArenaRw; + tensors_.back().bytes = (i + 1) * 3; + } + } + + ~TestGraph() { + for (auto node : nodes_) { + TfLiteIntArrayFree(node.inputs); + TfLiteIntArrayFree(node.outputs); + TfLiteIntArrayFree(node.temporaries); + } + } + + const std::vector& nodes() { return nodes_; } + std::vector* tensors() { return &tensors_; } + const std::vector& inputs() { return inputs_; } + const std::vector& outputs() { return outputs_; } + + private: + std::vector nodes_; + std::vector tensors_; + std::vector inputs_; + std::vector outputs_; +}; + +// The GraphInfo for a TestGraph. +class TestGraphInfo : public GraphInfo { + public: + explicit TestGraphInfo(TestGraph* graph) : graph_(graph) {} + + size_t num_tensors() const override { return graph_->tensors()->size(); } + TfLiteTensor* tensor(size_t index) override { + return &graph_->tensors()->at(index); + } + size_t num_nodes() const override { return graph_->nodes().size(); } + const TfLiteNode& node(size_t index) const override { + return graph_->nodes()[index]; + } + const std::vector& inputs() const override { return graph_->inputs(); } + const std::vector& outputs() const override { return graph_->outputs(); } + + private: + TestGraph* graph_; +}; + +void ReportError(TfLiteContext* context, const char* format, ...) { + const size_t kBufferSize = 1024; + char temp_buffer[kBufferSize]; + + va_list args; + va_start(args, format); + vsnprintf(temp_buffer, kBufferSize, format, args); + va_end(args); + + LOG(INFO) << temp_buffer; +} + +class ArenaPlannerTest : public ::testing::Test { + protected: + void SetGraph(TestGraph* graph) { + graph_ = graph; + context_.ReportError = ReportError; + planner_.reset(new ArenaPlanner( + &context_, std::unique_ptr(new TestGraphInfo(graph)))); + CHECK(planner_->ResetAllocations() == kTfLiteOk); + CHECK(planner_->PlanAllocations() == kTfLiteOk); + } + + void Execute(int start, int end) { + CHECK(planner_->ExecuteAllocations(start, end) == kTfLiteOk); + } + + // Returns the actual offset of a given tensor, relative to the start of its + // arena. + int64_t GetOffset(int tensor_index) { + const TfLiteTensor& tensor = (*graph_->tensors())[tensor_index]; + return reinterpret_cast(tensor.data.raw) - + planner_->BasePointer(tensor.allocation_type); + } + + // Returns the first aligned offset after a given tensor. + int64_t GetOffsetAfter(int tensor_index) { + const TfLiteTensor& tensor = (*graph_->tensors())[tensor_index]; + int64_t offset = GetOffset(tensor_index) + tensor.bytes; + // We must make sure the offset is aligned to kDefaultArenaAlignment. + if (offset % 4 != 0) { + offset += 4 - offset % 4; + } + return offset; + }; + + TfLiteContext context_; + TestGraph* graph_; + std::unique_ptr planner_; +}; + +TEST_F(ArenaPlannerTest, EmptyGraph) { + TestGraph graph({}, {}, {}); + SetGraph(&graph); + Execute(0, 10); +} + +TEST_F(ArenaPlannerTest, GraphWithNoOps) { + TestGraph graph({0, 10}, {}, {5, 11}); + SetGraph(&graph); + Execute(0, 10); + EXPECT_EQ(GetOffset(0), 0); + EXPECT_EQ(GetOffset(10), GetOffsetAfter(0)); + // The outputs are never allocated because they are not connected to any + // inputs. + EXPECT_EQ(GetOffset(5), 0); + EXPECT_EQ(GetOffset(11), 0); +} + +TEST_F(ArenaPlannerTest, GraphWithOneOp) { + TestGraph graph({1}, {{{1}, {2}, {}}}, {2}); + SetGraph(&graph); + Execute(0, 10); + EXPECT_EQ(GetOffset(1), 0); + EXPECT_EQ(GetOffset(2), GetOffsetAfter(1)); +} + +TEST_F(ArenaPlannerTest, ZeroSizedTensors) { + TestGraph graph({1}, {{{1}, {2}, {}}}, {2}); + (*graph.tensors())[1].bytes = 0; + SetGraph(&graph); + // TODO(ahentz): this is currently broken because the arena finds two + // allocations with the same offset and returns an error. + ASSERT_FALSE(planner_->ExecuteAllocations(0, 10) == kTfLiteOk); + // EXPECT_EQ(GetOffset(1), 0); + // EXPECT_EQ(GetOffset(2), GetOffsetAfter(1)); +} + +TEST_F(ArenaPlannerTest, SimpleGraph) { + TestGraph graph({0, 1}, + { + /* in, out, tmp */ + {{0, 1}, {2}, {}}, // First op + {{2, 0}, {4, 5}, {}}, // Second op + {{4, 5}, {3}, {}} // Third op + }, + {3}); + SetGraph(&graph); + Execute(0, 10); + + // Alloc(+) and dealloc(-) order: +0 +1 +2 -1 +4 +5 -2 -0 +3 -4 -5 + EXPECT_EQ(GetOffset(0), 0); + EXPECT_EQ(GetOffset(1), GetOffsetAfter(0)); + EXPECT_EQ(GetOffset(2), GetOffsetAfter(1)); + EXPECT_EQ(GetOffset(4), GetOffsetAfter(2)); + EXPECT_EQ(GetOffset(5), GetOffsetAfter(4)); + EXPECT_EQ(GetOffset(3), 0); +} + +TEST_F(ArenaPlannerTest, SimpleGraphWithTemporary) { + TestGraph graph({0, 1}, + { + /* in, out, tmp */ + {{0, 1}, {2}, {}}, // First op + {{2, 0}, {4}, {5}}, // Second op, with temporary + {{4}, {3}, {}} // Third op + }, + {3}); + SetGraph(&graph); + Execute(0, 10); + + // Alloc(+) and dealloc(-) order: +0 +1 +2 -1 +5 +4 -2 -0 -5 +3 -4 + EXPECT_EQ(GetOffset(0), 0); + EXPECT_EQ(GetOffset(1), GetOffsetAfter(0)); + EXPECT_EQ(GetOffset(2), GetOffsetAfter(1)); + EXPECT_EQ(GetOffset(5), GetOffsetAfter(2)); + EXPECT_EQ(GetOffset(4), GetOffsetAfter(5)); + EXPECT_EQ(GetOffset(3), 0); +} + +TEST_F(ArenaPlannerTest, SimpleGraphWithOptionals) { + TestGraph graph({0, -1, 1}, + { + /* in, out, tmp */ + {{0, 1}, {2}, {}}, // First op + {{2, 0}, {4, 5}, {}}, // Second op + {{4, -1, 5}, {3}, {}} // Third op, with optional + }, + {3}); + SetGraph(&graph); + Execute(0, 10); + + // Alloc(+) and dealloc(-) order: +0 +1 +2 -1 +4 +5 -2 -0 +3 -4 -5 + EXPECT_EQ(GetOffset(0), 0); + EXPECT_EQ(GetOffset(1), GetOffsetAfter(0)); + EXPECT_EQ(GetOffset(2), GetOffsetAfter(1)); + EXPECT_EQ(GetOffset(4), GetOffsetAfter(2)); + EXPECT_EQ(GetOffset(5), GetOffsetAfter(4)); + EXPECT_EQ(GetOffset(3), 0); +} + +TEST_F(ArenaPlannerTest, SimpleGraphWithLargeTensor) { + TestGraph graph({0, -1, 1}, + { + /* in, out, tmp */ + {{0, 1}, {2}, {}}, // First op + {{2, 0}, {4}, {5}}, // Second op, with temporary + {{4, -1}, {3}, {}} // Third op, with optional + }, + {3}); + + // Make #1 very large so its vacancy can be filled with #5 and #4. + (*graph.tensors())[1].bytes = 40; + + SetGraph(&graph); + Execute(0, 10); + + // Alloc(+) and dealloc(-) order: +0 +1 +2 -1 +5 +4 -2 -0 -5 +3 -4 + EXPECT_EQ(GetOffset(0), 0); + EXPECT_EQ(GetOffset(1), GetOffsetAfter(0)); + EXPECT_EQ(GetOffset(2), GetOffsetAfter(1)); + EXPECT_EQ(GetOffset(5), GetOffsetAfter(0)); + EXPECT_EQ(GetOffset(4), GetOffsetAfter(5)); + EXPECT_EQ(GetOffset(3), 0); +} + +TEST_F(ArenaPlannerTest, SimpleGraphWithPersistentTensor) { + TestGraph graph({0, -1, 1}, + { + /* in, out, tmp */ + {{0, 1}, {2}, {}}, // First op + {{2, 0}, {4}, {5}}, // Second op, with temporary + {{4, -1}, {3}, {}} // Third op, with optional + }, + {3}); + + // Make #1 persistent so it goes into its own arena. + (*graph.tensors())[1].allocation_type = kTfLiteArenaRwPersistent; + + SetGraph(&graph); + Execute(0, 10); + + // Make sure #0 and #1 were given different memory locations (because they + // will both have offset=0, in different arenas.) + EXPECT_NE((*graph.tensors())[0].data.raw, (*graph.tensors())[1].data.raw); + + // Alloc(+) and dealloc(-) order: +0 +1 +2 -1 +5 +4 -2 -0 -5 +3 -4 + EXPECT_EQ(GetOffset(0), 0); + EXPECT_EQ(GetOffset(1), 0); + EXPECT_EQ(GetOffset(2), GetOffsetAfter(0)); + EXPECT_EQ(GetOffset(5), GetOffsetAfter(2)); + EXPECT_EQ(GetOffset(4), GetOffsetAfter(5)); + EXPECT_EQ(GetOffset(3), 0); +} + +TEST_F(ArenaPlannerTest, SimpleGraphWithDynamicTensor) { + TestGraph graph({0, -1, 1}, + { + /* in, out, tmp */ + {{0, 1}, {2}, {}}, // First op + {{2, 0}, {4}, {5}}, // Second op, with temporary + {{4, -1}, {3}, {}} // Third op, with optional + }, + {3}); + + // Make #1 dynaic so it does not get allocated. + (*graph.tensors())[1].allocation_type = kTfLiteDynamic; + + SetGraph(&graph); + Execute(0, 10); + + EXPECT_EQ((*graph.tensors())[1].data.raw, nullptr); + + // Alloc(+) and dealloc(-) order: +0 +1 +2 -1 +5 +4 -2 -0 -5 +3 -4 + EXPECT_EQ(GetOffset(0), 0); + EXPECT_EQ(GetOffset(2), GetOffsetAfter(0)); + EXPECT_EQ(GetOffset(5), GetOffsetAfter(2)); + EXPECT_EQ(GetOffset(4), GetOffsetAfter(5)); + EXPECT_EQ(GetOffset(3), 0); +} + +TEST_F(ArenaPlannerTest, LargerGraphAndStepwiseAllocation) { + TestGraph graph({0, 1}, + { + /* in, out, tmp */ + {{0, 1}, {2, 3}, {}}, + {{2, 0}, {4, 5}, {6}}, + {{1, -1}, {7}, {}}, + {{7, 3}, {8}, {9}}, + {{4, 5, 8}, {10}, {}}, + }, + {10}); + SetGraph(&graph); + + auto is_unallocated = [&](int tensor_index) { + // TODO(ahentz): We'd to use nullptr to represent unallocated tensors, but + // the current code still points them all to the beginning fo the alloc + // (that is, zero offset). + // return (*graph.tensors())[tensor_index].data.raw == nullptr; + return GetOffset(tensor_index) == 0; + }; + + // The allocation plan is made at the beginning and is independent of + // the execution steps. Here's the allocation order: + // Op0: +0 +1 +2 +3 + // Op1: +6 +4 +5 -6 -0 -2 + // Op2: +7 -1 + // Op3: +9 +8 -9 -3 -7 + // Op4: +10 -4 -5 -8 + + Execute(0, 0); + EXPECT_EQ(GetOffset(0), 0); + EXPECT_EQ(GetOffset(1), GetOffsetAfter(0)); + EXPECT_EQ(GetOffset(2), GetOffsetAfter(1)); + EXPECT_EQ(GetOffset(3), GetOffsetAfter(2)); + EXPECT_TRUE(is_unallocated(6)); + EXPECT_TRUE(is_unallocated(4)); + EXPECT_TRUE(is_unallocated(5)); + EXPECT_TRUE(is_unallocated(7)); + EXPECT_TRUE(is_unallocated(9)); + EXPECT_TRUE(is_unallocated(8)); + EXPECT_TRUE(is_unallocated(10)); + + Execute(1, 1); + EXPECT_EQ(GetOffset(0), 0); + EXPECT_EQ(GetOffset(1), GetOffsetAfter(0)); + EXPECT_EQ(GetOffset(2), GetOffsetAfter(1)); + EXPECT_EQ(GetOffset(3), GetOffsetAfter(2)); + EXPECT_EQ(GetOffset(6), GetOffsetAfter(3)); + EXPECT_EQ(GetOffset(4), GetOffsetAfter(6)); + EXPECT_EQ(GetOffset(5), GetOffsetAfter(4)); + EXPECT_TRUE(is_unallocated(7)); + EXPECT_TRUE(is_unallocated(9)); + EXPECT_TRUE(is_unallocated(8)); + EXPECT_TRUE(is_unallocated(10)); + + Execute(2, 2); + EXPECT_EQ(GetOffset(0), 0); + EXPECT_EQ(GetOffset(1), GetOffsetAfter(0)); + EXPECT_EQ(GetOffset(2), GetOffsetAfter(1)); + EXPECT_EQ(GetOffset(3), GetOffsetAfter(2)); + EXPECT_EQ(GetOffset(6), GetOffsetAfter(3)); + EXPECT_EQ(GetOffset(4), GetOffsetAfter(6)); + EXPECT_EQ(GetOffset(5), GetOffsetAfter(4)); + // Here's an interesting allocation. Even though #6 requires only 21 bytes, + // its deallocation freed up 24 bytes due to the alignment requirements in + // the arena. That means we can fit #7 in the same space! + EXPECT_EQ(GetOffset(7), GetOffsetAfter(3)); + EXPECT_TRUE(is_unallocated(9)); + EXPECT_TRUE(is_unallocated(8)); + EXPECT_TRUE(is_unallocated(10)); + + Execute(3, 3); + EXPECT_EQ(GetOffset(0), 0); + EXPECT_EQ(GetOffset(1), GetOffsetAfter(0)); + EXPECT_EQ(GetOffset(2), GetOffsetAfter(1)); + EXPECT_EQ(GetOffset(3), GetOffsetAfter(2)); + EXPECT_EQ(GetOffset(6), GetOffsetAfter(3)); + EXPECT_EQ(GetOffset(4), GetOffsetAfter(6)); + EXPECT_EQ(GetOffset(5), GetOffsetAfter(4)); + EXPECT_EQ(GetOffset(7), GetOffsetAfter(3)); + // The deallocation of #0, #1 and #2 freed up 24 bytes but that's not enough + // for #9, so it goes at the end. + EXPECT_EQ(GetOffset(9), GetOffsetAfter(5)); + EXPECT_EQ(GetOffset(8), GetOffsetAfter(9)); + EXPECT_TRUE(is_unallocated(10)); + + Execute(4, 4); + EXPECT_EQ(GetOffset(0), 0); + EXPECT_EQ(GetOffset(1), GetOffsetAfter(0)); + EXPECT_EQ(GetOffset(2), GetOffsetAfter(1)); + EXPECT_EQ(GetOffset(3), GetOffsetAfter(2)); + EXPECT_EQ(GetOffset(6), GetOffsetAfter(3)); + EXPECT_EQ(GetOffset(4), GetOffsetAfter(6)); + EXPECT_EQ(GetOffset(5), GetOffsetAfter(4)); + EXPECT_EQ(GetOffset(7), GetOffsetAfter(3)); + EXPECT_EQ(GetOffset(9), GetOffsetAfter(5)); + EXPECT_EQ(GetOffset(8), GetOffsetAfter(9)); + // There's just enough space at the beginning for #10 due to the + // deallocation of #0, #1, #2 and #3 (total 36 bytes, #10 needs + // only 33.) + EXPECT_EQ(GetOffset(10), 0); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + // ::tflite::LogToStderr(); + FLAGS_logtostderr = true; + + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/build_def.bzl b/tensorflow/contrib/lite/build_def.bzl index e3c9cdd99beb93e356c148298dcbe6498fbe0306..0a097d5a69a8bc15aa03502f7a2131fc36e36091 100644 --- a/tensorflow/contrib/lite/build_def.bzl +++ b/tensorflow/contrib/lite/build_def.bzl @@ -89,11 +89,11 @@ def tflite_jni_linkopts(): return tflite_jni_linkopts_unstripped() + select({ "//tensorflow:android": [ "-s", # Omit symbol table. + "-latomic", # Required for some uses of ISO C++11 in x86. ], "//conditions:default": [], }) - def tflite_jni_binary(name, copts=tflite_copts(), linkopts=tflite_jni_linkopts(), @@ -223,11 +223,12 @@ def gen_selected_ops(name, model): """ out = name + "_registration.cc" tool = "//tensorflow/contrib/lite/tools:generate_op_registrations" + tflite_path = "//tensorflow/contrib/lite" native.genrule( name = name, srcs = [model], outs = [out], - cmd = ("$(location %s) --input_model=$(location %s) --output_registration=$(location %s)") - % (tool, model, out), + cmd = ("$(location %s) --input_model=$(location %s) --output_registration=$(location %s) --tflite_path=%s") + % (tool, model, out, tflite_path[2:]), tools = [tool], ) diff --git a/tensorflow/contrib/lite/build_ios_universal_lib.sh b/tensorflow/contrib/lite/build_ios_universal_lib.sh index e0f2ef768bfed544ed8acd6c0e3a5823e61a1e8c..4a9023ff33de15dd384531d51e39de4ffeecdb8b 100755 --- a/tensorflow/contrib/lite/build_ios_universal_lib.sh +++ b/tensorflow/contrib/lite/build_ios_universal_lib.sh @@ -1,5 +1,24 @@ #!/bin/bash -x +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + set -e + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +cd "$SCRIPT_DIR/../../.." + make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=x86_64 -j 8 make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=i386 -j 8 make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=armv7 -j 8 diff --git a/tensorflow/contrib/lite/builtin_op_data.h b/tensorflow/contrib/lite/builtin_op_data.h index 93072bf90bd8a18d9011a74c2eec95d86dbdce8a..3b43a1fd5d383b8b9eee1704b7a1b80b8d4059d4 100644 --- a/tensorflow/contrib/lite/builtin_op_data.h +++ b/tensorflow/contrib/lite/builtin_op_data.h @@ -83,6 +83,11 @@ typedef struct { TfLiteFusedActivation activation; } TfLiteRNNParams; +typedef struct { + bool time_major; + TfLiteFusedActivation activation; +} TfLiteSequenceRNNParams; + typedef struct { TfLiteFusedActivation activation; } TfLiteFullyConnectedParams; typedef enum { @@ -104,10 +109,40 @@ typedef struct { TfLiteFusedActivation activation; } TfLiteAddParams; +typedef struct { + // Number of spatial dimensions. + // For now only NHWC is supported, and the value should always be 2. + int num_spatial_dimensions; + // TODO(ahentz): We can't have dynamic data in this struct, at least not yet. + // For now we will fix the maximum possible number of dimensions. + int block_shape[2]; + int before_paddings[2]; + int after_paddings[2]; +} TfLiteSpaceToBatchNDParams; + +typedef struct { + // Number of spatial dimensions. + // For now only NHWC is supported, and the value should always be 2. + int num_spatial_dimensions; + // TODO(ahentz): We can't have dynamic data in this struct, at least not yet. + // For now we will fix the maximum possible number of dimensions. + int block_shape[2]; + int before_crops[2]; + int after_crops[2]; +} TfLiteBatchToSpaceNDParams; + typedef struct { TfLiteFusedActivation activation; } TfLiteMulParams; +typedef struct { + TfLiteFusedActivation activation; +} TfLiteSubParams; + +typedef struct { + TfLiteFusedActivation activation; +} TfLiteDivParams; + typedef struct { TfLiteFusedActivation activation; } TfLiteL2NormParams; @@ -130,6 +165,14 @@ typedef struct { int new_width; } TfLiteResizeBilinearParams; +typedef struct { + // TODO(ahentz): We can't have dynamic data in this struct, at least not yet. + // For now we will fix the maximum possible number of dimensions. + int before_padding[8]; + int after_padding[8]; + int num_dimensions; +} TfLitePadParams; + typedef struct { // TODO(ahentz): We can't have dynamic data in this struct, at least not yet. // For now we will fix the maximum possible number of dimensions. @@ -157,6 +200,32 @@ typedef struct { TfLiteCombinerType combiner; } TfLiteEmbeddingLookupSparseParams; +typedef struct { + int axis; +} TfLiteGatherParams; + +typedef struct { + // TODO(ahentz): We can't have dynamic data in this struct, at least not yet. + // For now we will fix the maximum possible number of dimensions. + int perm[8]; + int num_dimensions; +} TfLiteTransposeParams; + +typedef struct { + // TODO(ahentz): We can't have dynamic data in this struct, at least not yet. + // For now we will fix the maximum possible number of dimensions. + int axis[8]; + int num_axis_dimensions; + bool keep_dims; +} TfLiteMeanParams; + +typedef struct { + // TODO(ahentz): We can't have dynamic data in this struct, at least not yet. + // For now we will fix the maximum possible number of dimensions. + int squeeze_dims[8]; + int num_squeeze_dims; +} TfLiteSqueezeParams; + #ifdef __cplusplus } // extern "C" #endif // __cplusplus diff --git a/tensorflow/contrib/lite/context.h b/tensorflow/contrib/lite/context.h index 41257a53b145cbe7e252c9d4de6ea7ef654431b5..fca71165034a46b39803f4500af8dc5c6f4e8829 100644 --- a/tensorflow/contrib/lite/context.h +++ b/tensorflow/contrib/lite/context.h @@ -141,6 +141,7 @@ typedef struct { // A union of points that points to memory for a given tensor. typedef union { int* i32; + int64_t* i64; float* f; char* raw; const char* raw_const; diff --git a/tensorflow/contrib/lite/context_test.cc b/tensorflow/contrib/lite/context_test.cc index d0a104f43d9b9d148d80ce26b8ecf732d51ef110..20d6f69a25e9f0bb4323cf5d067b8ebd37bb3c23 100644 --- a/tensorflow/contrib/lite/context_test.cc +++ b/tensorflow/contrib/lite/context_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/contrib/lite/context.h" #include +#include "tensorflow/contrib/lite/testing/util.h" namespace tflite { @@ -68,7 +69,7 @@ TEST(IntArray, TestIntArrayEqual) { } // namespace tflite int main(int argc, char** argv) { - // On Linux, add: tflite::LogToStderr(); + ::tflite::LogToStderr(); ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } diff --git a/tensorflow/contrib/lite/download_dependencies.sh b/tensorflow/contrib/lite/download_dependencies.sh index 571d857be7292998996a4fb8101f0070064aa6be..362e5bee25e95e87fa22bb77904056e732c4e140 100755 --- a/tensorflow/contrib/lite/download_dependencies.sh +++ b/tensorflow/contrib/lite/download_dependencies.sh @@ -1,5 +1,5 @@ #!/bin/bash -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,16 +16,12 @@ set -e +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +cd "$SCRIPT_DIR/../../.." + DOWNLOADS_DIR=tensorflow/contrib/lite/downloads BZL_FILE_PATH=tensorflow/workspace.bzl -# Ensure it is being run from repo root -if [ ! -f $BZL_FILE_PATH ]; then - echo "Could not find ${BZL_FILE_PATH}": - echo "Likely you are not running this from the root directory of the repository."; - exit 1; -fi - EIGEN_URL="$(grep -o 'http.*bitbucket.org/eigen/eigen/get/.*tar\.gz' "${BZL_FILE_PATH}" | grep -v bazel-mirror | head -n1)" GEMMLOWP_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/gemmlowp/.*zip' "${BZL_FILE_PATH}" | head -n1)" GOOGLETEST_URL="https://github.com/google/googletest/archive/release-1.8.0.tar.gz" diff --git a/tensorflow/contrib/lite/error_reporter.cc b/tensorflow/contrib/lite/error_reporter.cc index 6ba5384a94dbf9de03fb2e4e2f63074525eafa2d..03fcd5409ceab1895cea3b9e0e4fcb5a127e6a45 100644 --- a/tensorflow/contrib/lite/error_reporter.cc +++ b/tensorflow/contrib/lite/error_reporter.cc @@ -39,7 +39,9 @@ int ErrorReporter::ReportError(void*, const char* format, ...) { } int StderrReporter::Report(const char* format, va_list args) { - return vfprintf(stderr, format, args); + const int result = vfprintf(stderr, format, args); + fputc('\n', stderr); + return result; } ErrorReporter* DefaultErrorReporter() { diff --git a/tensorflow/contrib/lite/error_reporter.h b/tensorflow/contrib/lite/error_reporter.h index 637d456ce7a754c7da34e551869e49b4efd18e3b..d5715e4f90aead79a617fe4576bfe5100d5e121a 100644 --- a/tensorflow/contrib/lite/error_reporter.h +++ b/tensorflow/contrib/lite/error_reporter.h @@ -25,10 +25,10 @@ namespace tflite { // // Usage: // ErrorReporter foo; -// foo.Report("test %d\n", 5); +// foo.Report("test %d", 5); // or // va_list args; -// foo.Report("test %d\n", args); // where args is va_list +// foo.Report("test %d", args); // where args is va_list // // Sublclass ErrorReporter to provide another reporting destination. // For example, if you have a GUI program, you might redirect to a buffer diff --git a/tensorflow/contrib/lite/examples/ios/camera/CameraExampleViewController.mm b/tensorflow/contrib/lite/examples/ios/camera/CameraExampleViewController.mm index ea398ad14e8be4c5a0021befc7cc076549b47e23..10f31bb6f17242c9f7f70f0648ec643f99c5ac86 100644 --- a/tensorflow/contrib/lite/examples/ios/camera/CameraExampleViewController.mm +++ b/tensorflow/contrib/lite/examples/ios/camera/CameraExampleViewController.mm @@ -123,7 +123,11 @@ static void GetTopN(const uint8_t* prediction, const int prediction_size, const AVCaptureDevice* device = [AVCaptureDevice defaultDeviceWithMediaType:AVMediaTypeVideo]; AVCaptureDeviceInput* deviceInput = [AVCaptureDeviceInput deviceInputWithDevice:device error:&error]; - assert(error == nil); + + if (error != nil) { + NSLog(@"Failed to initialize AVCaptureDeviceInput. Note: This app doesn't work with simulator"); + assert(NO); + } if ([session canAddInput:deviceInput]) [session addInput:deviceInput]; diff --git a/tensorflow/contrib/lite/examples/ios/simple/AppDelegate.h b/tensorflow/contrib/lite/examples/ios/simple/AppDelegate.h index 75b1f1da384b527e8332dfba08fec87c65eff8b1..94046d9728258901091f018fd0d081651145f400 100644 --- a/tensorflow/contrib/lite/examples/ios/simple/AppDelegate.h +++ b/tensorflow/contrib/lite/examples/ios/simple/AppDelegate.h @@ -14,8 +14,8 @@ #import -@interface AppDelegate : UIResponder +@interface AppDelegate : UIResponder -@property (strong, nonatomic) UIWindow *window; +@property(strong, nonatomic) UIWindow *window; @end diff --git a/tensorflow/contrib/lite/examples/ios/simple/AppDelegate.mm b/tensorflow/contrib/lite/examples/ios/simple/AppDelegate.mm index 1e808eb976ff3eeda4cf6f81b3c1794c6a037dc8..d1215fa0bffd978b4aaadbd8bc13b07723703c9a 100644 --- a/tensorflow/contrib/lite/examples/ios/simple/AppDelegate.mm +++ b/tensorflow/contrib/lite/examples/ios/simple/AppDelegate.mm @@ -22,8 +22,7 @@ didFinishLaunchingWithOptions:(NSDictionary *)launchOptions { UITabBarController *bar = [[UITabBarController alloc] init]; - [bar setViewControllers: - @[[[RunModelViewController alloc] init]]]; + [bar setViewControllers:@[ [[RunModelViewController alloc] init] ]]; bar.selectedIndex = 0; self.window = [[UIWindow alloc] initWithFrame:[[UIScreen mainScreen] bounds]]; self.window.rootViewController = bar; @@ -31,14 +30,19 @@ return YES; } -- (void)applicationWillResignActive:(UIApplication *)application {} +- (void)applicationWillResignActive:(UIApplication *)application { +} -- (void)applicationDidEnterBackground:(UIApplication *)application {} +- (void)applicationDidEnterBackground:(UIApplication *)application { +} -- (void)applicationWillEnterForeground:(UIApplication *)application {} +- (void)applicationWillEnterForeground:(UIApplication *)application { +} -- (void)applicationDidBecomeActive:(UIApplication *)application {} +- (void)applicationDidBecomeActive:(UIApplication *)application { +} -- (void)applicationWillTerminate:(UIApplication *)application {} +- (void)applicationWillTerminate:(UIApplication *)application { +} @end diff --git a/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.h b/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.h index 4e1a83ccf5a12c609baadab7359c55ec4f464ed8..a4b358b4eb7f6ba109638405091b798d30bd1768 100644 --- a/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.h +++ b/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.h @@ -18,7 +18,7 @@ - (IBAction)getUrl:(id)sender; -@property (weak, nonatomic) IBOutlet UITextView *urlContentTextView; -@property (weak, nonatomic) IBOutlet UITextField *urlTextField; +@property(weak, nonatomic) IBOutlet UITextView *urlContentTextView; +@property(weak, nonatomic) IBOutlet UITextField *urlTextField; @end diff --git a/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.mm b/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.mm index 965d83010516c6db72c9e8b1c33079b3eda204de..a885a57b65c5c40ec13cc1c8893e02f4f75ed106 100644 --- a/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.mm +++ b/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.mm @@ -14,10 +14,10 @@ #import "RunModelViewController.h" -#include -#include #include #include +#include +#include #include #include #include @@ -30,7 +30,11 @@ #include "ios_image_load.h" #define LOG(x) std::cerr -#define CHECK(x) if (!(x)) { LOG(ERROR) << #x << "failed"; exit(1); } +#define CHECK(x) \ + if (!(x)) { \ + LOG(ERROR) << #x << "failed"; \ + exit(1); \ + } NSString* RunInferenceOnImage(); @@ -49,15 +53,12 @@ NSString* RunInferenceOnImage(); // Returns the top N confidence values over threshold in the provided vector, // sorted by confidence in descending order. -static void GetTopN( - const float* prediction, - const int prediction_size, - const int num_results, const float threshold, - std::vector >* top_results) { +static void GetTopN(const float* prediction, const int prediction_size, const int num_results, + const float threshold, std::vector >* top_results) { // Will contain top N results in ascending order. - std::priority_queue, - std::vector >, - std::greater > > top_result_pq; + std::priority_queue, std::vector >, + std::greater > > + top_result_pq; const long count = prediction_size; for (int i = 0; i < count; ++i) { @@ -88,25 +89,26 @@ static void GetTopN( NSString* FilePathForResourceName(NSString* name, NSString* extension) { NSString* file_path = [[NSBundle mainBundle] pathForResource:name ofType:extension]; if (file_path == NULL) { - LOG(FATAL) << "Couldn't find '" << [name UTF8String] << "." - << [extension UTF8String] << "' in bundle."; + LOG(FATAL) << "Couldn't find '" << [name UTF8String] << "." << [extension UTF8String] + << "' in bundle."; } return file_path; } NSString* RunInferenceOnImage() { - std::string graph; + NSString* graph = @"mobilenet_v1_1.0_224"; const int num_threads = 1; std::string input_layer_type = "float"; std::vector sizes = {1, 224, 224, 3}; - NSString* graph_path = FilePathForResourceName(@"mobilenet_v1_1.0_224", @"tflite"); + const NSString* graph_path = FilePathForResourceName(graph, @"tflite"); - std::unique_ptr model(tflite::FlatBufferModel::BuildFromFile([graph_path UTF8String])); + std::unique_ptr model( + tflite::FlatBufferModel::BuildFromFile([graph_path UTF8String])); if (!model) { - LOG(FATAL) << "Failed to mmap model " << graph; + LOG(FATAL) << "Failed to mmap model " << [graph UTF8String]; } - LOG(INFO) << "Loaded model " << graph; + LOG(INFO) << "Loaded model " << [graph UTF8String]; model->error_reporter(); LOG(INFO) << "resolved reporter"; @@ -143,7 +145,7 @@ NSString* RunInferenceOnImage() { std::ifstream t; t.open([labels_path UTF8String]); std::string line; - while(t){ + while (t) { std::getline(t, line); label_strings.push_back(line); } @@ -154,7 +156,8 @@ NSString* RunInferenceOnImage() { int image_width; int image_height; int image_channels; - std::vector image_data = LoadImageFromFile([image_path UTF8String], &image_width, &image_height, &image_channels); + std::vector image_data = + LoadImageFromFile([image_path UTF8String], &image_width, &image_height, &image_channels); const int wanted_width = 224; const int wanted_height = 224; const int wanted_channels = 3; @@ -212,8 +215,7 @@ NSString* RunInferenceOnImage() { std::string predictions = ss.str(); NSString* result = @""; - result = [NSString stringWithFormat: @"%@ - %s", result, - predictions.c_str()]; - + result = [NSString stringWithFormat:@"%@ - %s", result, predictions.c_str()]; + return result; } diff --git a/tensorflow/contrib/lite/examples/ios/simple/ios_image_load.h b/tensorflow/contrib/lite/examples/ios/simple/ios_image_load.h index 7287d0d63d5b4c0b9c9a528578b6341cdb9c9954..98934ce41d349b33d4fc010a39a956e52f3d5721 100644 --- a/tensorflow/contrib/lite/examples/ios/simple/ios_image_load.h +++ b/tensorflow/contrib/lite/examples/ios/simple/ios_image_load.h @@ -17,9 +17,7 @@ #include -std::vector LoadImageFromFile(const char* file_name, - int* out_width, - int* out_height, - int* out_channels); +std::vector LoadImageFromFile(const char* file_name, int* out_width, + int* out_height, int* out_channels); #endif // TENSORFLOW_EXAMPLES_IOS_IOS_IMAGE_LOAD_H_ diff --git a/tensorflow/contrib/lite/examples/ios/simple/ios_image_load.mm b/tensorflow/contrib/lite/examples/ios/simple/ios_image_load.mm index 789522d2a9900b136f91f77c4ada682f1a316848..cb0fe1a7650c572d3745066431f2759daa94ffc9 100644 --- a/tensorflow/contrib/lite/examples/ios/simple/ios_image_load.mm +++ b/tensorflow/contrib/lite/examples/ios/simple/ios_image_load.mm @@ -14,17 +14,16 @@ #include "ios_image_load.h" -#include -#include #include #include +#include +#include #import #import -std::vector LoadImageFromFile(const char* file_name, - int* out_width, int* out_height, - int* out_channels) { +std::vector LoadImageFromFile(const char* file_name, int* out_width, int* out_height, + int* out_channels) { FILE* file_handle = fopen(file_name, "rb"); fseek(file_handle, 0, SEEK_END); const size_t bytes_in_file = ftell(file_handle); @@ -32,11 +31,10 @@ std::vector LoadImageFromFile(const char* file_name, std::vector file_data(bytes_in_file); fread(file_data.data(), 1, bytes_in_file, file_handle); fclose(file_handle); - CFDataRef file_data_ref = CFDataCreateWithBytesNoCopy(NULL, file_data.data(), - bytes_in_file, - kCFAllocatorNull); - CGDataProviderRef image_provider = - CGDataProviderCreateWithCFData(file_data_ref); + + CFDataRef file_data_ref = + CFDataCreateWithBytesNoCopy(NULL, file_data.data(), bytes_in_file, kCFAllocatorNull); + CGDataProviderRef image_provider = CGDataProviderCreateWithCFData(file_data_ref); const char* suffix = strrchr(file_name, '.'); if (!suffix || suffix == file_name) { @@ -44,12 +42,10 @@ std::vector LoadImageFromFile(const char* file_name, } CGImageRef image; if (strcasecmp(suffix, ".png") == 0) { - image = CGImageCreateWithPNGDataProvider(image_provider, NULL, true, - kCGRenderingIntentDefault); - } else if ((strcasecmp(suffix, ".jpg") == 0) || - (strcasecmp(suffix, ".jpeg") == 0)) { - image = CGImageCreateWithJPEGDataProvider(image_provider, NULL, true, - kCGRenderingIntentDefault); + image = CGImageCreateWithPNGDataProvider(image_provider, NULL, true, kCGRenderingIntentDefault); + } else if ((strcasecmp(suffix, ".jpg") == 0) || (strcasecmp(suffix, ".jpeg") == 0)) { + image = + CGImageCreateWithJPEGDataProvider(image_provider, NULL, true, kCGRenderingIntentDefault); } else { CFRelease(image_provider); CFRelease(file_data_ref); @@ -68,9 +64,10 @@ std::vector LoadImageFromFile(const char* file_name, const int bytes_in_image = (bytes_per_row * height); std::vector result(bytes_in_image); const int bits_per_component = 8; - CGContextRef context = CGBitmapContextCreate(result.data(), width, height, - bits_per_component, bytes_per_row, color_space, - kCGImageAlphaPremultipliedLast | kCGBitmapByteOrder32Big); + + CGContextRef context = + CGBitmapContextCreate(result.data(), width, height, bits_per_component, bytes_per_row, + color_space, kCGImageAlphaPremultipliedLast | kCGBitmapByteOrder32Big); CGColorSpaceRelease(color_space); CGContextDrawImage(context, CGRectMake(0, 0, width, height), image); CGContextRelease(context); diff --git a/tensorflow/contrib/lite/examples/ios/simple/main.mm b/tensorflow/contrib/lite/examples/ios/simple/main.mm index d70550a730720e5d6799a186c1beb3cfa04b0b9d..05cb55ddd7a230593863e64b351f6aac31a1b4d7 100644 --- a/tensorflow/contrib/lite/examples/ios/simple/main.mm +++ b/tensorflow/contrib/lite/examples/ios/simple/main.mm @@ -14,7 +14,7 @@ #import -int main(int argc, char * argv[]) { +int main(int argc, char *argv[]) { @autoreleasepool { NSString *delegateClassName = @"AppDelegate"; return UIApplicationMain(argc, argv, nil, delegateClassName); diff --git a/tensorflow/contrib/lite/examples/label_image/BUILD b/tensorflow/contrib/lite/examples/label_image/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..476d85c0314e331d6d3bad382c331a8458fd01a1 --- /dev/null +++ b/tensorflow/contrib/lite/examples/label_image/BUILD @@ -0,0 +1,75 @@ +# Description: +# TensorFlow Lite Example Label Image. + +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow:tensorflow.bzl", "tf_cc_binary") +load("//tensorflow/contrib/lite:build_def.bzl", "tflite_linkopts") + +exports_files(glob([ + "testdata/*.bmp", +])) + +tf_cc_binary( + name = "label_image", + srcs = [ + "get_top_n.h", + "get_top_n_impl.h", + "label_image.cc", + ], + linkopts = tflite_linkopts() + select({ + "//tensorflow:android": [ + "-pie", # Android 5.0 and later supports only PIE + "-lm", # some builtin ops, e.g., tanh, need -lm + ], + "//conditions:default": [], + }), + deps = [ + ":bitmap_helpers", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:string_util", + "//tensorflow/contrib/lite/kernels:builtin_ops", + ], +) + +cc_library( + name = "bitmap_helpers", + srcs = ["bitmap_helpers.cc"], + hdrs = [ + "bitmap_helpers.h", + "bitmap_helpers_impl.h", + "label_image.h", + ], + deps = ["//tensorflow/contrib/lite:string"], +) + +# TODO(ahentz): Test disabled as it has a memory leek from read_bmp +# cc_test( +# name = "label_image_test", +# srcs = [ +# "get_top_n.h", +# "get_top_n_impl.h", +# "label_image_test.cc", +# ], +# data = [ +# "testdata/grace_hopper.bmp", +# ], +# deps = [ +# ":bitmap_helpers", +# "//testing/base/public:gunit", +# ], +# ) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/lite/examples/label_image/bitmap_helpers.cc b/tensorflow/contrib/lite/examples/label_image/bitmap_helpers.cc new file mode 100644 index 0000000000000000000000000000000000000000..0b38cd38c83927c65d251b9356301b6bef7521f2 --- /dev/null +++ b/tensorflow/contrib/lite/examples/label_image/bitmap_helpers.cc @@ -0,0 +1,120 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include + +#include // NOLINT(build/include_order) + +#include "tensorflow/contrib/lite/examples/label_image/bitmap_helpers.h" + +#define LOG(x) std::cerr + +namespace tflite { +namespace label_image { + +uint8_t* decode_bmp(const uint8_t* input, int row_size, uint8_t* const output, + int width, int height, int channels, bool top_down) { + for (int i = 0; i < height; i++) { + int src_pos; + int dst_pos; + + for (int j = 0; j < width; j++) { + if (!top_down) { + src_pos = ((height - 1 - i) * row_size) + j * channels; + } else { + src_pos = i * row_size + j * channels; + } + + dst_pos = (i * width + j) * channels; + + switch (channels) { + case 1: + output[dst_pos] = input[src_pos]; + break; + case 3: + // BGR -> RGB + output[dst_pos] = input[src_pos + 2]; + output[dst_pos + 1] = input[src_pos + 1]; + output[dst_pos + 2] = input[src_pos]; + break; + case 4: + // BGRA -> RGBA + output[dst_pos] = input[src_pos + 2]; + output[dst_pos + 1] = input[src_pos + 1]; + output[dst_pos + 2] = input[src_pos]; + output[dst_pos + 3] = input[src_pos + 3]; + break; + default: + LOG(FATAL) << "Unexpected number of channels: " << channels; + break; + } + } + } + + return output; +} + +uint8_t* read_bmp(const std::string& input_bmp_name, int* width, int* height, + int* channels, Settings* s) { + int begin, end; + + std::ifstream file(input_bmp_name, std::ios::in | std::ios::binary); + if (!file) { + LOG(FATAL) << "input file " << input_bmp_name << " not found\n"; + exit(-1); + } + + begin = file.tellg(); + file.seekg(0, std::ios::end); + end = file.tellg(); + size_t len = end - begin; + + if (s->verbose) LOG(INFO) << "len: " << len << "\n"; + + const uint8_t* img_bytes = new uint8_t[len]; + file.seekg(0, std::ios::beg); + file.read((char*)img_bytes, len); + const int32_t header_size = + *(reinterpret_cast(img_bytes + 10)); + *width = *(reinterpret_cast(img_bytes + 18)); + *height = *(reinterpret_cast(img_bytes + 22)); + const int32_t bpp = *(reinterpret_cast(img_bytes + 28)); + *channels = bpp / 8; + + if (s->verbose) + LOG(INFO) << "width, height, channels: " << *width << ", " << *height + << ", " << *channels << "\n"; + + // there may be padding bytes when the width is not a multiple of 4 bytes + // 8 * channels == bits per pixel + const int row_size = (8 * *channels * *width + 31) / 32 * 4; + + // if height is negative, data layout is top down + // otherwise, it's bottom up + bool top_down = (*height < 0); + + // Decode image, allocating tensor once the image size is known + uint8_t* output = new uint8_t[abs(*height) * *width * *channels]; + const uint8_t* bmp_pixels = &img_bytes[header_size]; + return decode_bmp(bmp_pixels, row_size, output, *width, abs(*height), + *channels, top_down); +} + +} // namespace label_image +} // namespace tflite diff --git a/tensorflow/contrib/lite/examples/label_image/bitmap_helpers.h b/tensorflow/contrib/lite/examples/label_image/bitmap_helpers.h new file mode 100644 index 0000000000000000000000000000000000000000..860e27e5ba9cc9fe23d2a7f9f65dd53bbf76f7a3 --- /dev/null +++ b/tensorflow/contrib/lite/examples/label_image/bitmap_helpers.h @@ -0,0 +1,42 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_BITMAP_HELPERS_H +#define TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_BITMAP_HELPERS_H + +#include "tensorflow/contrib/lite/examples/label_image/bitmap_helpers_impl.h" +#include "tensorflow/contrib/lite/examples/label_image/label_image.h" + +namespace tflite { +namespace label_image { + +uint8_t* read_bmp(const std::string& input_bmp_name, int* width, int* height, + int* channels, Settings* s); + +template +void downsize(T* out, uint8_t* in, int image_height, int image_width, + int image_channels, int wanted_height, int wanted_width, + int wanted_channels, Settings* s); + +// explicit instantiation +template void downsize(uint8_t*, unsigned char*, int, int, int, int, + int, int, Settings*); +template void downsize(float*, unsigned char*, int, int, int, int, int, + int, Settings*); + +} // namespace label_image +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_BITMAP_HELPERS_H diff --git a/tensorflow/contrib/lite/examples/label_image/bitmap_helpers_impl.h b/tensorflow/contrib/lite/examples/label_image/bitmap_helpers_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..64a931082b0cbb4632ec3a814ce654d4f9106bc1 --- /dev/null +++ b/tensorflow/contrib/lite/examples/label_image/bitmap_helpers_impl.h @@ -0,0 +1,49 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_BITMAP_HELPERS_IMPL_H +#define TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_BITMAP_HELPERS_IMPL_H + +#include "tensorflow/contrib/lite/examples/label_image/label_image.h" + +namespace tflite { +namespace label_image { + +template +void downsize(T* out, uint8_t* in, int image_height, int image_width, + int image_channels, int wanted_height, int wanted_width, + int wanted_channels, Settings* s) { + for (int y = 0; y < wanted_height; ++y) { + const int in_y = (y * image_height) / wanted_height; + uint8_t* in_row = in + (in_y * image_width * image_channels); + T* out_row = out + (y * wanted_width * wanted_channels); + for (int x = 0; x < wanted_width; ++x) { + const int in_x = (x * image_width) / wanted_width; + uint8_t* in_pixel = in_row + (in_x * image_channels); + T* out_pixel = out_row + (x * wanted_channels); + for (int c = 0; c < wanted_channels; ++c) { + if (s->input_floating) + out_pixel[c] = (in_pixel[c] - s->input_mean) / s->input_std; + else + out_pixel[c] = in_pixel[c]; + } + } + } +} + +} // namespace label_image +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_BITMAP_HELPERS_IMPL_H diff --git a/tensorflow/contrib/lite/examples/label_image/get_top_n.h b/tensorflow/contrib/lite/examples/label_image/get_top_n.h new file mode 100644 index 0000000000000000000000000000000000000000..70a7586fe6a008f0da20a7bac928ca676e5914ab --- /dev/null +++ b/tensorflow/contrib/lite/examples/label_image/get_top_n.h @@ -0,0 +1,38 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_GET_TOP_N_H +#define TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_GET_TOP_N_H + +#include "tensorflow/contrib/lite/examples/label_image/get_top_n_impl.h" + +namespace tflite { +namespace label_image { + +template +void get_top_n(T* prediction, int prediction_size, size_t num_results, + float threshold, std::vector>* top_results, + bool input_floating); + +// explicit instantiation so that we can use them otherwhere +template void get_top_n(uint8_t*, int, size_t, float, + std::vector>*, bool); +template void get_top_n(float*, int, size_t, float, + std::vector>*, bool); + +} // namespace label_image +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_GET_TOP_N_H diff --git a/tensorflow/contrib/lite/examples/label_image/get_top_n_impl.h b/tensorflow/contrib/lite/examples/label_image/get_top_n_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..e416fbd39b125ea65d1155b19ab0967a9062e71a --- /dev/null +++ b/tensorflow/contrib/lite/examples/label_image/get_top_n_impl.h @@ -0,0 +1,70 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_GET_TOP_N_IMPL_H +#define TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_GET_TOP_N_IMPL_H + +#include +#include + +namespace tflite { +namespace label_image { + +extern bool input_floating; + +// Returns the top N confidence values over threshold in the provided vector, +// sorted by confidence in descending order. +template +void get_top_n(T* prediction, int prediction_size, size_t num_results, + float threshold, std::vector>* top_results, + bool input_floating) { + // Will contain top N results in ascending order. + std::priority_queue, std::vector>, + std::greater>> + top_result_pq; + + const long count = prediction_size; // NOLINT(runtime/int) + for (int i = 0; i < count; ++i) { + float value; + if (input_floating) + value = prediction[i]; + else + value = prediction[i] / 255.0; + // Only add it if it beats the threshold and has a chance at being in + // the top N. + if (value < threshold) { + continue; + } + + top_result_pq.push(std::pair(value, i)); + + // If at capacity, kick the smallest value out. + if (top_result_pq.size() > num_results) { + top_result_pq.pop(); + } + } + + // Copy to output vector and reverse into descending order. + while (!top_result_pq.empty()) { + top_results->push_back(top_result_pq.top()); + top_result_pq.pop(); + } + std::reverse(top_results->begin(), top_results->end()); +} + +} // namespace label_image +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_GET_TOP_N_IMPL_H diff --git a/tensorflow/contrib/lite/examples/label_image/label_image.cc b/tensorflow/contrib/lite/examples/label_image/label_image.cc new file mode 100644 index 0000000000000000000000000000000000000000..4d2e1ce0bc751667393c4b38acc0517980c9f02a --- /dev/null +++ b/tensorflow/contrib/lite/examples/label_image/label_image.cc @@ -0,0 +1,300 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include // NOLINT(build/include_order) +#include // NOLINT(build/include_order) +#include // NOLINT(build/include_order) +#include // NOLINT(build/include_order) +#include // NOLINT(build/include_order) +#include // NOLINT(build/include_order) + +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/contrib/lite/optional_debug_tools.h" +#include "tensorflow/contrib/lite/string_util.h" + +#include "tensorflow/contrib/lite/examples/label_image/bitmap_helpers.h" +#include "tensorflow/contrib/lite/examples/label_image/get_top_n.h" + +#define LOG(x) std::cerr + +namespace tflite { +namespace label_image { + +double get_us(struct timeval t) { return (t.tv_sec * 1000000 + t.tv_usec); } + +// Takes a file name, and loads a list of labels from it, one per line, and +// returns a vector of the strings. It pads with empty strings so the length +// of the result is a multiple of 16, because our model expects that. +TfLiteStatus ReadLabelsFile(const string& file_name, + std::vector* result, + size_t* found_label_count) { + std::ifstream file(file_name); + if (!file) { + LOG(FATAL) << "Labels file " << file_name << " not found\n"; + return kTfLiteError; + } + result->clear(); + string line; + while (std::getline(file, line)) { + result->push_back(line); + } + *found_label_count = result->size(); + const int padding = 16; + while (result->size() % padding) { + result->emplace_back(); + } + return kTfLiteOk; +} + +void RunInference(Settings* s) { + if (!s->model_name.c_str()) { + LOG(ERROR) << "no model file name\n"; + exit(-1); + } + + std::unique_ptr model; + std::unique_ptr interpreter; + model = tflite::FlatBufferModel::BuildFromFile(s->model_name.c_str()); + if (!model) { + LOG(FATAL) << "\nFailed to mmap model " << s->model_name << "\n"; + exit(-1); + } + LOG(INFO) << "Loaded model " << s->model_name << "\n"; + model->error_reporter(); + LOG(INFO) << "resolved reporter\n"; + + tflite::ops::builtin::BuiltinOpResolver resolver; + + tflite::InterpreterBuilder(*model, resolver)(&interpreter); + if (!interpreter) { + LOG(FATAL) << "Failed to construct interpreter\n"; + exit(-1); + } + + interpreter->UseNNAPI(s->accel); + + if (s->verbose) { + LOG(INFO) << "tensors size: " << interpreter->tensors_size() << "\n"; + LOG(INFO) << "nodes size: " << interpreter->nodes_size() << "\n"; + LOG(INFO) << "inputs: " << interpreter->inputs().size() << "\n"; + LOG(INFO) << "input(0) name: " << interpreter->GetInputName(0) << "\n"; + + int t_size = interpreter->tensors_size(); + for (int i = 0; i < t_size; i++) { + if (interpreter->tensor(i)->name) + LOG(INFO) << i << ": " << interpreter->tensor(i)->name << ", " + << interpreter->tensor(i)->bytes << ", " + << interpreter->tensor(i)->type << ", " + << interpreter->tensor(i)->params.scale << ", " + << interpreter->tensor(i)->params.zero_point << "\n"; + } + } + + if (s->number_of_threads != -1) { + interpreter->SetNumThreads(s->number_of_threads); + } + + int image_width = 224; + int image_height = 224; + int image_channels = 3; + uint8_t* in = read_bmp(s->input_bmp_name, &image_width, &image_height, + &image_channels, s); + + int input = interpreter->inputs()[0]; + if (s->verbose) LOG(INFO) << "input: " << input << "\n"; + + const std::vector inputs = interpreter->inputs(); + const std::vector outputs = interpreter->outputs(); + + if (s->verbose) { + LOG(INFO) << "number of inputs: " << inputs.size() << "\n"; + LOG(INFO) << "number of outputs: " << outputs.size() << "\n"; + } + + if (interpreter->AllocateTensors() != kTfLiteOk) { + LOG(FATAL) << "Failed to allocate tensors!"; + } + + if (s->verbose) PrintInterpreterState(interpreter.get()); + + // get input dimension from the input tensor metadata + // assuming one input only + TfLiteIntArray* dims = interpreter->tensor(input)->dims; + int wanted_height = dims->data[1]; + int wanted_width = dims->data[2]; + int wanted_channels = dims->data[3]; + + if (s->input_floating) { + downsize(interpreter->typed_tensor(input), in, image_height, + image_width, image_channels, wanted_height, wanted_width, + wanted_channels, s); + } else { + downsize(interpreter->typed_tensor(input), in, + image_height, image_width, image_channels, wanted_height, + wanted_width, wanted_channels, s); + } + + struct timeval start_time, stop_time; + gettimeofday(&start_time, NULL); + for (int i = 0; i < s->loop_count; i++) { + if (interpreter->Invoke() != kTfLiteOk) { + LOG(FATAL) << "Failed to invoke tflite!\n"; + } + } + gettimeofday(&stop_time, NULL); + LOG(INFO) << "invoked \n"; + LOG(INFO) << "average time: " + << (get_us(stop_time) - get_us(start_time)) / (s->loop_count * 1000) + << " ms \n"; + + const int output_size = 1000; + const size_t num_results = 5; + const float threshold = 0.001f; + + std::vector> top_results; + + if (s->input_floating) { + get_top_n(interpreter->typed_output_tensor(0), output_size, + num_results, threshold, &top_results, s->input_floating); + } else { + get_top_n(interpreter->typed_output_tensor(0), + output_size, num_results, threshold, &top_results, + s->input_floating); + } + + std::vector labels; + size_t label_count; + + if (ReadLabelsFile(s->labels_file_name, &labels, &label_count) != kTfLiteOk) + exit(-1); + + for (const auto& result : top_results) { + const float confidence = result.first; + const int index = result.second; + LOG(INFO) << confidence << ": " << index << " " << labels[index] << "\n"; + } +} + +void display_usage() { + LOG(INFO) << "label_image\n" + << "--accelerated, -a: [0|1], use Android NNAPI or note\n" + << "--count, -c: loop interpreter->Invoke() for certain times\n" + << "--input_floating, -f: [0|1] type of input layer is floating " + "point numbers\n" + << "--input_mean, -b: input mean\n" + << "--input_std, -s: input standard deviation\n" + << "--image, -i: image_name.bmp\n" + << "--labels, -l: labels for the model\n" + << "--tflite_mode, -m: model_name.tflite\n" + << "--threads, -t: number of threads\n" + << "--verbose, -v: [0|1] print more information\n" + << "\n"; +} + +int Main(int argc, char** argv) { + Settings s; + + int c; + while (1) { + static struct option long_options[] = { + {"accelerated", required_argument, 0, 'a'}, + {"count", required_argument, 0, 'c'}, + {"input_floating", required_argument, 0, 'f'}, + {"verbose", required_argument, 0, 'v'}, + {"image", required_argument, 0, 'i'}, + {"labels", required_argument, 0, 'l'}, + {"tflite_model", required_argument, 0, 'm'}, + {"threads", required_argument, 0, 't'}, + {"input_mean", required_argument, 0, 'b'}, + {"input_std", required_argument, 0, 's'}, + {0, 0, 0, 0}}; + + /* getopt_long stores the option index here. */ + int option_index = 0; + + c = getopt_long(argc, argv, "a:b:c:f:i:l:m:s:t:v:", long_options, + &option_index); + + /* Detect the end of the options. */ + if (c == -1) break; + + switch (c) { + case 'a': + s.accel = strtol( // NOLINT(runtime/deprecated_fn) + optarg, (char**)NULL, 10); + break; + case 'b': + s.input_mean = strtod(optarg, NULL); + break; + case 'c': + s.loop_count = strtol( // NOLINT(runtime/deprecated_fn) + optarg, (char**)NULL, 10); + break; + case 'f': + s.input_floating = strtol( // NOLINT(runtime/deprecated_fn) + optarg, (char**)NULL, 10); + s.input_layer_type = "float"; + break; + case 'i': + s.input_bmp_name = optarg; + break; + case 'l': + s.labels_file_name = optarg; + break; + case 'm': + s.model_name = optarg; + break; + case 's': + s.input_std = strtod(optarg, NULL); + break; + case 't': + s.number_of_threads = strtol( // NOLINT(runtime/deprecated_fn) + optarg, (char**)NULL, 10); + break; + case 'v': + s.verbose = strtol( // NOLINT(runtime/deprecated_fn) + optarg, (char**)NULL, 10); + break; + case 'h': + case '?': + /* getopt_long already printed an error message. */ + display_usage(); + exit(-1); + default: + exit(-1); + } + } + RunInference(&s); + return 0; +} + +} // namespace label_image +} // namespace tflite + +int main(int argc, char** argv) { + return tflite::label_image::Main(argc, argv); +} diff --git a/tensorflow/contrib/lite/examples/label_image/label_image.h b/tensorflow/contrib/lite/examples/label_image/label_image.h new file mode 100644 index 0000000000000000000000000000000000000000..ce98e06fc162a9588707eae701e2fcb8d648a4e4 --- /dev/null +++ b/tensorflow/contrib/lite/examples/label_image/label_image.h @@ -0,0 +1,36 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_LABEL_IMAGE_H +#define TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_LABEL_IMAGE_H + +#include +#include "tensorflow/contrib/lite/string.h" + +struct Settings { + bool verbose = false; + bool accel = false; + bool input_floating = false; + int loop_count = 1; + float input_mean = 127.5f; + float input_std = 127.5f; + string model_name = "./mobilenet_quant_v1_224.tflite"; + string input_bmp_name = "./grace_hopper.bmp"; + string labels_file_name = "./labels.txt"; + string input_layer_type = "uint8_t"; + int number_of_threads = 4; +}; + +#endif // TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_LABEL_IMAGE_H diff --git a/tensorflow/contrib/lite/examples/label_image/label_image.md b/tensorflow/contrib/lite/examples/label_image/label_image.md new file mode 100644 index 0000000000000000000000000000000000000000..d6019d673f1b15429e69b57e8dc9eeaad2825bc3 --- /dev/null +++ b/tensorflow/contrib/lite/examples/label_image/label_image.md @@ -0,0 +1,74 @@ +label_image for TensorFlow Lite inspired by TensorFlow's label_image. + +To build it for android ARMv8: +``` +> bazel build --cxxopt=-std=c++11 \ + --crosstool_top=//external:android/crosstool \ + --host_crosstool_top=@bazel_tools//tools/cpp:toolchain \ + --cpu=arm64-v8a \ + //tensorflow/contrib/lite/examples/label_image:label_image +``` +or +``` +> bazel build --config android_arm64 --cxxopt=-std=c++11 \ + //tensorflow/contrib/lite/examples/label_image:label_image +``` + +To build it for android arm-v7a: +``` +> bazel build --cxxopt=-std=c++11 \ + --crosstool_top=//external:android/crosstool \ + --host_crosstool_top=@bazel_tools//tools/cpp:toolchain \ + --cpu=armeabi-v7a \ + //tensorflow/contrib/lite/examples/label_image:label_image +``` +or +``` +> bazel build --config android_arm --cxxopt=-std=c++11 \ + //tensorflow/contrib/lite/examples/label_image:label_image +``` + +Build it for desktop machines (tested on Ubuntu and OS X) +``` +> bazel build --config opt --cxxopt=-std=c++11 //tensorflow/contrib/lite/examples/label_image:label_image +``` +To run it. Prepare `./mobilenet_quant_v1_224.tflite`, `./grace_hopper.bmp`, and `./labels.txt`. + +Run it: +``` +> ./label_image +Loaded model ./mobilenet_quant_v1_224.tflite +resolved reporter +invoked +average time: 100.986 ms +0.439216: 653 military uniform +0.372549: 458 bow tie +0.0705882: 466 bulletproof vest +0.0235294: 514 cornet +0.0196078: 835 suit +``` +Run `interpreter->Invoker()` 100 times: +``` +> ./label_image -c 100 +Loaded model ./mobilenet_quant_v1_224.tflite +resolved reporter +invoked +average time: 33.4694 ms +... +``` + +Run a floating point (`mobilenet_v1_1.0_224.tflite`) model, +``` +> ./label_image -f 1 -m mobilenet_v1_1.0_224.tflite +Loaded model mobilenet_v1_1.0_224.tflite +resolved reporter +invoked +average time: 263.493 ms +0.88615: 653 military uniform +0.0422316: 440 bearskin +0.0109948: 466 bulletproof vest +0.0105327: 401 academic gown +0.00947104: 723 ping-pong bal +``` + +See the source code for other command line options. diff --git a/tensorflow/contrib/lite/examples/label_image/label_image_test.cc b/tensorflow/contrib/lite/examples/label_image/label_image_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..ce35483f76e8f40ced79e1ee30774c62d0eba94e --- /dev/null +++ b/tensorflow/contrib/lite/examples/label_image/label_image_test.cc @@ -0,0 +1,61 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "tensorflow/contrib/lite/examples/label_image/bitmap_helpers.h" +#include "tensorflow/contrib/lite/examples/label_image/get_top_n.h" +#include "tensorflow/contrib/lite/examples/label_image/label_image.h" + +using ::testing::ElementsAreArray; + +namespace tflite { +namespace label_image { + +TEST(LabelImageTest, GraceHopper) { + std::string lena_file = + "tensorflow/contrib/lite/examples/label_image/testdata/grace_hopper.bmp"; + int height, width, channels; + Settings s; + uint8_t *data; + + data = read_bmp(lena_file, &width, &height, &channels, &s); + ASSERT_EQ(height, 606); + ASSERT_EQ(width, 517); + ASSERT_EQ(channels, 3); + + uint8_t *out = new uint8_t[606 * 517 * 3]; + downsize(out, data, 606, 517, 3, 214, 214, 3, &s); + ASSERT_EQ(out[0], 0x15); + ASSERT_EQ(out[214 * 214 * 3 - 1], 0x12); +} + +TEST(LabelImageTest, GetTopN) { + uint8_t in[] = {1, 1, 2, 2, 4, 4, 16, 32, 128, 64}; + + std::vector> top_results; + get_top_n(in, 10, 5, 0.025, &top_results, false); + ASSERT_EQ(top_results.size(), 4); + ASSERT_EQ(top_results[0].second, 8); +} + +} // namespace label_image +} // namespace tflite + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/examples/label_image/testdata/grace_hopper.bmp b/tensorflow/contrib/lite/examples/label_image/testdata/grace_hopper.bmp new file mode 100644 index 0000000000000000000000000000000000000000..0d94cd3e930a138b7c20308f5ba375576484d48b Binary files /dev/null and b/tensorflow/contrib/lite/examples/label_image/testdata/grace_hopper.bmp differ diff --git a/tensorflow/contrib/lite/g3doc/ios.md b/tensorflow/contrib/lite/g3doc/ios.md index ce8b37fbf9b0db5dee60784e85a3cbf0326fddb6..a359b8d4b481dbc15cc86db14eabda5433722b8b 100644 --- a/tensorflow/contrib/lite/g3doc/ios.md +++ b/tensorflow/contrib/lite/g3doc/ios.md @@ -45,6 +45,10 @@ into a universal file containing armv7, armv7s, arm64, i386, and x86_64 architectures. The resulting library is in `tensorflow/contrib/lite/gen/lib/libtensorflow-lite.a`. +If you get an error such as `no such file or directory: 'x86_64'` when running +`build_ios_universal_lib.sh`: open Xcode > Preferences > Locations, and ensure +a value is selected in the "Command Line Tools" dropdown. + ## Using in your own application You'll need to update various settings in your app to link against TensorFlow diff --git a/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md b/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md index 9ade04eb8c696d7e0e39a8104e02b6e5feec95eb..8e5e694a5cbe7f908572114db33c8257db6151f0 100644 --- a/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md +++ b/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md @@ -329,18 +329,18 @@ Inputs { 0: a tensor } Outputs { - 0: a tensor equivalent to max(0, min(input, 1) + 0: a tensor equivalent to max(0, input) } ``` -**RELU1** +**RELU_N1_TO_1** ``` Inputs { 0: a tensor } Outputs { - 0: a tensor equivalent to max(-1, min(input, 6) + 0: a tensor equivalent to max(-1, min(input, 1) } ``` diff --git a/tensorflow/contrib/lite/graph_info.h b/tensorflow/contrib/lite/graph_info.h new file mode 100644 index 0000000000000000000000000000000000000000..5481aede605453958adb2c2e661c73130046d9f9 --- /dev/null +++ b/tensorflow/contrib/lite/graph_info.h @@ -0,0 +1,53 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_GRAPH_INFO_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_GRAPH_INFO_H_ + +#include + +#include "tensorflow/contrib/lite/context.h" + +namespace tflite { + +// Basic information about an inference graph, where execution nodes +// are connected via tensors. +class GraphInfo { + public: + virtual ~GraphInfo() {} + + // Total number of tensors in the graph. + virtual size_t num_tensors() const = 0; + + // Returns a tensor given its index which is expected to be between 0 and + // num_tensors(). + virtual TfLiteTensor* tensor(size_t index) = 0; + + // Total number of nodes in the graph. + virtual size_t num_nodes() const = 0; + + // Returns a node given its index which is expected to be between 0 and + // num_nodes(). + virtual const TfLiteNode& node(size_t index) const = 0; + + // Returns the indices of the input tensors. + virtual const std::vector& inputs() const = 0; + + // Returns the indices of the output tensors. + virtual const std::vector& outputs() const = 0; +}; + +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_GRAPH_INFO_H_ diff --git a/tensorflow/contrib/lite/interpreter.cc b/tensorflow/contrib/lite/interpreter.cc index 954e236ac8f0c8c59a9d20d62e66b3aa1164ecc1..5f5981e45a20a2c79ea1a2ba08345e831ce194da 100644 --- a/tensorflow/contrib/lite/interpreter.cc +++ b/tensorflow/contrib/lite/interpreter.cc @@ -18,16 +18,16 @@ limitations under the License. #include #include #include +#include "tensorflow/contrib/lite/arena_planner.h" #include "tensorflow/contrib/lite/context.h" #include "tensorflow/contrib/lite/error_reporter.h" +#include "tensorflow/contrib/lite/graph_info.h" #include "tensorflow/contrib/lite/kernels/gemm_support.h" +#include "tensorflow/contrib/lite/memory_planner.h" #include "tensorflow/contrib/lite/nnapi_delegate.h" namespace { -// Memory allocation tuning -constexpr const int kDefaultArenaAlignment = 64; -constexpr const int kDefaultTensorAlignment = 4; // std::vector preallocation tuning. constexpr const int kSlotsToReserve = 128; @@ -35,10 +35,33 @@ constexpr const int kSlotsToReserve = 128; namespace tflite { +// A trivial implementation of GraphInfo around the Interpreter. +class InterpreterInfo : public GraphInfo { + public: + explicit InterpreterInfo(Interpreter* interpreter) + : interpreter_(interpreter) {} + + size_t num_tensors() const override { return interpreter_->tensors_size(); } + TfLiteTensor* tensor(size_t index) override { + return interpreter_->tensor(index); + } + size_t num_nodes() const override { return interpreter_->nodes_size(); } + const TfLiteNode& node(size_t index) const override { + return interpreter_->node_and_registration(index)->first; + } + const std::vector& inputs() const override { + return interpreter_->inputs(); + } + const std::vector& outputs() const override { + return interpreter_->outputs(); + } + + public: + Interpreter* interpreter_; +}; + Interpreter::Interpreter(ErrorReporter* error_reporter) - : arena_(kDefaultArenaAlignment), - persistent_arena_(kDefaultArenaAlignment), - error_reporter_(error_reporter ? error_reporter + : error_reporter_(error_reporter ? error_reporter : DefaultErrorReporter()) { context_.impl_ = static_cast(this); context_.ResizeTensor = ResizeTensor; @@ -50,7 +73,7 @@ Interpreter::Interpreter(ErrorReporter* error_reporter) // Reserve some space for the tensors to avoid excessive resizing. tensors_.reserve(kSlotsToReserve); nodes_and_registration_.reserve(kSlotsToReserve); - next_allocate_node_id_ = 0; + next_node_to_prepare_ = 0; UseNNAPI(false); } @@ -128,181 +151,6 @@ TfLiteStatus Interpreter::BytesRequired(TfLiteType type, const int* dims, return kTfLiteOk; } -TfLiteStatus Interpreter::AllocateTensorsWhoseSizesAreKnown() { - if (!consistent_) { - ReportError(&context_, "AllocateTensors() called on inconsistent model."); - return kTfLiteError; - } - if (next_allocate_node_id_ == nodes_and_registration_.size() && invokable_) { - return kTfLiteOk; - } - allocs_and_refcounts_.resize(context_.tensors_size); - - int new_next_allocate_node_id = next_allocate_node_id_; - invokable_ = false; - - // Allocate graph input nodes. - if (next_allocate_node_id_ == 0) { - for (int i = 0; i < inputs_.size(); ++i) { - int tensor_index = inputs_[i]; - if (tensor_index == kOptionalTensor) { - continue; - } - TfLiteTensor& tensor = context_.tensors[tensor_index]; - if (tensor.allocation_type == kTfLiteArenaRw) { - TF_LITE_ENSURE_OK( - &context_, - arena_.Allocate(&context_, kDefaultTensorAlignment, tensor.bytes, - &allocs_and_refcounts_[tensor_index].alloc)); - } - } - // Add 1 to output tensors, so they will not get overwritten. - for (int i = 0; i < outputs_.size(); ++i) { - allocs_and_refcounts_[outputs_[i]].count++; - } - } - - // Count references to node input tensors, and resize node-referenced tensors - // until we encounter a node that has a dynamic output tensor. - for (int k = next_allocate_node_id_; k < nodes_and_registration_.size(); - k++) { - new_next_allocate_node_id++; - TfLiteNode& node = nodes_and_registration_[k].first; - const TfLiteRegistration& registration = nodes_and_registration_[k].second; - if (OpPrepare(registration, &node) == kTfLiteError) { - return kTfLiteError; - } - - TfLiteIntArray* node_inputs = node.inputs; - for (int i = 0; i < node_inputs->size; ++i) { - int tensor_index = node_inputs->data[i]; - if (tensor_index != kOptionalTensor) { - allocs_and_refcounts_[node_inputs->data[i]].count++; - } - } - - // Discontinue if the node has dynamic outputs. - bool has_unallocated_dynamic_tensor = false; - TfLiteIntArray* node_outputs = node.outputs; - for (int i = 0; i < node_outputs->size; ++i) { - TfLiteTensor& tensor = context_.tensors[node_outputs->data[i]]; - if (tensor.allocation_type == kTfLiteDynamic) { - has_unallocated_dynamic_tensor = true; - break; - } - } - if (has_unallocated_dynamic_tensor) { - break; - } - } - - // Allocate graph persistent outputs, e.g. RNN cell states, etc. - for (int k = next_allocate_node_id_; k < new_next_allocate_node_id; k++) { - TfLiteNode& node = nodes_and_registration_[k].first; - - // Go through output tensors and allocate the persistent ones first. - TfLiteIntArray* node_outputs = node.outputs; - for (int i = 0; i < node_outputs->size; ++i) { - int tensor_index = node_outputs->data[i]; - TfLiteTensor& tensor = context_.tensors[tensor_index]; - if (tensor.allocation_type == kTfLiteArenaRwPersistent) { - TF_LITE_ENSURE_OK(&context_, - persistent_arena_.Allocate( - &context_, kDefaultTensorAlignment, tensor.bytes, - &allocs_and_refcounts_[tensor_index].alloc)); - } - } - } - - // Go through the graph in execution order. - for (int k = next_allocate_node_id_; k < new_next_allocate_node_id; k++) { - TfLiteNode& node = nodes_and_registration_[k].first; - - // First allocate output tensors. - TfLiteIntArray* node_outputs = node.outputs; - for (int i = 0; i < node_outputs->size; ++i) { - int tensor_index = node_outputs->data[i]; - TfLiteTensor& tensor = context_.tensors[tensor_index]; - if (tensor.allocation_type == kTfLiteArenaRw) { - TF_LITE_ENSURE_OK( - &context_, - arena_.Allocate(&context_, kDefaultTensorAlignment, tensor.bytes, - &allocs_and_refcounts_[tensor_index].alloc)); - } - } - // Then the temporaries, in two passes. First allocate them all, them - // deallocate them. - TfLiteIntArray* node_temporaries = node.temporaries; - for (int i = 0; i < node_temporaries->size; ++i) { - int tensor_index = node_temporaries->data[i]; - TfLiteTensor& tensor = context_.tensors[tensor_index]; - if (tensor.allocation_type == kTfLiteArenaRw) { - TF_LITE_ENSURE_OK( - &context_, - arena_.Allocate(&context_, kDefaultTensorAlignment, tensor.bytes, - &allocs_and_refcounts_[tensor_index].alloc)); - } - } - for (int i = 0; i < node_temporaries->size; ++i) { - int tensor_index = node_temporaries->data[i]; - TfLiteTensor& tensor = context_.tensors[tensor_index]; - allocs_and_refcounts_[tensor_index].count--; - if (tensor.allocation_type == kTfLiteArenaRw && - allocs_and_refcounts_[tensor_index].count == 0) { - TF_LITE_ENSURE_OK( - &context_, - arena_.Deallocate(&context_, - allocs_and_refcounts_[tensor_index].alloc)); - } - } - - // Then process the node's inputs. - TfLiteIntArray* node_inputs = node.inputs; - for (int i = 0; i < node_inputs->size; ++i) { - int tensor_index = node_inputs->data[i]; - if (tensor_index == kOptionalTensor) { - continue; - } - TfLiteTensor& tensor = context_.tensors[tensor_index]; - - // Decrease reference count and deallocate if not needed anymore. - allocs_and_refcounts_[tensor_index].count--; - if (tensor.allocation_type == kTfLiteArenaRw && - allocs_and_refcounts_[tensor_index].count == 0) { - TF_LITE_ENSURE_OK( - &context_, - arena_.Deallocate(&context_, - allocs_and_refcounts_[tensor_index].alloc)); - } - } - } - - // Resize the buffer and commit the arena. - TF_LITE_ENSURE_OK(&context_, arena_.Commit(&context_)); - TF_LITE_ENSURE_OK(&context_, persistent_arena_.Commit(&context_)); - - // Rewire the tensors to use the underlying arena buffer. - for (int i = 0; i < context_.tensors_size; ++i) { - TfLiteTensor& tensor = context_.tensors[i]; - if (tensor.allocation_type == kTfLiteArenaRw) { - TF_LITE_ENSURE_OK( - &context_, - arena_.ResolveAlloc(&context_, allocs_and_refcounts_[i].alloc, - &tensor.data.raw)); - } - if (tensor.allocation_type == kTfLiteArenaRwPersistent) { - TF_LITE_ENSURE_OK( - &context_, - persistent_arena_.ResolveAlloc( - &context_, allocs_and_refcounts_[i].alloc, &tensor.data.raw)); - } - } - - invokable_ = true; - next_allocate_node_id_ = new_next_allocate_node_id; - return kTfLiteOk; -} - namespace { TfLiteIntArray* convertVectorToTfLiteIntArray(const std::vector& x) { TfLiteIntArray* lite = TfLiteIntArrayCreate(x.size()); @@ -312,11 +160,19 @@ TfLiteIntArray* convertVectorToTfLiteIntArray(const std::vector& x) { } // namespace TfLiteStatus Interpreter::AllocateTensors() { - next_allocate_node_id_ = 0; - TF_LITE_ENSURE_OK(&context_, arena_.Clear()); - TF_LITE_ENSURE_OK(&context_, persistent_arena_.Clear()); - allocs_and_refcounts_.clear(); - return AllocateTensorsWhoseSizesAreKnown(); + next_node_to_prepare_ = 0; + if (memory_planner_) { + TF_LITE_ENSURE_STATUS(memory_planner_->ResetAllocations()); + } + + if (!consistent_) { + ReportError(&context_, "AllocateTensors() called on inconsistent model."); + return kTfLiteError; + } + + TF_LITE_ENSURE_STATUS(PrepareOpsAndTensors()); + invokable_ = true; + return kTfLiteOk; } TfLiteStatus Interpreter::AddNodeWithParameters( @@ -372,6 +228,57 @@ TfLiteStatus Interpreter::ResizeInputTensor(int tensor_index, return ResizeTensorImpl(&context_.tensors[tensor_index], dims_lite); } +// Returns true if at least one tensor in the given list is kTfLiteDynamic. +bool HasDynamicTensor(const TfLiteContext& context, + const TfLiteIntArray* tensors) { + for (int i = 0; i < tensors->size; ++i) { + const TfLiteTensor& tensor = context.tensors[tensors->data[i]]; + if (tensor.allocation_type == kTfLiteDynamic) { + return true; + } + } + return false; +} + +TfLiteStatus Interpreter::PrepareOpsStartingAt(int first_node, + int* last_node_prepared) { + for (int i = first_node; i < nodes_and_registration_.size(); i++) { + TfLiteNode& node = nodes_and_registration_[i].first; + const TfLiteRegistration& registration = nodes_and_registration_[i].second; + if (OpPrepare(registration, &node) == kTfLiteError) { + return kTfLiteError; + } + + *last_node_prepared = i; + + // Discontinue if the node has dynamic outputs. Note that we don't + // stop for dynamic temporary tensors since they won't affect the + // sizes of other tensors in the graph. + if (HasDynamicTensor(context_, node.outputs)) { + break; + } + } + return kTfLiteOk; +} + +TfLiteStatus Interpreter::PrepareOpsAndTensors() { + if (!memory_planner_) { + memory_planner_.reset(new ArenaPlanner( + &context_, std::unique_ptr(new InterpreterInfo(this)))); + memory_planner_->PlanAllocations(); + } + + int last_node_prepared = 0; + + TF_LITE_ENSURE_STATUS( + PrepareOpsStartingAt(next_node_to_prepare_, &last_node_prepared)); + TF_LITE_ENSURE_STATUS(memory_planner_->ExecuteAllocations( + next_node_to_prepare_, last_node_prepared)); + + next_node_to_prepare_ = last_node_prepared + 1; + return kTfLiteOk; +} + TfLiteStatus Interpreter::Invoke() { if (!consistent_) { ReportError(&context_, "Invoke called on model that is not consistent."); @@ -384,10 +291,8 @@ TfLiteStatus Interpreter::Invoke() { TfLiteStatus status = kTfLiteOk; if (nnapi_delegate_) { - if (AllocateTensorsWhoseSizesAreKnown() == kTfLiteError) { - return kTfLiteError; - } - if (next_allocate_node_id_ == nodes_and_registration_.size()) { + TF_LITE_ENSURE_STATUS(PrepareOpsAndTensors()); + if (next_node_to_prepare_ == nodes_and_registration_.size()) { TF_LITE_ENSURE_OK(&context_, nnapi_delegate_->Invoke(this)); return kTfLiteOk; } else { @@ -400,14 +305,17 @@ TfLiteStatus Interpreter::Invoke() { } } + // Invocations are always done in node order. + // Note that calling Invoke repeatedly will cause the original memory plan to + // be reused, unless either ResizeInputTensor() or AllocateTensors() has been + // called. + // TODO(b/71913981): we should force recalculation in the presence of dynamic + // tensors, because they may have new value which in turn may affect shapes + // and allocations. for (int i = 0; i < nodes_and_registration_.size(); i++) { - // Ensure we have allocated up to this node. The point of this is to - // allocate as much as possible before running any evaluation, but - // dynamic shapes can prevent this from being possible. - if (i >= next_allocate_node_id_) { - if (AllocateTensorsWhoseSizesAreKnown() == kTfLiteError) { - return kTfLiteError; - } + if (i == next_node_to_prepare_) { + TF_LITE_ENSURE_STATUS(PrepareOpsAndTensors()); + TF_LITE_ENSURE(&context_, next_node_to_prepare_ >= i); } TfLiteNode& node = nodes_and_registration_[i].first; const TfLiteRegistration& registration = nodes_and_registration_[i].second; diff --git a/tensorflow/contrib/lite/interpreter.h b/tensorflow/contrib/lite/interpreter.h index 65c61e44bee48535f884a3afaddc691972f5e04b..38dd402e8a971fd0aab51e98610ad12131441862 100644 --- a/tensorflow/contrib/lite/interpreter.h +++ b/tensorflow/contrib/lite/interpreter.h @@ -23,7 +23,7 @@ limitations under the License. #include "tensorflow/contrib/lite/allocation.h" #include "tensorflow/contrib/lite/context.h" #include "tensorflow/contrib/lite/error_reporter.h" -#include "tensorflow/contrib/lite/simple_memory_arena.h" +#include "tensorflow/contrib/lite/memory_planner.h" namespace tflite { @@ -49,13 +49,6 @@ constexpr TfLiteType typeToTfLiteType() { return kTfLiteUInt8; } -struct ArenaAllocRefCount { - ArenaAllocRefCount() : alloc(), count(0) {} - - ArenaAlloc alloc; - int count; -}; - // Forward declare since NNAPIDelegate uses Interpreter. class NNAPIDelegate; @@ -276,9 +269,17 @@ class Interpreter { return op_reg.invoke(&context_, node); } - // Allocate tensors whose sizes are known in order of nodes. Discontinue when - // we encounter a node that has a dynamic output tensor. - TfLiteStatus AllocateTensorsWhoseSizesAreKnown(); + // Call OpPrepare() for as many ops as possible, allocating memory for their + // tensors. If an op containing dynamic tensors is found, preparation will be + // postponed until this function is called again. This allows the interpreter + // to wait until Invoke() to resolve the sizes of dynamic tensors. + TfLiteStatus PrepareOpsAndTensors(); + + // Call OpPrepare() for all ops starting at 'first_node'. Stop when a + // dynamic tensors is found or all ops have been prepared. Fill + // 'last_node_prepared' with the id of the op containing dynamic tensors, or + // the last in the graph. + TfLiteStatus PrepareOpsStartingAt(int first_node, int* last_node_prepared); // Tensors needed by the interpreter. Use `AddTensors` to add more blank // tensor entries. Note, `tensors_.data()` needs to be synchronized to the @@ -325,17 +326,6 @@ class Interpreter { std::vector> nodes_and_registration_; - // Raw memory buffer that is allocated for all temporary and graph outputs. - // that are declared kTfLiteArenaRw. - SimpleMemoryArena arena_; - - // Raw memory buffer that is allocated for persistent tensors that are - // declared as kTfLiteArenaRwPersistent. - SimpleMemoryArena persistent_arena_; - - // Stores allocation and reference counts of all tensors. - std::vector allocs_and_refcounts_; - // Whether the model is consistent. That is to say if the inputs and outputs // of every node and the global inputs and outputs are valid indexes into // the tensor array. @@ -356,7 +346,7 @@ class Interpreter { // The error reporter delegate that tflite will forward queries errors to. ErrorReporter* error_reporter_; - // Next node to allocate output tensors. + // Index of the next node to prepare. // During Invoke(), Interpreter will allocate input tensors first, which are // known to be fixed size. Then it will allocate outputs from nodes as many // as possible. When there is a node that produces dynamic sized tensor. @@ -364,10 +354,12 @@ class Interpreter { // node id, and execute the node to generate the output tensor before continue // to allocate successors. This process repeats until all nodes are executed. // NOTE: this relies on the order of nodes that is in topological order. - int next_allocate_node_id_; + int next_node_to_prepare_; // Whether to delegate to NN API std::unique_ptr nnapi_delegate_; + + std::unique_ptr memory_planner_; }; } // namespace tflite diff --git a/tensorflow/contrib/lite/ios_makefile.inc b/tensorflow/contrib/lite/ios_makefile.inc index bcff7ed9889e95c13294b6cf0d0f4788991a04df..26cfe6c3e286ed603c2183986c697562e846889c 100644 --- a/tensorflow/contrib/lite/ios_makefile.inc +++ b/tensorflow/contrib/lite/ios_makefile.inc @@ -30,6 +30,9 @@ ifeq ($(TARGET), IOS) ${IPHONEOS_SYSROOT} \ -arch $(IOS_ARCH) \ -O3 + ifeq ($(IOS_ARCH), x86_64) + CXXFLAGS += -msse4.1 + endif CCFLAGS += -miphoneos-version-min=$(MIN_SDK_VERSION) \ -fembed-bitcode \ -mno-thumb \ diff --git a/tensorflow/contrib/lite/java/AndroidManifest.xml b/tensorflow/contrib/lite/java/AndroidManifest.xml new file mode 100644 index 0000000000000000000000000000000000000000..f705feacbec38ab5152ce52b701320d8f1cd8d3d --- /dev/null +++ b/tensorflow/contrib/lite/java/AndroidManifest.xml @@ -0,0 +1,7 @@ + + + + + + diff --git a/tensorflow/contrib/lite/java/BUILD b/tensorflow/contrib/lite/java/BUILD index 1de28eb52ddb458df0be0a8f9ef453f7caf68654..9a1a888b93ff981b1d14faa7e847e80be1f167f2 100644 --- a/tensorflow/contrib/lite/java/BUILD +++ b/tensorflow/contrib/lite/java/BUILD @@ -7,6 +7,16 @@ licenses(["notice"]) # Apache 2.0 load("//tensorflow/java:build_defs.bzl", "JAVACOPTS") load("//tensorflow/contrib/lite:build_def.bzl", "tflite_jni_binary") +load("//tensorflow/contrib/lite/java:aar_with_jni.bzl", "aar_with_jni") + +# Building tensorflow-lite.aar including 4 variants of .so +# To build an aar for release, run below command: +# bazel build --cxxopt='--std=c++11' -c opt --fat_apk_cpu=x86,x86_64,arm64-v8a,armeabi-v7a \ +# tensorflow/contrib/lite/java:tensorflow-lite +aar_with_jni( + name = "tensorflow-lite", + android_library = ":tensorflowlite", +) android_library( name = "tensorflowlite", @@ -15,6 +25,7 @@ android_library( "src/main/java/org/tensorflow/lite/*.java", ], ), + manifest = "AndroidManifest.xml", visibility = ["//visibility:public"], deps = [ ":tflite_runtime", diff --git a/tensorflow/contrib/lite/java/aar_with_jni.bzl b/tensorflow/contrib/lite/java/aar_with_jni.bzl new file mode 100644 index 0000000000000000000000000000000000000000..4450bc9085555b3416f51bac07ea94a1240e919c --- /dev/null +++ b/tensorflow/contrib/lite/java/aar_with_jni.bzl @@ -0,0 +1,47 @@ +"""Generate zipped aar file including different variants of .so in jni folder.""" + +def aar_with_jni(name, android_library): + # Generate dummy AndroidManifest.xml for dummy apk usage + # (dummy apk is generated by _dummy_app_for_so target below) + native.genrule( + name = name + "_binary_manifest_generator", + outs = [name + "_generated_AndroidManifest.xml"], + cmd = """ +cat > $(OUTS) < + + +EOF +""", + ) + + # Generate dummy apk including .so files and later we extract out + # .so files and throw away the apk. + native.android_binary( + name = name + "_dummy_app_for_so", + manifest = name + "_generated_AndroidManifest.xml", + custom_package = "dummy.package.for.so", + deps = [android_library], + # In some platforms we don't have an Android SDK/NDK and this target + # can't be built. We need to prevent the build system from trying to + # use the target in that case. + tags = ["manual"], + ) + + native.genrule( + name = name, + srcs = [android_library + ".aar", name + "_dummy_app_for_so_unsigned.apk"], + outs = [name + ".aar"], + tags = ["manual"], + cmd = """ +cp $(location {}.aar) $(location :{}.aar) +chmod +w $(location :{}.aar) +origdir=$$PWD +cd $$(mktemp -d) +unzip $$origdir/$(location :{}_dummy_app_for_so_unsigned.apk) "lib/*" +cp -r lib jni +zip -r $$origdir/$(location :{}.aar) jni/*/*.so +""".format(android_library, name, name, name, name), + ) diff --git a/tensorflow/contrib/lite/java/build_aar_for_release.sh b/tensorflow/contrib/lite/java/build_aar_for_release.sh new file mode 100755 index 0000000000000000000000000000000000000000..fbcb1e7db9a3f9b885505e989b7ff7224f2d2b15 --- /dev/null +++ b/tensorflow/contrib/lite/java/build_aar_for_release.sh @@ -0,0 +1,66 @@ +#!/bin/bash +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +set -e +set -x + +TMPDIR=`mktemp -d` +trap "rm -rf $TMPDIR" EXIT + +VERSION=1.0 + +BUILDER=bazel +BASEDIR=tensorflow/contrib/lite +CROSSTOOL="//external:android/crosstool" +HOST_CROSSTOOL="@bazel_tools//tools/cpp:toolchain" + +BUILD_OPTS="--cxxopt=--std=c++11 -c opt" +CROSSTOOL_OPTS="--crosstool_top=$CROSSTOOL --host_crosstool_top=$HOST_CROSSTOOL" + +test -d $BASEDIR || (echo "Aborting: not at top-level build directory"; exit 1) + +function build_basic_aar() { + local OUTDIR=$1 + $BUILDER build $BUILD_OPTS $BASEDIR/java:tensorflowlite.aar + unzip -d $OUTDIR $BUILDER-bin/$BASEDIR/java/tensorflowlite.aar + # targetSdkVersion is here to prevent the app from requesting spurious + # permissions, such as permission to make phone calls. It worked for v1.0, + # but minSdkVersion might be the preferred way to handle this. + sed -i -e 's///' $OUTDIR/AndroidManifest.xml +} + +function build_arch() { + local ARCH=$1 + local CONFIG=$2 + local OUTDIR=$3 + mkdir -p $OUTDIR/jni/$ARCH/ + $BUILDER build $BUILD_OPTS $CROSSTOOL_OPTS --cpu=$CONFIG \ + $BASEDIR/java:libtensorflowlite_jni.so + cp $BUILDER-bin/$BASEDIR/java/libtensorflowlite_jni.so $OUTDIR/jni/$ARCH/ +} + +rm -rf $TMPDIR +mkdir -p $TMPDIR/jni + +build_basic_aar $TMPDIR +build_arch arm64-v8a arm64-v8a $TMPDIR +build_arch armeabi-v7a armeabi-v7a $TMPDIR +build_arch x86 x86 $TMPDIR +build_arch x86_64 x86_64 $TMPDIR + +AAR_FILE=`realpath tflite-${VERSION}.aar` +(cd $TMPDIR && zip $AAR_FILE -r *) +echo "New AAR file is $AAR_FILE" + diff --git a/tensorflow/contrib/lite/java/demo/README.md b/tensorflow/contrib/lite/java/demo/README.md index 71b633c5774d93684f651821adad13c378a8243c..2e818f728ef208d30b0eeb27ffd7e3fa0c7c1a2d 100644 --- a/tensorflow/contrib/lite/java/demo/README.md +++ b/tensorflow/contrib/lite/java/demo/README.md @@ -8,7 +8,12 @@ It's easiest with Android Studio. - You'll need at least SDK version 23. + - Make sure to install the latest version of Bazel. Some distributions + ship with Bazel 0.5.4, which is too old. - Bazel requires Android Build Tools `26.0.1` or higher. + - **Bazel is incompatible with NDK revisions 15 and above,** with revision + 16 being a compile-breaking change. [Download an older version manually + instead of using the SDK Manager.](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android#install-bazel-and-android-prerequisites) - You also need to install the Android Support Repository, available through Android Studio under `Android SDK Manager -> SDK Tools -> Android Support Repository`. @@ -16,10 +21,15 @@ 2. [Edit your `WORKSPACE`](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android#edit-workspace) to add SDK and NDK targets. + NOTE: As long as you have the SDK and NDK installed, the `./configure` + script will create these rules for you. Answer "Yes" when the script asks + to automatically configure the `./WORKSPACE`. + - Make sure the `api_level` in `WORKSPACE` is set to an SDK version that you have installed. - By default, Android Studio will install the SDK to `~/Android/Sdk` and - the NDK to `~/Android/Sdk/ndk-bundle`. + the NDK to `~/Android/Sdk/ndk-bundle` (but the NDK should be a manual + download until Bazel supports NDK 16. See bullet points under (1)). 2. Build the app with Bazel. The demo needs C++11: diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifier.java b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifier.java index e7bad4637041d003c1e507d81c0c30404c587653..e44c5ae6b48eda187079dd3a0a1bc563276d816e 100644 --- a/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifier.java +++ b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifier.java @@ -73,6 +73,11 @@ public class ImageClassifier { /** An array to hold inference results, to be feed into Tensorflow Lite as outputs. */ private byte[][] labelProbArray = null; + /** multi-stage low pass filter * */ + private float[][] filterLabelProbArray = null; + + private static final int FILTER_STAGES = 3; + private static final float FILTER_FACTOR = 0.4f; private PriorityQueue> sortedLabels = new PriorityQueue<>( @@ -93,6 +98,7 @@ public class ImageClassifier { DIM_BATCH_SIZE * DIM_IMG_SIZE_X * DIM_IMG_SIZE_Y * DIM_PIXEL_SIZE); imgData.order(ByteOrder.nativeOrder()); labelProbArray = new byte[1][labelList.size()]; + filterLabelProbArray = new float[FILTER_STAGES][labelList.size()]; Log.d(TAG, "Created a Tensorflow Lite Image Classifier."); } @@ -108,11 +114,38 @@ public class ImageClassifier { tflite.run(imgData, labelProbArray); long endTime = SystemClock.uptimeMillis(); Log.d(TAG, "Timecost to run model inference: " + Long.toString(endTime - startTime)); + + // Smooth the results across frames. + applyFilter(); + + // Print the results. String textToShow = printTopKLabels(); textToShow = Long.toString(endTime - startTime) + "ms" + textToShow; return textToShow; } + void applyFilter() { + int numLabels = labelList.size(); + + // Low pass filter `labelProbArray` into the first stage of the filter. + for (int j = 0; j < numLabels; ++j) { + filterLabelProbArray[0][j] += + FILTER_FACTOR * (labelProbArray[0][j] - filterLabelProbArray[0][j]); + } + // Low pass filter each stage into the next. + for (int i = 1; i < FILTER_STAGES; ++i) { + for (int j = 0; j < numLabels; ++j) { + filterLabelProbArray[i][j] += + FILTER_FACTOR * (filterLabelProbArray[i - 1][j] - filterLabelProbArray[i][j]); + } + } + + // Copy the last stage filter output back to `labelProbArray`. + for (int j = 0; j < numLabels; ++j) { + labelProbArray[0][j] = (byte)filterLabelProbArray[FILTER_STAGES - 1][j]; + } + } + /** Closes tflite to release resources. */ public void close() { tflite.close(); @@ -177,7 +210,7 @@ public class ImageClassifier { final int size = sortedLabels.size(); for (int i = 0; i < size; ++i) { Map.Entry label = sortedLabels.poll(); - textToShow = "\n" + label.getKey() + ":" + Float.toString(label.getValue()) + textToShow; + textToShow = String.format("\n%s: %4.2f", label.getKey(), label.getValue()) + textToShow; } return textToShow; } diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java index 1939a078ad8031b99620773c9b91335c4e8f7b22..5ee594dec492ad2fee22e603a6de311b3fed4cac 100644 --- a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java +++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java @@ -34,7 +34,7 @@ final class NativeInterpreterWrapper implements AutoCloseable { NativeInterpreterWrapper(String modelPath) { errorHandle = createErrorReporter(ERROR_BUFFER_SIZE); modelHandle = createModel(modelPath, errorHandle); - interpreterHandle = createInterpreter(modelHandle); + interpreterHandle = createInterpreter(modelHandle, errorHandle); } /** @@ -46,7 +46,7 @@ final class NativeInterpreterWrapper implements AutoCloseable { modelByteBuffer = mappedByteBuffer; errorHandle = createErrorReporter(ERROR_BUFFER_SIZE); modelHandle = createModelWithBuffer(modelByteBuffer, errorHandle); - interpreterHandle = createInterpreter(modelHandle); + interpreterHandle = createInterpreter(modelHandle, errorHandle); } /** Releases resources associated with this {@code NativeInterpreterWrapper}. */ @@ -103,11 +103,22 @@ final class NativeInterpreterWrapper implements AutoCloseable { return outputs; } + private static native long[] run( + long interpreterHandle, + long errorHandle, + Object[] sizes, + int[] dtypes, + int[] numsOfBytes, + Object[] values); + /** Resizes dimensions of a specific input. */ void resizeInput(int idx, int[] dims) { resizeInput(interpreterHandle, errorHandle, idx, dims); } + private static native void resizeInput( + long interpreterHandle, long errorHandle, int inputIdx, int[] dims); + void setUseNNAPI(boolean useNNAPI) { useNNAPI(interpreterHandle, useNNAPI); } @@ -245,9 +256,6 @@ final class NativeInterpreterWrapper implements AutoCloseable { private static native String[] getOutputNames(long interpreterHandle); - private static native void resizeInput( - long interpreterHandle, long errorHandle, int inputIdx, int[] dims); - private static native void useNNAPI(long interpreterHandle, boolean state); private static native long createErrorReporter(int size); @@ -256,15 +264,7 @@ final class NativeInterpreterWrapper implements AutoCloseable { private static native long createModelWithBuffer(MappedByteBuffer modelBuffer, long errorHandle); - private static native long createInterpreter(long modelHandle); - - private static native long[] run( - long interpreterHandle, - long errorHandle, - Object[] sizes, - int[] dtypes, - int[] numsOfBytes, - Object[] values); + private static native long createInterpreter(long modelHandle, long errorHandle); private static native void delete(long errorHandle, long modelHandle, long interpreterHandle); diff --git a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc index bc6462eb5466e14769f94c5103984f5201b4b8dc..f3f51b668f068ffcd02862a79b72dbae31d31c02 100644 --- a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc +++ b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc @@ -307,12 +307,21 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_createModelWithBuffer( JNIEXPORT jlong JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_createInterpreter( - JNIEnv* env, jclass clazz, jlong model_handle) { + JNIEnv* env, jclass clazz, jlong model_handle, jlong error_handle) { tflite::FlatBufferModel* model = convertLongToModel(env, model_handle); if (model == nullptr) return 0; + BufferErrorReporter* error_reporter = + convertLongToErrorReporter(env, error_handle); + if (error_reporter == nullptr) return 0; auto resolver = ::tflite::CreateOpResolver(); std::unique_ptr interpreter; - tflite::InterpreterBuilder(*model, *(resolver.get()))(&interpreter); + TfLiteStatus status = + tflite::InterpreterBuilder(*model, *(resolver.get()))(&interpreter); + if (status != kTfLiteOk) { + throwException(env, kIllegalArgumentException, + "Cannot create interpreter: %s", + error_reporter->CachedErrorMessage()); + } return reinterpret_cast(interpreter.release()); } diff --git a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h index 430886b7cc04a356d1826843acc1bbebf4189bf7..c52a7e4e439936344be26d5761fb5747db64794a 100644 --- a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h +++ b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h @@ -95,11 +95,11 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_createModelWithBuffer( /* * Class: org_tensorflow_lite_NativeInterpreterWrapper * Method: - * Signature: (J)J + * Signature: (JJ)J */ JNIEXPORT jlong JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_createInterpreter( - JNIEnv* env, jclass clazz, jlong model_handle); + JNIEnv* env, jclass clazz, jlong model_handle, jlong error_handle); /* * Class: org_tensorflow_lite_NativeInterpreterWrapper diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java index 9a6894f49c0b7278511717d2671648c6d1763e00..473f73816fd3c0a414a2c2e232dec299579fcbb6 100644 --- a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java +++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java @@ -25,6 +25,7 @@ import org.junit.runner.RunWith; import org.junit.runners.JUnit4; /** Unit tests for {@link org.tensorflow.lite.NativeInterpreterWrapper}. */ +// TODO(b/71818425): Generates model files dynamically. @RunWith(JUnit4.class) public final class NativeInterpreterWrapperTest { @@ -43,6 +44,9 @@ public final class NativeInterpreterWrapperTest { private static final String INVALID_MODEL_PATH = "tensorflow/contrib/lite/java/src/testdata/invalid_model.bin"; + private static final String MODEL_WITH_CUSTOM_OP_PATH = + "tensorflow/contrib/lite/java/src/testdata/with_custom_op.lite"; + @Test public void testConstructor() { NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH); @@ -62,6 +66,18 @@ public final class NativeInterpreterWrapperTest { } } + @Test + public void testConstructorWithUnresolableCustomOp() { + try { + NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(MODEL_WITH_CUSTOM_OP_PATH); + fail(); + } catch (IllegalArgumentException e) { + assertThat(e) + .hasMessageThat() + .contains("Cannot create interpreter: Didn't find custom op for name 'Assign'"); + } + } + @Test public void testRunWithFloat() { NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH); diff --git a/tensorflow/contrib/lite/java/src/testdata/with_custom_op.lite b/tensorflow/contrib/lite/java/src/testdata/with_custom_op.lite new file mode 100644 index 0000000000000000000000000000000000000000..e775d56d88854ecdf70233262ff5884d224f4373 Binary files /dev/null and b/tensorflow/contrib/lite/java/src/testdata/with_custom_op.lite differ diff --git a/tensorflow/contrib/lite/kernels/Android.bp b/tensorflow/contrib/lite/kernels/Android.bp index f077bcfbed9b310491206d0c1b5b56fdddfbe403..de53078c8af2783cc876636ad350d0adb48fb6a9 100644 --- a/tensorflow/contrib/lite/kernels/Android.bp +++ b/tensorflow/contrib/lite/kernels/Android.bp @@ -32,26 +32,36 @@ cc_library_static { "activations.cc", "add.cc", "basic_rnn.cc", + "batch_to_space_nd.cc", "concatenation.cc", "conv.cc", "depthwise_conv.cc", + "div.cc", "embedding_lookup.cc", "embedding_lookup_sparse.cc", "fully_connected.cc", + "gather.cc", "hashtable_lookup.cc", "kernel_util.cc", "l2norm.cc", "local_response_norm.cc", "lsh_projection.cc", "lstm.cc", + "mean.cc", "mul.cc", + "pad.cc", "pooling.cc", "register.cc", "reshape.cc", "resize_bilinear.cc", "skip_gram.cc", + "space_to_batch_nd.cc", "space_to_depth.cc", + "squeeze.cc", + "sub.cc", "svdf.cc", + "transpose.cc", + "unidirectional_sequence_rnn.cc", "internal/tensor_utils.cc", "internal/quantization_util.cc", "internal/reference/portable_tensor_utils.cc", diff --git a/tensorflow/contrib/lite/kernels/BUILD b/tensorflow/contrib/lite/kernels/BUILD index bbbfa3e7415bfd7a34dfc7d764da55cac22e7d42..7e9644f36c71ff7e03a04dd01743be811632f077 100644 --- a/tensorflow/contrib/lite/kernels/BUILD +++ b/tensorflow/contrib/lite/kernels/BUILD @@ -32,6 +32,7 @@ cc_library( "//tensorflow/contrib/lite:framework", "//tensorflow/contrib/lite:schema_fbs_version", "//tensorflow/contrib/lite:string_util", + "//tensorflow/contrib/lite/testing:util", "//tensorflow/core:lib", "@com_google_googletest//:gtest", ], @@ -76,26 +77,36 @@ cc_library( "activations.cc", "add.cc", "basic_rnn.cc", + "batch_to_space_nd.cc", "concatenation.cc", "conv.cc", "depthwise_conv.cc", + "div.cc", "embedding_lookup.cc", "embedding_lookup_sparse.cc", "fully_connected.cc", + "gather.cc", "hashtable_lookup.cc", "kernel_util.cc", "l2norm.cc", "local_response_norm.cc", "lsh_projection.cc", "lstm.cc", + "mean.cc", "mul.cc", + "pad.cc", "pooling.cc", "register.cc", "reshape.cc", "resize_bilinear.cc", "skip_gram.cc", + "space_to_batch_nd.cc", "space_to_depth.cc", + "squeeze.cc", + "sub.cc", "svdf.cc", + "transpose.cc", + "unidirectional_sequence_rnn.cc", ], hdrs = [ "kernel_util.h", @@ -152,6 +163,44 @@ tf_cc_test( ], ) +tf_cc_test( + name = "transpose_test", + size = "small", + srcs = ["transpose_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "//tensorflow/contrib/lite/kernels/internal:reference", + "//tensorflow/contrib/lite/kernels/internal:reference_base", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "space_to_batch_nd_test", + size = "small", + srcs = ["space_to_batch_nd_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "batch_to_space_nd_test", + size = "small", + srcs = ["batch_to_space_nd_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + tf_cc_test( name = "concatenation_test", size = "small", @@ -200,6 +249,18 @@ tf_cc_test( ], ) +tf_cc_test( + name = "unidirectional_sequence_rnn_test", + size = "small", + srcs = ["unidirectional_sequence_rnn_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + tf_cc_test( name = "l2norm_test", size = "small", @@ -212,6 +273,18 @@ tf_cc_test( ], ) +tf_cc_test( + name = "mean_test", + size = "small", + srcs = ["mean_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + tf_cc_test( name = "mul_test", size = "small", @@ -224,6 +297,18 @@ tf_cc_test( ], ) +tf_cc_test( + name = "pad_test", + size = "small", + srcs = ["pad_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + tf_cc_test( name = "reshape_test", size = "small", @@ -236,6 +321,19 @@ tf_cc_test( ], ) +tf_cc_test( + name = "gather_test", + size = "small", + srcs = ["gather_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:builtin_op_data", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + tf_cc_test( name = "resize_bilinear_test", size = "small", @@ -395,6 +493,18 @@ tf_cc_test( ], ) +tf_cc_test( + name = "squeeze_test", + size = "small", + srcs = ["squeeze_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/lite/kernels/activations.cc b/tensorflow/contrib/lite/kernels/activations.cc index 7ab60a33e5e2ff61bae5f4c6db85ab9c47a391bc..8ac93bc8c8dcfc66d3822e01b6f9b29a3e49c446 100644 --- a/tensorflow/contrib/lite/kernels/activations.cc +++ b/tensorflow/contrib/lite/kernels/activations.cc @@ -349,7 +349,7 @@ TfLiteRegistration* Register_RELU() { return &r; } -TfLiteRegistration* Register_RELU1() { +TfLiteRegistration* Register_RELU_N1_TO_1() { static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr, activations::GenericPrepare, activations::Relu1Eval}; diff --git a/tensorflow/contrib/lite/kernels/activations_test.cc b/tensorflow/contrib/lite/kernels/activations_test.cc index f10aee70170d4a94ed54376fa410b22a60f109af..68d49944e51b043b6b82aa1589d22f6ebed37574 100644 --- a/tensorflow/contrib/lite/kernels/activations_test.cc +++ b/tensorflow/contrib/lite/kernels/activations_test.cc @@ -102,7 +102,7 @@ TEST(FloatActivationsOpTest, Relu) { } TEST(FloatActivationsOpTest, Relu1) { - FloatActivationsOpModel m(BuiltinOperator_RELU1, + FloatActivationsOpModel m(BuiltinOperator_RELU_N1_TO_1, /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}}); m.SetInput({ 0.0, -0.6, 0.2, -0.4, // @@ -317,7 +317,7 @@ TEST(QuantizedActivationsOpTest, Softmax2D) { } // namespace tflite int main(int argc, char** argv) { - // On Linux, add: tflite::LogToStderr(); + ::tflite::LogToStderr(); ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } diff --git a/tensorflow/contrib/lite/kernels/add_test.cc b/tensorflow/contrib/lite/kernels/add_test.cc index 8e12a837c4954832ff37a6d1ab377bee9e8d5763..306dfc3e803d3df34061767ba9ced032299bfa26 100644 --- a/tensorflow/contrib/lite/kernels/add_test.cc +++ b/tensorflow/contrib/lite/kernels/add_test.cc @@ -77,9 +77,10 @@ TEST(FloatAddOpModel, NoActivation) { EXPECT_THAT(m.GetOutput(), ElementsAreArray({-1.9, 0.4, 1.0, 1.3})); } -TEST(FloatAddOpModel, ActivationRELU1) { +TEST(FloatAddOpModel, ActivationRELU_N1_TO_1) { FloatAddOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}}, - {TensorType_FLOAT32, {}}, ActivationFunctionType_RELU1); + {TensorType_FLOAT32, {}}, + ActivationFunctionType_RELU_N1_TO_1); m.PopulateTensor(m.input1(), {-2.0, 0.2, 0.7, 0.8}); m.PopulateTensor(m.input2(), {0.1, 0.2, 0.3, 0.5}); m.Invoke(); @@ -122,7 +123,7 @@ TEST(QuantizedAddOpModel, QuantizedTestsNoActivation) { } } -TEST(QuantizedAddOpModel, QuantizedTestsActivationRELU1) { +TEST(QuantizedAddOpModel, QuantizedTestsActivationRELU_N1_TO_1) { float kQuantizedTolerance = GetTolerance(-1.0, 1.0); std::vector> inputs1 = {{-0.8, 0.2, 0.9, 0.7}, {-0.8, 0.2, 0.7, 0.3}}; @@ -133,7 +134,7 @@ TEST(QuantizedAddOpModel, QuantizedTestsActivationRELU1) { for (int i = 0; i < inputs1.size(); ++i) { QuantizedAddOpModel m({TensorType_UINT8, {1, 2, 2, 1}, -1.0, 1.0}, {TensorType_UINT8, {}, -1.0, 1.0}, - ActivationFunctionType_RELU1); + ActivationFunctionType_RELU_N1_TO_1); m.QuantizeAndPopulate(m.input1(), inputs1[i]); m.QuantizeAndPopulate(m.input2(), inputs2[i]); m.Invoke(); @@ -164,8 +165,7 @@ TEST(QuantizedAddOpModel, QuantizedVariousInputShapes) { } // namespace } // namespace tflite int main(int argc, char** argv) { - // On Linux, add: tflite::LogToStderr(); - tflite::LogToStderr(); + ::tflite::LogToStderr(); ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } diff --git a/tensorflow/contrib/lite/kernels/basic_rnn_test.cc b/tensorflow/contrib/lite/kernels/basic_rnn_test.cc index dfa75655bcfe7762c6cc4c9a98a71d529028c03a..5ecccb985e91238f1183c8f94a2b5f468758ce55 100644 --- a/tensorflow/contrib/lite/kernels/basic_rnn_test.cc +++ b/tensorflow/contrib/lite/kernels/basic_rnn_test.cc @@ -261,7 +261,7 @@ TEST(FullyConnectedOpTest, BlackBoxTest) { } // namespace tflite int main(int argc, char** argv) { - // On Linux, add: tflite::LogToStderr(); + ::tflite::LogToStderr(); ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } diff --git a/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc b/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc new file mode 100644 index 0000000000000000000000000000000000000000..0eed680fdcc2afc4bc72be55a5e7722310fa4538 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc @@ -0,0 +1,161 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace batch_to_space_nd { + +// This file has two implementations of BatchToSpaceND. +enum KernelType { + kReference, + kGenericOptimized, +}; + +struct BatchToSpaceNDContext { + BatchToSpaceNDContext(TfLiteContext* context, TfLiteNode* node) { + params = reinterpret_cast(node->builtin_data); + input = GetInput(context, node, 0); + output = GetOutput(context, node, 0); + } + TfLiteBatchToSpaceNDParams* params; + TfLiteTensor* input; + TfLiteTensor* output; +}; + +// Currently, only 4D NHWC input/output op_context are supported. +// The 4D array need to have exactly 2 spatial dimensions. +// TODO(ycling): Support arbitrary dimension in BatchToSpaceND. +const int kInputDimensionNum = 4; +const int kOutputDimensionNum = 4; +const int kSpatialDimensionNum = 2; + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + // The 2nd tensor (block_shape) and the 3rd tensor (crops) are ignored now. + TF_LITE_ENSURE(context, NumInputs(node) >= 1 && NumInputs(node) <= 3); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + BatchToSpaceNDContext op_context(context, node); + TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.input), + kInputDimensionNum); + TF_LITE_ENSURE_EQ(context, op_context.params->num_spatial_dimensions, + kSpatialDimensionNum); + TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type); + + const TfLiteIntArray* input_size = op_context.input->dims; + const int* block_shape = op_context.params->block_shape; + + // Number of batch must be multiple of (block_shape[0] * block_shape[1]). + TF_LITE_ENSURE_EQ(context, + input_size->data[0] % (block_shape[0] * block_shape[1]), 0); + + const int output_batch_size = + input_size->data[0] / (block_shape[0] * block_shape[1]); + const int output_height = input_size->data[1] * block_shape[0]; + const int output_width = input_size->data[2] * block_shape[1]; + const int output_channel_size = input_size->data[3]; + + TfLiteIntArray* output_size = TfLiteIntArrayCreate(kOutputDimensionNum); + output_size->data[0] = output_batch_size; + output_size->data[1] = output_height; + output_size->data[2] = output_width; + output_size->data[3] = output_channel_size; + + return context->ResizeTensor(context, op_context.output, output_size); +} + +template +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + BatchToSpaceNDContext op_context(context, node); + + int block_shape_dims_array[1] = {kSpatialDimensionNum}; + Dims<4> block_shape_dims = GetTensorDims(block_shape_dims_array, 1); + +#define TF_LITE_BATCH_TO_SPACE_ND(type, scalar) \ + type::BatchToSpaceND(GetTensorData(op_context.input), \ + GetTensorDims(op_context.input), \ + op_context.params->block_shape, block_shape_dims, \ + GetTensorData(op_context.output), \ + GetTensorDims(op_context.output)) + switch (op_context.input->type) { // Already know in/out types are same. + case kTfLiteFloat32: + if (kernel_type == kReference) { + TF_LITE_BATCH_TO_SPACE_ND(reference_ops, float); + } else { + TF_LITE_BATCH_TO_SPACE_ND(optimized_ops, float); + } + break; + case kTfLiteUInt8: + if (kernel_type == kReference) { + TF_LITE_BATCH_TO_SPACE_ND(reference_ops, uint8_t); + } else { + TF_LITE_BATCH_TO_SPACE_ND(optimized_ops, uint8_t); + } + break; + case kTfLiteInt32: + if (kernel_type == kReference) { + TF_LITE_BATCH_TO_SPACE_ND(reference_ops, int32_t); + } else { + TF_LITE_BATCH_TO_SPACE_ND(optimized_ops, int32_t); + } + break; + case kTfLiteInt64: + if (kernel_type == kReference) { + TF_LITE_BATCH_TO_SPACE_ND(reference_ops, int64_t); + } else { + TF_LITE_BATCH_TO_SPACE_ND(optimized_ops, int64_t); + } + break; + default: + context->ReportError(context, + "Type is currently not supported by BatchToSpace."); + return kTfLiteError; + } +#undef TF_LITE_BATCH_TO_SPACE_ND + return kTfLiteOk; +} + +} // namespace batch_to_space_nd + +TfLiteRegistration* Register_BATCH_TO_SPACE_ND_REF() { + static TfLiteRegistration r = { + nullptr, nullptr, batch_to_space_nd::Prepare, + batch_to_space_nd::Eval}; + return &r; +} + +TfLiteRegistration* Register_BATCH_TO_SPACE_ND_GENERIC_OPT() { + static TfLiteRegistration r = { + nullptr, nullptr, batch_to_space_nd::Prepare, + batch_to_space_nd::Eval}; + return &r; +} + +TfLiteRegistration* Register_BATCH_TO_SPACE_ND() { + return Register_BATCH_TO_SPACE_ND_GENERIC_OPT(); +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/batch_to_space_nd_test.cc b/tensorflow/contrib/lite/kernels/batch_to_space_nd_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..3ec4efbebcef9d55d0042d93007018c9f6ee3b58 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/batch_to_space_nd_test.cc @@ -0,0 +1,78 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class BatchToSpaceNDOpModel : public SingleOpModel { + public: + BatchToSpaceNDOpModel(std::initializer_list input_shape, + std::initializer_list block_shape, + std::initializer_list before_crops, + std::initializer_list after_crops) { + input_ = AddInput(TensorType_FLOAT32); + output_ = AddOutput(TensorType_FLOAT32); + SetBuiltinOp(BuiltinOperator_BATCH_TO_SPACE_ND, + BuiltinOptions_BatchToSpaceNDOptions, + CreateBatchToSpaceNDOptions( + builder_, builder_.CreateVector(block_shape), + builder_.CreateVector(before_crops), + builder_.CreateVector(after_crops)) + .Union()); + BuildInterpreter({input_shape}); + } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetOutputShape() { return GetTensorShape(output_); } + + private: + int input_; + int output_; +}; + +TEST(BatchToSpaceNDOpTest, SimpleTest) { + BatchToSpaceNDOpModel m({4, 2, 2, 1}, {2, 2}, {0, 0}, {0, 0}); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 5, 2, 6, 9, 13, 10, 14, 3, 7, + 4, 8, 11, 15, 12, 16})); +} + +TEST(BatchToSpaceNDOpTest, InvalidShapeTest) { + EXPECT_DEATH(BatchToSpaceNDOpModel({3, 2, 2, 1}, {2, 2}, {0, 0}, {0, 0}), + "Cannot allocate tensors"); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/concatenation_test.cc b/tensorflow/contrib/lite/kernels/concatenation_test.cc index 94e5b2acdcabeedb4652baa1a008b22bf6bc8433..499856a93cbbfbf9aa1a326912e52ce32bbbdf83 100644 --- a/tensorflow/contrib/lite/kernels/concatenation_test.cc +++ b/tensorflow/contrib/lite/kernels/concatenation_test.cc @@ -156,7 +156,7 @@ TEST(ConcatenationOpTest, FourInputsQuantized) { } // namespace tflite int main(int argc, char** argv) { - // On Linux, add: tflite::LogToStderr(); + ::tflite::LogToStderr(); ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } diff --git a/tensorflow/contrib/lite/kernels/conv_test.cc b/tensorflow/contrib/lite/kernels/conv_test.cc index 18d7a31d594efb6a05fe7292a0194ea17599a65b..1d0a81c3135625c07a3566f5f9a8e5401f0d4db7 100644 --- a/tensorflow/contrib/lite/kernels/conv_test.cc +++ b/tensorflow/contrib/lite/kernels/conv_test.cc @@ -434,7 +434,7 @@ TEST(ConvolutionOpTest, SimpleTestQuantizedWithAnisotropicStrides) { } // namespace tflite int main(int argc, char** argv) { - // On Linux, add: tflite::LogToStderr(); + ::tflite::LogToStderr(); ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } diff --git a/tensorflow/contrib/lite/kernels/depthwise_conv_test.cc b/tensorflow/contrib/lite/kernels/depthwise_conv_test.cc index 39227b2811e2be719a0be77f89793bcf9366d513..1439c8bce14ad127ed68dc54991aed8b8bb39383 100644 --- a/tensorflow/contrib/lite/kernels/depthwise_conv_test.cc +++ b/tensorflow/contrib/lite/kernels/depthwise_conv_test.cc @@ -180,7 +180,7 @@ TEST(QuantizedDepthwiseConvolutionOpTest, SimpleTestQuantized) { } // namespace tflite int main(int argc, char** argv) { - // On Linux, add: tflite::LogToStderr(); + ::tflite::LogToStderr(); ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } diff --git a/tensorflow/contrib/lite/kernels/div.cc b/tensorflow/contrib/lite/kernels/div.cc new file mode 100644 index 0000000000000000000000000000000000000000..44bd0dc85d50c98ec6b6888e05064a8f2e2731c0 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/div.cc @@ -0,0 +1,129 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace div { + +// This file has three implementation of Div. +enum KernelType { + kReference, + kGenericOptimized, // Neon-free + kNeonOptimized, +}; + +constexpr int kInputTensor1 = 0; +constexpr int kInputTensor2 = 1; +constexpr int kOutputTensor = 0; + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); + TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + TF_LITE_ENSURE_EQ(context, NumDimensions(input1), NumDimensions(input2)); + for (int i = 0; i < NumDimensions(input1); ++i) { + TF_LITE_ENSURE_EQ(context, SizeOfDimension(input1, i), + SizeOfDimension(input2, i)); + } + + TF_LITE_ENSURE_EQ(context, input1->type, output->type); + TF_LITE_ENSURE_EQ(context, input2->type, output->type); + + TfLiteIntArray* output_size = TfLiteIntArrayCopy(input1->dims); + return context->ResizeTensor(context, output, output_size); +} + +template +void EvalDivFloat(TfLiteContext* context, TfLiteNode* node, + TfLiteDivParams* params, TfLiteTensor* input1, + TfLiteTensor* input2, TfLiteTensor* output) { + float output_activation_min, output_activation_max; + CalculateActivationRangeFloat(params->activation, &output_activation_min, + &output_activation_max); +#define TF_LITE_DIV(type) \ + type::Div(GetTensorData(input1), GetTensorDims(input1), \ + GetTensorData(input2), GetTensorDims(input2), \ + output_activation_min, output_activation_max, \ + GetTensorData(output), GetTensorDims(output)) + if (kernel_type == kReference) { + TF_LITE_DIV(reference_ops); + } else { + TF_LITE_DIV(optimized_ops); + } +#undef TF_LITE_DIV +} + +template +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + + TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); + TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + if (output->type == kTfLiteFloat32) { + EvalDivFloat(context, node, params, input1, input2, output); + } else { + context->ReportError(context, "Inputs and outputs not all float types."); + return kTfLiteError; + } + + return kTfLiteOk; +} + +} // namespace div + +TfLiteRegistration* Register_DIV_REF() { + static TfLiteRegistration r = {nullptr, nullptr, div::Prepare, + div::Eval}; + return &r; +} + +TfLiteRegistration* Register_DIV_GENERIC_OPT() { + static TfLiteRegistration r = {nullptr, nullptr, div::Prepare, + div::Eval}; + return &r; +} + +TfLiteRegistration* Register_DIV_NEON_OPT() { + static TfLiteRegistration r = {nullptr, nullptr, div::Prepare, + div::Eval}; + return &r; +} + +TfLiteRegistration* Register_DIV() { +#ifdef USE_NEON + return Register_DIV_NEON_OPT(); +#else + return Register_DIV_GENERIC_OPT(); +#endif +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/embedding_lookup_sparse_test.cc b/tensorflow/contrib/lite/kernels/embedding_lookup_sparse_test.cc index 69d9c5cc7dec13a65f1c5050f2f1c56812ad5aa1..dcdc5fffad9ceac1a9d23a4e91637a9ff92a8dda 100644 --- a/tensorflow/contrib/lite/kernels/embedding_lookup_sparse_test.cc +++ b/tensorflow/contrib/lite/kernels/embedding_lookup_sparse_test.cc @@ -158,9 +158,7 @@ TEST(EmbeddingLookupOpTest, Indices3DTest) { } // namespace tflite int main(int argc, char** argv) { -#ifdef OS_LINUX - tflite::LogToStderr(); -#endif + ::tflite::LogToStderr(); ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } diff --git a/tensorflow/contrib/lite/kernels/embedding_lookup_test.cc b/tensorflow/contrib/lite/kernels/embedding_lookup_test.cc index 8c030b06772ac0c6af34a45897f03ebc4637d4de..9b501878f196216a61568bfa36e6615f4dd07478 100644 --- a/tensorflow/contrib/lite/kernels/embedding_lookup_test.cc +++ b/tensorflow/contrib/lite/kernels/embedding_lookup_test.cc @@ -88,7 +88,7 @@ TEST(EmbeddingLookupOpTest, SimpleTest) { } // namespace tflite int main(int argc, char** argv) { - // On Linux, add: tflite::LogToStderr(); + ::tflite::LogToStderr(); ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } diff --git a/tensorflow/contrib/lite/kernels/fully_connected_test.cc b/tensorflow/contrib/lite/kernels/fully_connected_test.cc index 112e3f1ba01a428023eea5ee8410fb76c1d67de6..a0f766c4f4580d7679275c0b63aa200410fcb5ad 100644 --- a/tensorflow/contrib/lite/kernels/fully_connected_test.cc +++ b/tensorflow/contrib/lite/kernels/fully_connected_test.cc @@ -370,8 +370,7 @@ TEST(FullyConnectedOpTest, BlackBoxTest) { } // namespace tflite int main(int argc, char** argv) { - // On Linux, add: tflite::LogToStderr(); - tflite::LogToStderr(); + ::tflite::LogToStderr(); ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } diff --git a/tensorflow/contrib/lite/kernels/gather.cc b/tensorflow/contrib/lite/kernels/gather.cc new file mode 100644 index 0000000000000000000000000000000000000000..f8df797daf7338e33b16508c21fc61cd9836db1e --- /dev/null +++ b/tensorflow/contrib/lite/kernels/gather.cc @@ -0,0 +1,130 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" +#include "tensorflow/contrib/lite/string_util.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace gather { +constexpr int kInputTensor = 0; +constexpr int kInputPositions = 1; +constexpr int kOutputTensor = 0; + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + const auto* params = + reinterpret_cast(node->builtin_data); + TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* positions = GetInput(context, node, kInputPositions); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + // Only INT32 positions are supported. + TF_LITE_ENSURE_EQ(context, positions->type, kTfLiteInt32); + // Check that input and output types match. + TF_LITE_ENSURE_EQ(context, input->type, output->type); + // TODO(mgubin): only 1D positions are currently supported. + TF_LITE_ENSURE_EQ(context, NumDimensions(positions), 1); + // TODO(mgubin): Only default axis == 0 is supported. + // Check conditions for different types. + switch (input->type) { + case kTfLiteFloat32: + case kTfLiteUInt8: + case kTfLiteInt32: { + // Fully supported by reference_ops::Gather. + } break; + + case kTfLiteString: { + // Only 1D input is supported. + TF_LITE_ENSURE_EQ(context, NumDimensions(input), 1); + } break; + default: + context->ReportError(context, + "Only float32 and string types are supported"); + return kTfLiteError; + } + const int num_dimensions = + NumDimensions(input) + NumDimensions(positions) - 1; + TF_LITE_ENSURE(context, params->axis < num_dimensions); + TfLiteIntArray* output_shape = TfLiteIntArrayCreate(num_dimensions); + int output_index = 0; + for (int i = 0; i < params->axis; ++i) { + output_shape->data[output_index++] = input->dims->data[i]; + } + for (int i = 0; i < positions->dims->size; ++i) { + output_shape->data[output_index++] = positions->dims->data[i]; + } + for (int i = params->axis + 1; i < input->dims->size; ++i) { + output_shape->data[output_index++] = input->dims->data[i]; + } + return context->ResizeTensor(context, output, output_shape); +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* positions = GetInput(context, node, kInputPositions); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + const int input_rank = NumDimensions(input); +#define TF_LITE_GATHER(data_type, index_type) \ + optimized_ops::Gather( \ + GetTensorData(input), GetTensorDims(input), input_rank, \ + GetTensorData(positions), GetTensorDims(positions), \ + GetTensorData(output), GetTensorDims(output)); + switch (input->type) { + case kTfLiteFloat32: + TF_LITE_GATHER(float, int32_t); + break; + case kTfLiteUInt8: + TF_LITE_GATHER(uint8_t, int32_t); + break; + case kTfLiteInt32: + TF_LITE_GATHER(int32_t, int32_t); + break; + case kTfLiteString: { + DynamicBuffer buffer; + const int32* indexes = positions->data.i32; + const int num_strings = GetStringCount(input); + for (int i = 0; i < positions->dims->data[0]; ++i) { + const int pos = indexes[i]; + TF_LITE_ENSURE(context, pos < num_strings); + const auto string_ref = GetString(input, pos); + buffer.AddString(string_ref.str, string_ref.len); + } + buffer.WriteToTensor(output); + } break; + default: + return kTfLiteError; + } +#undef TF_LITE_GATHER + return kTfLiteOk; +} +} // namespace gather + +TfLiteRegistration* Register_GATHER() { + static TfLiteRegistration r = {nullptr, nullptr, gather::Prepare, + gather::Eval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/gather_test.cc b/tensorflow/contrib/lite/kernels/gather_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..6343d3b4ef20ae3e030396ec1b6adbcf83a3e45f --- /dev/null +++ b/tensorflow/contrib/lite/kernels/gather_test.cc @@ -0,0 +1,121 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class GatherOpModel : public SingleOpModel { + public: + GatherOpModel(std::initializer_list input_shape, TensorType input_type, + std::initializer_list positions_shape) { + input_ = AddInput(input_type); + positions_ = AddInput(TensorType_INT32); + output_ = AddOutput(input_type); + SetBuiltinOp(BuiltinOperator_GATHER, BuiltinOptions_GatherOptions, + CreateGatherOptions(builder_, 0).Union()); + BuildInterpreter({input_shape, positions_shape}); + } + + void SetInputFloat(std::initializer_list data) { + PopulateTensor(input_, data); + } + + void SetInputUint8(std::initializer_list data) { + PopulateTensor(input_, data); + } + + void SetInput(std::initializer_list data) { + PopulateStringTensor(input_, data); + } + + void SetPositions(std::initializer_list data) { + PopulateTensor(positions_, data); + } + + std::vector GetOutputFloat() { return ExtractVector(output_); } + std::vector GetOutputUint8() { + return ExtractVector(output_); + } + std::vector GetOutputString() { + return ExtractVector(output_); + } + std::vector GetOutputShape() { return GetTensorShape(output_); } + + protected: + int input_; + int positions_; + int output_; +}; + +TEST(GatherOpTest, Shuffle) { + GatherOpModel m({2, 2}, TensorType_FLOAT32, {2}); + m.SetInputFloat({-2.0, 0.2, 0.7, 0.8}); + m.SetPositions({1, 0}); + m.Invoke(); + EXPECT_THAT(m.GetOutputFloat(), + ElementsAreArray(ArrayFloatNear({0.7, 0.8, -2, 0.2}))); +} + +TEST(FloatGatherOpTest, Duplicate) { + GatherOpModel m({1, 2, 2}, TensorType_FLOAT32, {2}); + m.SetInputFloat({-2.0, 0.2, 0.7, 0.8}); + m.SetPositions({0, 0}); + m.Invoke(); + EXPECT_THAT( + m.GetOutputFloat(), + ElementsAreArray(ArrayFloatNear({-2, 0.2, 0.7, 0.8, -2, 0.2, 0.7, 0.8}))); +} + +TEST(FloatGatherOpTest, Slice) { + GatherOpModel m({4, 1}, TensorType_FLOAT32, {2}); + m.SetInputFloat({-2.0, 0.2, 0.7, 0.8}); + m.SetPositions({1, 3}); + m.Invoke(); + EXPECT_THAT(m.GetOutputFloat(), ElementsAreArray(ArrayFloatNear({0.2, 0.8}))); +} + +TEST(Uint8tGatherOpTest, Shuffle) { + GatherOpModel m({2, 2}, TensorType_UINT8, {2}); + m.SetInputUint8({133, 134, 14, 15}); + m.SetPositions({1, 0}); + m.Invoke(); + + EXPECT_THAT(m.GetOutputUint8(), ElementsAreArray({14, 15, 133, 134})); +} + +TEST(GatherOpTest, SimpleString) { + GatherOpModel m({3}, TensorType_STRING, {2}); + m.SetInput({"A", "B", "C"}); + m.SetPositions({0, 2}); + m.Invoke(); + ASSERT_THAT(m.GetOutputShape(), ElementsAreArray({2})); + EXPECT_THAT(m.GetOutputString(), ElementsAreArray({"A", "C"})); +} +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/hashtable_lookup_test.cc b/tensorflow/contrib/lite/kernels/hashtable_lookup_test.cc index 916a23225e2ad3c5645a7809169677a7a8880535..cb6038f9009a3865661e7b4f075c3033166d0f91 100644 --- a/tensorflow/contrib/lite/kernels/hashtable_lookup_test.cc +++ b/tensorflow/contrib/lite/kernels/hashtable_lookup_test.cc @@ -170,7 +170,7 @@ TEST(HashtableLookupOpTest, TestString) { } // namespace tflite int main(int argc, char** argv) { - // On Linux, add: tflite::LogToStderr(); + ::tflite::LogToStderr(); ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } diff --git a/tensorflow/contrib/lite/kernels/internal/BUILD b/tensorflow/contrib/lite/kernels/internal/BUILD index 288534099b9e090ce0c223a401b4152ca6ffb61f..a3ecb2ebf6a889729954d1e447997c510e8ff6d4 100644 --- a/tensorflow/contrib/lite/kernels/internal/BUILD +++ b/tensorflow/contrib/lite/kernels/internal/BUILD @@ -124,6 +124,13 @@ config_setting( }, ) +config_setting( + name = "freebsd", + values = { + "cpu": "freebsd", + }, +) + cc_library( name = "optimized_base", srcs = [], @@ -147,6 +154,7 @@ cc_library( ":x86": tflite_deps_intel, ":x86_64": tflite_deps_intel, ":darwin": tflite_deps_intel, + ":freebsd": tflite_deps_intel, "//conditions:default": [], }), ) @@ -224,6 +232,7 @@ cc_library( ":x86": tflite_deps_intel, ":x86_64": tflite_deps_intel, ":darwin": tflite_deps_intel, + ":freebsd": tflite_deps_intel, "//conditions:default": [], }), ) diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h index 974611f52ac74cec275f978c5af5bd561688db78..da34c8aef94b1c69e661bd33fcb518e73034c4bd 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h @@ -311,6 +311,9 @@ struct FloatDepthwiseConvKernel { } }; +// Note this implementation is very slow for input_depths < 8 +// (e.g. comparable to reference implementation) see, specializations for +// input_depth=3 below. template <> struct FloatDepthwiseConvKernel { static void Run(int num_output_pixels, int input_depth, int depth_multiplier, @@ -417,6 +420,74 @@ struct FloatDepthwiseConvKernel { } }; +template <> +struct FloatDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const float* input_ptr, int input_ptr_increment, + const float* filter_ptr, float* acc_buffer_ptr) { + // Load the filters + float32x2_t filter[3]; + for (int i = 0; i < 3; i++) { + filter[i] = vld1_f32(filter_ptr + 2 * i); + } + // Handle one output pixel at a time. + for (int outp = 0; outp < num_output_pixels; outp++) { + const float32x2_t input01 = vld1_f32(input_ptr); + const float32x2_t input2 = vld1_dup_f32(input_ptr + 2); + // Load the accumulators from acc_buffer + float32x2_t acc[3]; + for (int i = 0; i < 3; i++) { + acc[i] = vld1_f32(acc_buffer_ptr + 2 * i); + } + // Multiply-accumulate for each input channel there 2 outputs + acc[0] = vmla_lane_f32(acc[0], filter[0], input01, 0); + acc[1] = vmla_lane_f32(acc[1], filter[1], input01, 1); + acc[2] = vmla_lane_f32(acc[2], filter[2], input2, 0); + // Store the accumulators back to acc_buffer + for (int i = 0; i < 3; i++) { + vst1_f32(acc_buffer_ptr + 2 * i, acc[i]); + } + acc_buffer_ptr += 6; + input_ptr += input_ptr_increment; + } + } +}; + +template <> +struct FloatDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const float* input_ptr, int input_ptr_increment, + const float* filter_ptr, float* acc_buffer_ptr) { + // Load the filters + float32x4_t filter[3]; + for (int i = 0; i < 3; i++) { + filter[i] = vld1q_f32(filter_ptr + 4 * i); + } + // Handle one output pixel at a time. + for (int outp = 0; outp < num_output_pixels; outp++) { + // NOTE: we only want 3 values, so we read it as two ops where + // the second op just duplicates the lane + const float32x2_t input01 = vld1_f32(input_ptr); + const float32x2_t input2 = vld1_dup_f32(input_ptr + 2); + // Load the accumulators from acc_buffer + float32x4_t acc[3]; + for (int i = 0; i < 3; i++) { + acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i); + } + // Multiply-accumulate all outputs. + acc[0] = vmlaq_lane_f32(acc[0], filter[0], input01, 0); + acc[1] = vmlaq_lane_f32(acc[1], filter[1], input01, 1); + acc[2] = vmlaq_lane_f32(acc[2], filter[2], input2, 0); + // Store the accumulators back to acc_buffer + for (int i = 0; i < 3; i++) { + vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 12; + input_ptr += input_ptr_increment; + } + } +}; + template <> struct FloatDepthwiseConvKernel { static void Run(int num_output_pixels, int input_depth, int depth_multiplier, @@ -857,6 +928,8 @@ inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims, TFMINI_USE_DEPTHWISECONV_KERNEL(true, 1, 8) TFMINI_USE_DEPTHWISECONV_KERNEL(true, 1, 32) TFMINI_USE_DEPTHWISECONV_KERNEL(true, 2, 1) + TFMINI_USE_DEPTHWISECONV_KERNEL(true, 3, 2) + TFMINI_USE_DEPTHWISECONV_KERNEL(true, 3, 4) TFMINI_USE_DEPTHWISECONV_KERNEL(true, 4, 1) // Finally, the kernels allowing a variable input depth, diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h index 1cd6442c83db77affa17c3a494475c61a9717105..ded5ae8ff50cfc5337a5ea5f6e4880b701246aa6 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h @@ -1868,6 +1868,61 @@ inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims, output_data, output_dims); } +// TODO(aselle): This is not actually optimized yet. +inline void Div(const float* input1_data, const Dims<4>& input1_dims, + const float* input2_data, const Dims<4>& input2_dims, + float output_activation_min, float output_activation_max, + float* output_data, const Dims<4>& output_dims) { + const int batches = + MatchingArraySize(input1_dims, 3, input2_dims, 3, output_dims, 3); + const int height = + MatchingArraySize(input1_dims, 2, input2_dims, 2, output_dims, 2); + const int width = + MatchingArraySize(input1_dims, 1, input2_dims, 1, output_dims, 1); + const int depth = + MatchingArraySize(input1_dims, 0, input2_dims, 0, output_dims, 0); + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + for (int c = 0; c < depth; ++c) { + output_data[Offset(output_dims, c, x, y, b)] = + ActivationFunctionWithMinMax( + input1_data[Offset(input1_dims, c, x, y, b)] / + input2_data[Offset(input2_dims, c, x, y, b)], + output_activation_min, output_activation_max); + } + } + } + } +} + +// TODO(aselle): This is not actually optimized yet. +inline void Sub(const float* input1_data, const Dims<4>& input1_dims, + const float* input2_data, const Dims<4>& input2_dims, + float output_activation_min, float output_activation_max, + float* output_data, const Dims<4>& output_dims) { + const int batches = + MatchingArraySize(input1_dims, 3, input2_dims, 3, output_dims, 3); + const int height = + MatchingArraySize(input1_dims, 2, input2_dims, 2, output_dims, 2); + const int width = + MatchingArraySize(input1_dims, 1, input2_dims, 1, output_dims, 1); + const int depth = + MatchingArraySize(input1_dims, 0, input2_dims, 0, output_dims, 0); + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + for (int c = 0; c < depth; ++c) { + output_data[Offset(output_dims, c, x, y, b)] = + ActivationFunctionWithMinMax( + input1_data[Offset(input1_dims, c, x, y, b)] - + input2_data[Offset(input2_dims, c, x, y, b)], + output_activation_min, output_activation_max); + } + } + } + } +} template void Concatenation(int concat_dim, const Scalar* const* input_data, const Dims<4>* const* input_dims, int inputs_count, @@ -3381,10 +3436,11 @@ inline void SpaceToBatchND(const T* input_data, const Dims<4>& input_dims, for (int out_h = 0; out_h < output_height; ++out_h) { for (int out_w = 0; out_w < output_width; ++out_w) { T* out = output_data + Offset(output_dims, 0, out_w, out_h, out_b); - if (out_h * block_shape_height < padding_top || - out_h * block_shape_height >= padding_top + input_height || - out_w * block_shape_width < padding_left || - out_w * block_shape_width >= padding_left + input_width) { + if (out_h * block_shape_height + shift_h < padding_top || + out_h * block_shape_height + shift_h >= + padding_top + input_height || + out_w * block_shape_width + shift_w < padding_left || + out_w * block_shape_width + shift_w >= padding_left + input_width) { memset(out, 0, depth * sizeof(T)); } else { const T* in = @@ -3704,6 +3760,43 @@ void TensorFlowMaximum(const T* input1_data, const Dims<4>& input1_dims, auto max_value = input2_data[0]; output_map.array() = input1_map.array().max(max_value); } + +template +void ArgMax(const T3* axis, const T1* input_data, const Dims<4>& input_dims, + T2* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("ArgMax"); + + // The current ArgMax implemention can only determine the index of the maximum + // value in the last dimension. So the axis argument is ignored. + TFLITE_DCHECK_EQ(axis[0], 3); + + // For ArgMax, the number of output dimensions = (number of input dimensions - + // 1). For the sake of simplicity, the output dimensions are equal to the + // input dimensions here. We enforce the constraint that the last dimension + // must always be 1. + TFLITE_DCHECK_EQ(ArraySize(output_dims, 0), 1); + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = ArraySize(input_dims, 0); + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + auto max_value = input_data[Offset(input_dims, 0, x, y, b)]; + int max_index = 0; + for (int d = 1; d < depth; ++d) { + const auto& curr_value = input_data[Offset(input_dims, d, x, y, b)]; + if (curr_value > max_value) { + max_value = curr_value; + max_index = d; + } + } + output_data[Offset(output_dims, 0, x, y, b)] = max_index; + } + } + } +} + } // namespace optimized_ops } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h index c2ab78000b81485f037c507933cd024e70f39850..7f90d731b8454a020ab273e6b5591ed90aab14c7 100644 --- a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h +++ b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h @@ -22,7 +22,7 @@ limitations under the License. namespace tflite { namespace tensor_utils { -// Limit a float input f betweeen +abs_limit and -abs_limit. +// Limit a float input f between +abs_limit and -abs_limit. float PortableClip(float f, float abs_limit); // Multiply a matrix by a batch vector, and store results in a batch-size diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h index f5c43f1fd98f130507f6b3f216c4a83593d26a13..7f1f3143e8e2fa1e4a7c2a1902920e9e86ad7f68 100644 --- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h +++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h @@ -1149,6 +1149,60 @@ inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims, output_data, output_dims); } +inline void Div(const float* input1_data, const Dims<4>& input1_dims, + const float* input2_data, const Dims<4>& input2_dims, + float output_activation_min, float output_activation_max, + float* output_data, const Dims<4>& output_dims) { + const int batches = + MatchingArraySize(input1_dims, 3, input2_dims, 3, output_dims, 3); + const int height = + MatchingArraySize(input1_dims, 2, input2_dims, 2, output_dims, 2); + const int width = + MatchingArraySize(input1_dims, 1, input2_dims, 1, output_dims, 1); + const int depth = + MatchingArraySize(input1_dims, 0, input2_dims, 0, output_dims, 0); + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + for (int c = 0; c < depth; ++c) { + output_data[Offset(output_dims, c, x, y, b)] = + ActivationFunctionWithMinMax( + input1_data[Offset(input1_dims, c, x, y, b)] / + input2_data[Offset(input2_dims, c, x, y, b)], + output_activation_min, output_activation_max); + } + } + } + } +} + +inline void Sub(const float* input1_data, const Dims<4>& input1_dims, + const float* input2_data, const Dims<4>& input2_dims, + float output_activation_min, float output_activation_max, + float* output_data, const Dims<4>& output_dims) { + const int batches = + MatchingArraySize(input1_dims, 3, input2_dims, 3, output_dims, 3); + const int height = + MatchingArraySize(input1_dims, 2, input2_dims, 2, output_dims, 2); + const int width = + MatchingArraySize(input1_dims, 1, input2_dims, 1, output_dims, 1); + const int depth = + MatchingArraySize(input1_dims, 0, input2_dims, 0, output_dims, 0); + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + for (int c = 0; c < depth; ++c) { + output_data[Offset(output_dims, c, x, y, b)] = + ActivationFunctionWithMinMax( + input1_data[Offset(input1_dims, c, x, y, b)] - + input2_data[Offset(input2_dims, c, x, y, b)], + output_activation_min, output_activation_max); + } + } + } + } +} + template void Concatenation(int concat_dim, const Scalar* const* input_data, const Dims<4>* const* input_dims, int inputs_count, @@ -2183,10 +2237,11 @@ inline void SpaceToBatchND(const T* input_data, const Dims<4>& input_dims, for (int out_h = 0; out_h < output_height; ++out_h) { for (int out_w = 0; out_w < output_width; ++out_w) { T* out = output_data + Offset(output_dims, 0, out_w, out_h, out_b); - if (out_h * block_shape_height < padding_top || - out_h * block_shape_height >= padding_top + input_height || - out_w * block_shape_width < padding_left || - out_w * block_shape_width >= padding_left + input_width) { + if (out_h * block_shape_height + shift_h < padding_top || + out_h * block_shape_height + shift_h >= + padding_top + input_height || + out_w * block_shape_width + shift_w < padding_left || + out_w * block_shape_width + shift_w >= padding_left + input_width) { memset(out, 0, depth * sizeof(T)); } else { const T* in = @@ -2335,6 +2390,64 @@ inline void Slice(const T* input_data, const Dims<4>& input_dims, } } +template +inline void Mean(T* input_data, const int* input_dims, const int input_num_dims, + T* output_data, const int* output_dims, + const int output_num_dims, const int* axis, + const int num_axis_dimensions, bool keep_dims, int* temp_index, + int* resolved_axis) { + // resets output data. + size_t num_outputs = 1; + for (int idx = 0; idx < output_num_dims; ++idx) { + num_outputs *= static_cast(output_dims[idx]); + } + for (size_t idx = 0; idx < num_outputs; ++idx) { + output_data[idx] = 0; + } + // resets temp index. + for (int idx = 0; idx < input_num_dims; ++idx) { + temp_index[idx] = 0; + } + // resolves axis. + int num_resolved_axis = 0; + for (int idx = 0; idx < num_axis_dimensions; ++idx) { + int current = axis[idx]; + TFLITE_DCHECK(current < input_num_dims && current + input_num_dims >= 0); + if (current < 0) { + current += input_num_dims; + } + bool is_dup = false; + for (int j = 0; j < num_resolved_axis; ++j) { + if (resolved_axis[j] == current) { + is_dup = true; + break; + } + } + if (!is_dup) { + resolved_axis[num_resolved_axis++] = current; + } + } + // iterates through input_data. + for (bool has_next = true; has_next; + has_next = NextIndex(input_num_dims, input_dims, temp_index)) { + size_t input_offset = + ReducedOutputOffset(input_num_dims, input_dims, temp_index, 0, nullptr); + size_t output_offset = + ReducedOutputOffset(input_num_dims, input_dims, temp_index, + num_resolved_axis, resolved_axis); + output_data[output_offset] += input_data[input_offset]; + } + // takes average by num of elements added to get mean. + size_t num_elements_in_axis = 1; + for (int idx = 0; idx < num_resolved_axis; ++idx) { + num_elements_in_axis *= static_cast(input_dims[resolved_axis[idx]]); + } + for (size_t idx = 0; idx < num_outputs; ++idx) { + output_data[idx] = static_cast(static_cast(output_data[idx]) / + num_elements_in_axis); + } +} + template inline void Mean(const T* input_data, const Dims<4>& input_dims, const std::vector& reduction_indices, T* output_data, @@ -2449,6 +2562,69 @@ void TensorFlowMaximum(const T* input1_data, const Dims<4>& input1_dims, } } +template +void ArgMax(const T3* axis, const T1* input_data, const Dims<4>& input_dims, + T2* output_data, const Dims<4>& output_dims) { + // The current ArgMax implemention can only determine the index of the maximum + // value in the last dimension. So the axis argument is ignored. + TFLITE_DCHECK_EQ(axis[0], 3); + + // For ArgMax, the number of output dimensions = (number of input dimensions - + // 1). For the sake of simplicity, the output dimensions are equal to the + // input dimensions here. We enforce the constraint that the last dimension + // must always be 1. + TFLITE_DCHECK_EQ(ArraySize(output_dims, 0), 1); + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = ArraySize(input_dims, 0); + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + auto max_value = input_data[Offset(input_dims, 0, x, y, b)]; + int max_index = 0; + for (int d = 1; d < depth; ++d) { + const auto& curr_value = input_data[Offset(input_dims, d, x, y, b)]; + if (curr_value > max_value) { + max_value = curr_value; + max_index = d; + } + } + output_data[Offset(output_dims, 0, x, y, b)] = max_index; + } + } + } +} + +template +void Transpose(const T* input, const Dims<4>& input_dims, T* output, + const Dims<4>& output_dims, int* permuted_axes) { + int out_sizes[4]; + // Compute the inverse permutation array so we can do an output centered + // transpose. Also, check to make sure output_dims is matching input_dims. + for (int k = 0; k < 4; k++) { + out_sizes[k] = + MatchingArraySize(input_dims, permuted_axes[k], output_dims, k); + } + + // Naive transpose loop (iterate on output index and compute input index). + int o[4]; // loop index (on output). + int i[4]; + for (o[3] = 0; o[3] < out_sizes[3]; o[3]++) { + i[permuted_axes[3]] = o[3]; + for (o[2] = 0; o[2] < out_sizes[2]; o[2]++) { + i[permuted_axes[2]] = o[2]; + for (o[1] = 0; o[1] < out_sizes[1]; o[1]++) { + i[permuted_axes[1]] = o[1]; + for (o[0] = 0; o[0] < out_sizes[0]; o[0]++) { + i[permuted_axes[0]] = o[0]; + output[Offset(output_dims, o)] = input[Offset(input_dims, i)]; + } + } + } + } +} + } // namespace reference_ops } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/internal/tensor.h b/tensorflow/contrib/lite/kernels/internal/tensor.h index ee4111e0416560d94d513c528971bdf3bf819662..1961e1a2d5ecd4fd20c6f442b79dc88ed28062fe 100644 --- a/tensorflow/contrib/lite/kernels/internal/tensor.h +++ b/tensorflow/contrib/lite/kernels/internal/tensor.h @@ -41,8 +41,7 @@ inline int32_t* GetTensorData(TfLiteTensor* tensor) { template <> inline int64_t* GetTensorData(TfLiteTensor* tensor) { - return tensor != nullptr ? reinterpret_cast(tensor->data.raw) - : nullptr; + return tensor != nullptr ? tensor->data.i64 : nullptr; } inline int RemapDim(int max_dimensions, int d) { diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_utils.h b/tensorflow/contrib/lite/kernels/internal/tensor_utils.h index 0e69ef5982f01e364d865684652d1dfecab6fee3..e7e2994397650004c7ba442fa1803290e6b12302 100644 --- a/tensorflow/contrib/lite/kernels/internal/tensor_utils.h +++ b/tensorflow/contrib/lite/kernels/internal/tensor_utils.h @@ -20,7 +20,7 @@ limitations under the License. namespace tflite { namespace tensor_utils { -// Limit a float input f betweeen +abs_limit and -abs_limit. +// Limit a float input f between +abs_limit and -abs_limit. float Clip(float f, float abs_limit); // Multiply a matrix by a batch vector, and store results in a batch-size diff --git a/tensorflow/contrib/lite/kernels/internal/types.h b/tensorflow/contrib/lite/kernels/internal/types.h index 07f1cb40045fff3ae47ed4efa6ec43b0cb88a0a7..5989ac8fcdec101c14dd7b04d89fe8c7bfce0a10 100644 --- a/tensorflow/contrib/lite/kernels/internal/types.h +++ b/tensorflow/contrib/lite/kernels/internal/types.h @@ -27,6 +27,58 @@ struct Dims { int strides[N]; }; +// Gets next index to iterate through a multidimensional array. +inline bool NextIndex(const int num_dims, const int* dims, int* current) { + TFLITE_DCHECK_GT(num_dims, 0); + TFLITE_DCHECK(dims != nullptr); + TFLITE_DCHECK(current != nullptr); + int carry = 1; + for (int idx = num_dims - 1; idx >= 0; --idx) { + int current_val = current[idx] + carry; + TFLITE_DCHECK_GE(dims[idx], current_val); + if (dims[idx] == current_val) { + current[idx] = 0; + } else { + current[idx] = current_val; + carry = 0; + break; + } + } + return (carry == 0); +} + +// Gets offset of index if reducing on axis. When reducing, the flattened offset +// will not change, if the input index changes on the given axis. For example, +// if you have a 3D tensor and you are reducing to 2D by eliminating axis 0, +// then index (0, 1, 2) and index (1, 1, 2) will map to the same flattened +// offset. +// TODO(kanlig): uses Dims to represent dimensions. +inline size_t ReducedOutputOffset(const int num_dims, const int* dims, + const int* index, const int num_axis, + const int* axis) { + TFLITE_DCHECK_GT(num_dims, 0); + TFLITE_DCHECK(dims != nullptr); + TFLITE_DCHECK(index != nullptr); + size_t offset = 0; + for (int idx = 0; idx < num_dims; ++idx) { + // if we need to skip this axis + bool is_axis = false; + if (axis != nullptr) { + for (int axis_idx = 0; axis_idx < num_axis; ++axis_idx) { + if (idx == axis[axis_idx]) { + is_axis = true; + break; + } + } + } + if (!is_axis) { + offset = offset * static_cast(dims[idx]) + + static_cast(index[idx]); + } + } + return offset; +} + inline int Offset(const Dims<4>& dims, int i0, int i1, int i2, int i3) { TFLITE_DCHECK(i0 >= 0 && i0 < dims.sizes[0]); TFLITE_DCHECK(i1 >= 0 && i1 < dims.sizes[1]); @@ -36,6 +88,10 @@ inline int Offset(const Dims<4>& dims, int i0, int i1, int i2, int i3) { i3 * dims.strides[3]; } +inline int Offset(const Dims<4>& dims, int* index) { + return Offset(dims, index[0], index[1], index[2], index[3]); +} + // Get array size, DCHECKing that the dim index is in range. template int ArraySize(const Dims& array, int index) { diff --git a/tensorflow/contrib/lite/kernels/l2norm.cc b/tensorflow/contrib/lite/kernels/l2norm.cc index f43aa372b6398a38e57dd38f3d7c7db2bd3aefc1..ee8bfe56d95e9f383ef49b40b8f58b63d61da3e1 100644 --- a/tensorflow/contrib/lite/kernels/l2norm.cc +++ b/tensorflow/contrib/lite/kernels/l2norm.cc @@ -43,8 +43,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TfLiteTensor* input = GetInput(context, node, kInputTensor); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); - // TODO(ahentz): Our current implementations rely on the inputs being 4D. - TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4); + TF_LITE_ENSURE(context, NumDimensions(input) <= 4); // TODO(ahentz): Our current implementations only support float32. TF_LITE_ENSURE_EQ(context, output->type, kTfLiteFloat32); @@ -54,12 +53,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // activations. TF_LITE_ENSURE_EQ(context, params->activation, kTfLiteActNone); - TfLiteIntArray* output_size = TfLiteIntArrayCreate(4); - output_size->data[0] = input->dims->data[0]; - output_size->data[1] = input->dims->data[1]; - output_size->data[2] = input->dims->data[2]; - output_size->data[3] = input->dims->data[3]; - + TfLiteIntArray* output_size = TfLiteIntArrayCopy(input->dims); return context->ResizeTensor(context, output, output_size); } diff --git a/tensorflow/contrib/lite/kernels/l2norm_test.cc b/tensorflow/contrib/lite/kernels/l2norm_test.cc index b1db89b8bd3474ac868d7215e4a0de12088c48ef..30e103f3303484c339ef98e6a68e0438291c102f 100644 --- a/tensorflow/contrib/lite/kernels/l2norm_test.cc +++ b/tensorflow/contrib/lite/kernels/l2norm_test.cc @@ -57,7 +57,7 @@ TEST(L2NormOpTest, SimpleTest) { } // namespace tflite int main(int argc, char** argv) { - // On Linux, add: tflite::LogToStderr(); + ::tflite::LogToStderr(); ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } diff --git a/tensorflow/contrib/lite/kernels/local_response_norm_test.cc b/tensorflow/contrib/lite/kernels/local_response_norm_test.cc index 63a8b0a3d0186def7da2c9f31481721f1a55281c..d75ce258a04c820d8f82735988c01d0154ef36f2 100644 --- a/tensorflow/contrib/lite/kernels/local_response_norm_test.cc +++ b/tensorflow/contrib/lite/kernels/local_response_norm_test.cc @@ -95,7 +95,7 @@ TEST(LocalResponseNormOpTest, SmallRadius) { } // namespace tflite int main(int argc, char** argv) { - // On Linux, add: tflite::LogToStderr(); + ::tflite::LogToStderr(); ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } diff --git a/tensorflow/contrib/lite/kernels/lsh_projection_test.cc b/tensorflow/contrib/lite/kernels/lsh_projection_test.cc index 1011927848d586c8541fb694914b5eee123cb8dc..414d728dfc153058ec878d3c766f58e86815cd3f 100644 --- a/tensorflow/contrib/lite/kernels/lsh_projection_test.cc +++ b/tensorflow/contrib/lite/kernels/lsh_projection_test.cc @@ -117,7 +117,7 @@ TEST(LSHProjectionOpTest2, Sparse3DInputs) { } // namespace tflite int main(int argc, char** argv) { - // On Linux, add: tflite::LogToStderr(); + ::tflite::LogToStderr(); ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } diff --git a/tensorflow/contrib/lite/kernels/lstm_test.cc b/tensorflow/contrib/lite/kernels/lstm_test.cc index be4c7ddbf88fc902368cda13aff72f5aecb9dac4..c068286b0d84bcb51ebb0e239350a42863de6523 100644 --- a/tensorflow/contrib/lite/kernels/lstm_test.cc +++ b/tensorflow/contrib/lite/kernels/lstm_test.cc @@ -1081,8 +1081,7 @@ TEST(LSTMOpTest, BlackBoxTestWithPeepholeWithProjectionNoClipping) { } // namespace tflite int main(int argc, char** argv) { - // On Linux, add: tflite::LogToStderr(); - tflite::LogToStderr(); + ::tflite::LogToStderr(); ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } diff --git a/tensorflow/contrib/lite/kernels/mean.cc b/tensorflow/contrib/lite/kernels/mean.cc new file mode 100644 index 0000000000000000000000000000000000000000..540e5a364dd60a42c316199d0ebe878ae07e6756 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/mean.cc @@ -0,0 +1,200 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace mean { + +// This file has reference implementation of Mean. +enum KernelType { + kReference, +}; + +struct MeanContext { + MeanContext(TfLiteContext* context, TfLiteNode* node) { + params = reinterpret_cast(node->builtin_data); + input = GetInput(context, node, 0); + output = GetOutput(context, node, 0); + } + TfLiteMeanParams* params; + TfLiteTensor* input; + TfLiteTensor* output; +}; + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + // Creates two temp tensors to store index and axis for internal + // implementation only. + auto* scratch_tensor_index = new int; + context->AddTensors(context, 2, scratch_tensor_index); + return scratch_tensor_index; +} + +void Free(TfLiteContext* context, void* buffer) { + delete reinterpret_cast(buffer); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE(context, NumInputs(node) == 1 || NumInputs(node) == 2); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + MeanContext op_context(context, node); + int input_num_dims = NumDimensions(op_context.input); + int axis_num_dims = op_context.params->num_axis_dimensions; + + // Creates a temp index to iterate through input data. + int* scratch_tensor_index = reinterpret_cast(node->user_data); + TfLiteIntArrayFree(node->temporaries); + node->temporaries = TfLiteIntArrayCreate(2); + node->temporaries->data[0] = *scratch_tensor_index; + TfLiteTensor* scratch_tensor = &context->tensors[node->temporaries->data[0]]; + scratch_tensor->type = kTfLiteInt32; + scratch_tensor->allocation_type = kTfLiteArenaRw; + TfLiteIntArray* index_size = TfLiteIntArrayCreate(1); + index_size->data[0] = input_num_dims; + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, scratch_tensor, index_size)); + + // Creates a temp tensor to store resolved axis given input data. + node->temporaries->data[1] = *scratch_tensor_index + 1; + TfLiteTensor* axis_tensor = &context->tensors[node->temporaries->data[1]]; + axis_tensor->type = kTfLiteInt32; + axis_tensor->allocation_type = kTfLiteArenaRw; + TfLiteIntArray* axis_size = TfLiteIntArrayCreate(1); + axis_size->data[0] = op_context.params->num_axis_dimensions; + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, axis_tensor, axis_size)); + + // Determines size of output tensor. + const TfLiteIntArray* input_dims = op_context.input->dims; + const int* axis = op_context.params->axis; + if (op_context.params->keep_dims) { + TfLiteIntArray* output_dims = TfLiteIntArrayCreate(input_num_dims); + for (int idx = 0; idx < input_num_dims; ++idx) { + bool is_axis = false; + for (int axis_idx = 0; axis_idx < axis_num_dims; ++axis_idx) { + if (axis[axis_idx] == idx || axis[axis_idx] + input_num_dims == idx) { + is_axis = true; + break; + } + } + if (is_axis) { + output_dims->data[idx] = 1; + } else { + output_dims->data[idx] = input_dims->data[idx]; + } + } + return context->ResizeTensor(context, op_context.output, output_dims); + } else { + // Calculates size of reducing axis. + int num_reduce_axis = axis_num_dims; + for (int i = 0; i < axis_num_dims; ++i) { + int current = axis[i]; + if (current < 0) { + current += input_num_dims; + } + TF_LITE_ENSURE(context, current >= 0 && current < input_num_dims); + for (int j = 0; j < i; ++j) { + int previous = axis[j]; + if (previous < 0) { + previous += input_num_dims; + } + if (current == previous) { + --num_reduce_axis; + break; + } + } + } + // Determines output dimensions. + TfLiteIntArray* output_dims = + TfLiteIntArrayCreate(input_num_dims - num_reduce_axis); + int num_skip_axis = 0; + for (int idx = 0; idx < input_num_dims; ++idx) { + bool is_axis = false; + for (int axis_idx = 0; axis_idx < axis_num_dims; ++axis_idx) { + if (axis[axis_idx] == idx || axis[axis_idx] + input_num_dims == idx) { + ++num_skip_axis; + is_axis = true; + break; + } + } + if (!is_axis) { + output_dims->data[idx - num_skip_axis] = input_dims->data[idx]; + } + } + return context->ResizeTensor(context, op_context.output, output_dims); + } +} + +template +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + MeanContext op_context(context, node); + TfLiteTensor* temp_index = &context->tensors[node->temporaries->data[0]]; + TfLiteTensor* resolved_axis = &context->tensors[node->temporaries->data[1]]; + +#define TF_LITE_MEAN(kernel_type, data_type) \ + kernel_type::Mean<>( \ + GetTensorData(op_context.input), \ + op_context.input->dims->data, op_context.input->dims->size, \ + GetTensorData(op_context.output), \ + op_context.output->dims->data, op_context.output->dims->size, \ + op_context.params->axis, op_context.params->num_axis_dimensions, \ + op_context.params->keep_dims, GetTensorData(temp_index), \ + GetTensorData(resolved_axis)) + + if (kernel_type == kReference) { + switch (op_context.input->type) { + case kTfLiteFloat32: + TF_LITE_MEAN(reference_ops, float); + break; + case kTfLiteInt32: + TF_LITE_MEAN(reference_ops, int); + break; + case kTfLiteUInt8: + TF_LITE_MEAN(reference_ops, uint8_t); + break; + case kTfLiteInt64: + TF_LITE_MEAN(reference_ops, int64_t); + break; + default: + return kTfLiteError; + } + } +#undef TF_LITE_MEAN + return kTfLiteOk; +} + +} // namespace mean + +TfLiteRegistration* Register_MEAN_REF() { + static TfLiteRegistration r = {mean::Init, mean::Free, mean::Prepare, + mean::Eval}; + return &r; +} + +// TODO(kanlig): add optimized implementation of Mean. +TfLiteRegistration* Register_MEAN() { return Register_MEAN_REF(); } + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/mean_test.cc b/tensorflow/contrib/lite/kernels/mean_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..4305c0632f5a52b858a056109187ad4a0cc2e46e --- /dev/null +++ b/tensorflow/contrib/lite/kernels/mean_test.cc @@ -0,0 +1,90 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class BaseMeanOpModel : public SingleOpModel { + public: + BaseMeanOpModel(const TensorData& input, const TensorData& output, + std::initializer_list axis, bool keep_dims) { + input_ = AddInput(input); + output_ = AddOutput(output); + SetBuiltinOp( + BuiltinOperator_MEAN, BuiltinOptions_MeanOptions, + CreateMeanOptions(builder_, builder_.CreateVector(axis), keep_dims) + .Union()); + BuildInterpreter({GetShape(input_)}); + } + + int input() { return input_; } + + protected: + int input_; + int output_; +}; + +class FloatMeanOpModel : public BaseMeanOpModel { + public: + using BaseMeanOpModel::BaseMeanOpModel; + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetOutputShape() { return GetTensorShape(output_); } +}; + +TEST(FloatMeanOpTest, NotKeepDims) { + std::initializer_list data = { + 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, + 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; + FloatMeanOpModel m({TensorType_FLOAT32, {4, 3, 2}}, {TensorType_FLOAT32, {2}}, + {1, 0, -3, -3}, false); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({12, 13}))); +} + +TEST(FloatMeanOpTest, KeepDims) { + std::initializer_list data = { + 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, + 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; + FloatMeanOpModel m({TensorType_FLOAT32, {4, 3, 2}}, {TensorType_FLOAT32, {3}}, + {0, 2}, true); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3, 1})); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({10.5, 12.5, 14.5}))); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/mul_test.cc b/tensorflow/contrib/lite/kernels/mul_test.cc index 4b858e1f396252e7f7bdc231bc1e00f47277f08a..8838b300c0af167bf2ffcf944fc7c31d6173f462 100644 --- a/tensorflow/contrib/lite/kernels/mul_test.cc +++ b/tensorflow/contrib/lite/kernels/mul_test.cc @@ -78,9 +78,10 @@ TEST(FloatMulOpTest, NoActivation) { ElementsAreArray(ArrayFloatNear({-0.2, 0.04, 0.21, 0.4}))); } -TEST(FloatMulOpTest, ActivationRELU1) { +TEST(FloatMulOpTest, ActivationRELU_N1_TO_1) { FloatMulOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}}, - {TensorType_FLOAT32, {}}, ActivationFunctionType_RELU1); + {TensorType_FLOAT32, {}}, + ActivationFunctionType_RELU_N1_TO_1); m.PopulateTensor(m.input1(), {-2.0, 0.2, 0.7, 0.8}); m.PopulateTensor(m.input2(), {0.1, 0.2, 0.3, 5}); m.Invoke(); @@ -120,8 +121,7 @@ TEST(QuantizedMulOpTest, NoActivation) { } // namespace tflite int main(int argc, char** argv) { - // On Linux, add: tflite::LogToStderr(); - tflite::LogToStderr(); + ::tflite::LogToStderr(); ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } diff --git a/tensorflow/contrib/lite/kernels/op_macros.h b/tensorflow/contrib/lite/kernels/op_macros.h index 7535afaf8ea52d855e2e4773e56ce2118a16447c..63670efcb1e6349317aa5c75756707fb7a7fa2aa 100644 --- a/tensorflow/contrib/lite/kernels/op_macros.h +++ b/tensorflow/contrib/lite/kernels/op_macros.h @@ -15,6 +15,8 @@ limitations under the License. #ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_OP_UTIL_H_ #define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_OP_UTIL_H_ +#include + #define TF_LITE_FATAL(msg) \ do { \ fprintf(stderr, "%s\n", (msg)); \ diff --git a/tensorflow/contrib/lite/kernels/optional_tensor_test.cc b/tensorflow/contrib/lite/kernels/optional_tensor_test.cc index 8e9cc07656c8bea83f7cb78ca0b6cc5de7ad1b73..17166715ca30ff3d8ba3d384110e403f8910e39d 100644 --- a/tensorflow/contrib/lite/kernels/optional_tensor_test.cc +++ b/tensorflow/contrib/lite/kernels/optional_tensor_test.cc @@ -334,8 +334,7 @@ TEST(LSTMOpTest, BlackBoxTestWithCifgWithPeepholeNoProjectionNoClipping) { } // namespace tflite int main(int argc, char** argv) { - // On Linux, add: tflite::LogToStderr(); - tflite::LogToStderr(); + ::tflite::LogToStderr(); ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } diff --git a/tensorflow/contrib/lite/kernels/pad.cc b/tensorflow/contrib/lite/kernels/pad.cc new file mode 100644 index 0000000000000000000000000000000000000000..1a0d9d1505d41fb7948863f9da9e2a4f1b61e4f9 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/pad.cc @@ -0,0 +1,159 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace pad { + +// This file has two implementations of Pad. +enum KernelType { + kReference, + kGenericOptimized, +}; + +// TODO(nupurgarg): Padding represented as a tensor is ignored. Only use the +// `left_padding` and `right_padding` specified in `params`. +struct PadContext { + PadContext(TfLiteContext* context, TfLiteNode* node) { + params = reinterpret_cast(node->builtin_data); + input = GetInput(context, node, 0); + output = GetOutput(context, node, 0); + } + TfLitePadParams* params; + TfLiteTensor* input; + TfLiteTensor* output; +}; + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE(context, NumInputs(node) == 1 || NumInputs(node) == 2); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + // Determines size of output tensor. + PadContext op_context(context, node); + int dims = NumDimensions(op_context.input); + TF_LITE_ENSURE_EQ(context, dims, op_context.params->num_dimensions); + TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type); + + // TODO(nupurgarg): Our current implementations rely on the inputs being 4D. + TF_LITE_ENSURE_EQ(context, dims, 4); + + const TfLiteIntArray* input_size = op_context.input->dims; + TfLiteIntArray* output_size = TfLiteIntArrayCreate(dims); + for (int idx = 0; idx < dims; ++idx) { + TF_LITE_ENSURE_MSG(context, + (op_context.params->before_padding[idx] >= 0 && + op_context.params->after_padding[idx] >= 0), + "Pad value has to be greater than equal to 0."); + output_size->data[idx] = + (input_size->data[idx] + op_context.params->before_padding[idx] + + op_context.params->after_padding[idx]); + } + + return context->ResizeTensor(context, op_context.output, output_size); +} + +template +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + PadContext op_context(context, node); + + std::vector before_padding( + op_context.params->before_padding, + op_context.params->before_padding + op_context.params->num_dimensions); + std::vector after_padding( + op_context.params->after_padding, + op_context.params->after_padding + op_context.params->num_dimensions); + + // TODO(nupurgarg): Change TOCO's implementation to use padding arrays + // in forward order (depth, width, height, batch). + // Converts from int[] = {depth, width, height, batch} to int[] = {batch, + // height, width, depth} to match TOCO's implementation of pad in + // referenced_ops.h and optimized_ops.h. + std::reverse(before_padding.begin(), before_padding.end()); + std::reverse(after_padding.begin(), after_padding.end()); + +#define TF_LITE_PAD(type, scalar) \ + type::Pad(GetTensorData(op_context.input), \ + GetTensorDims(op_context.input), before_padding, after_padding, \ + GetTensorData(op_context.output), \ + GetTensorDims(op_context.output)) + + switch (op_context.input->type) { + case kTfLiteFloat32: + if (kernel_type == kReference) { + TF_LITE_PAD(reference_ops, float); + } else if (kernel_type == kGenericOptimized) { + TF_LITE_PAD(optimized_ops, float); + } + break; + case kTfLiteUInt8: + if (kernel_type == kReference) { + TF_LITE_PAD(reference_ops, uint8_t); + } else if (kernel_type == kGenericOptimized) { + TF_LITE_PAD(optimized_ops, uint8_t); + } + break; + case kTfLiteInt32: + if (kernel_type == kReference) { + TF_LITE_PAD(reference_ops, int32_t); + } else if (kernel_type == kGenericOptimized) { + TF_LITE_PAD(optimized_ops, int32_t); + } + break; + case kTfLiteInt64: + if (kernel_type == kReference) { + TF_LITE_PAD(reference_ops, int64_t); + } else if (kernel_type == kGenericOptimized) { + TF_LITE_PAD(optimized_ops, int64_t); + } + break; + default: + context->ReportError(context, "Type is currently not supported by Pad."); + return kTfLiteError; + } +#undef TF_LITE_PAD + return kTfLiteOk; +} + +} // namespace pad + +TfLiteRegistration* Register_PAD_REF() { + static TfLiteRegistration r = {nullptr, nullptr, pad::Prepare, + pad::Eval}; + return &r; +} + +TfLiteRegistration* Register_PAD_GENERIC_OPT() { + static TfLiteRegistration r = {nullptr, nullptr, pad::Prepare, + pad::Eval}; + return &r; +} + +TfLiteRegistration* Register_PAD() { + return Register_PAD_GENERIC_OPT(); +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/pad_test.cc b/tensorflow/contrib/lite/kernels/pad_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..f3ea9417df0e61dcff7a877726ab91c9b22691ba --- /dev/null +++ b/tensorflow/contrib/lite/kernels/pad_test.cc @@ -0,0 +1,99 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class PadOpModel : public SingleOpModel { + public: + PadOpModel(std::initializer_list input_shape, + std::initializer_list before_padding, + std::initializer_list after_padding) { + input_ = AddInput(TensorType_FLOAT32); + output_ = AddOutput(TensorType_FLOAT32); + SetBuiltinOp( + BuiltinOperator_PAD, BuiltinOptions_PadOptions, + CreatePadOptions(builder_, builder_.CreateVector(before_padding), + builder_.CreateVector(after_padding)) + .Union()); + BuildInterpreter({input_shape}); + } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetOutputShape() { return GetTensorShape(output_); } + + private: + int input_; + int output_; +}; + +TEST(PadOpTest, TooManyDimensions) { + EXPECT_DEATH( + PadOpModel({1, 2, 3, 4, 5, 6, 7, 8, 9}, {1, 2, 3, 4, 5, 6, 7, 8, 9}, + {1, 2, 3, 4, 5, 6, 7, 8, 9}), + "dims != 4"); +} + +// TODO(nupurgarg): Test case where before padding and after padding arrays +// don't contain the same number of dimensions. +TEST(PadOpTest, UnequalDimensions) { + EXPECT_DEATH(PadOpModel({1, 1, 2, 1}, {1, 2, 3}, {1, 2, 3}), + "dims != op_context.params->num_dimensions"); +} + +TEST(PadOpTest, InvalidPadValue) { + EXPECT_DEATH(PadOpModel({1, 1, 2, 1}, {0, 1, 2, 0}, {0, -1, -1, 0}), + "Pad value has to be greater than equal to 0."); +} + +TEST(PadOpTest, SimpleTest) { + PadOpModel m({1, 2, 2, 1}, {0, 1, 1, 0}, {0, 1, 1, 0}); + m.SetInput({1, 2, 3, 4}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 0, 0, 0, 0, 1, 2, 0, 0, 3, 4, + 0, 0, 0, 0, 0})); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1})); +} + +TEST(PadOpTest, AdvancedTest) { + // The padding is input in the order of batch, height, width, depth. + PadOpModel m({1, 2, 3, 1}, {0, 0, 1, 0}, {0, 2, 3, 0}); + m.SetInput({1, 2, 3, 4, 5, 6}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({0, 1, 2, 3, 0, 0, 0, 0, 4, 5, 6, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0})); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 7, 1})); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/pooling_test.cc b/tensorflow/contrib/lite/kernels/pooling_test.cc index e1b51ec7d5141bf2a41e7ede3e90ff20ec523819..01c91b2ba905e249c36af19f175c68a7e7f17f6d 100644 --- a/tensorflow/contrib/lite/kernels/pooling_test.cc +++ b/tensorflow/contrib/lite/kernels/pooling_test.cc @@ -155,7 +155,7 @@ TEST(FloatPoolingOpTest, L2Pool) { } // namespace tflite int main(int argc, char** argv) { - // On Linux, add: tflite::LogToStderr(); + ::tflite::LogToStderr(); ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } diff --git a/tensorflow/contrib/lite/kernels/register.cc b/tensorflow/contrib/lite/kernels/register.cc index ca7a0dd1949a3a31d26be770a7df781cc5fe7533..45ad5f18903927ff8f2743e96c167cfcb11bdcca 100644 --- a/tensorflow/contrib/lite/kernels/register.cc +++ b/tensorflow/contrib/lite/kernels/register.cc @@ -20,7 +20,7 @@ namespace ops { namespace builtin { TfLiteRegistration* Register_RELU(); -TfLiteRegistration* Register_RELU1(); +TfLiteRegistration* Register_RELU_N1_TO_1(); TfLiteRegistration* Register_RELU6(); TfLiteRegistration* Register_TANH(); TfLiteRegistration* Register_LOGISTIC(); @@ -31,6 +31,7 @@ TfLiteRegistration* Register_CONV_2D(); TfLiteRegistration* Register_DEPTHWISE_CONV_2D(); TfLiteRegistration* Register_SVDF(); TfLiteRegistration* Register_RNN(); +TfLiteRegistration* Register_UNIDIRECTIONAL_SEQUENCE_RNN(); TfLiteRegistration* Register_EMBEDDING_LOOKUP(); TfLiteRegistration* Register_EMBEDDING_LOOKUP_SPARSE(); TfLiteRegistration* Register_FULLY_CONNECTED(); @@ -39,18 +40,27 @@ TfLiteRegistration* Register_HASHTABLE_LOOKUP(); TfLiteRegistration* Register_SOFTMAX(); TfLiteRegistration* Register_CONCATENATION(); TfLiteRegistration* Register_ADD(); +TfLiteRegistration* Register_SPACE_TO_BATCH_ND(); +TfLiteRegistration* Register_DIV(); +TfLiteRegistration* Register_SUB(); +TfLiteRegistration* Register_BATCH_TO_SPACE_ND(); TfLiteRegistration* Register_MUL(); TfLiteRegistration* Register_L2_NORMALIZATION(); TfLiteRegistration* Register_LOCAL_RESPONSE_NORMALIZATION(); TfLiteRegistration* Register_LSTM(); +TfLiteRegistration* Register_PAD(); TfLiteRegistration* Register_RESHAPE(); TfLiteRegistration* Register_RESIZE_BILINEAR(); TfLiteRegistration* Register_SKIP_GRAM(); TfLiteRegistration* Register_SPACE_TO_DEPTH(); +TfLiteRegistration* Register_GATHER(); +TfLiteRegistration* Register_TRANSPOSE(); +TfLiteRegistration* Register_MEAN(); +TfLiteRegistration* Register_SQUEEZE(); BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_RELU, Register_RELU()); - AddBuiltin(BuiltinOperator_RELU1, Register_RELU1()); + AddBuiltin(BuiltinOperator_RELU_N1_TO_1, Register_RELU_N1_TO_1()); AddBuiltin(BuiltinOperator_RELU6, Register_RELU6()); AddBuiltin(BuiltinOperator_TANH, Register_TANH()); AddBuiltin(BuiltinOperator_LOGISTIC, Register_LOGISTIC()); @@ -61,6 +71,8 @@ BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_DEPTHWISE_CONV_2D, Register_DEPTHWISE_CONV_2D()); AddBuiltin(BuiltinOperator_SVDF, Register_SVDF()); AddBuiltin(BuiltinOperator_RNN, Register_RNN()); + AddBuiltin(BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN, + Register_UNIDIRECTIONAL_SEQUENCE_RNN()); AddBuiltin(BuiltinOperator_EMBEDDING_LOOKUP, Register_EMBEDDING_LOOKUP()); AddBuiltin(BuiltinOperator_EMBEDDING_LOOKUP_SPARSE, Register_EMBEDDING_LOOKUP_SPARSE()); @@ -70,15 +82,24 @@ BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_SOFTMAX, Register_SOFTMAX()); AddBuiltin(BuiltinOperator_CONCATENATION, Register_CONCATENATION()); AddBuiltin(BuiltinOperator_ADD, Register_ADD()); + AddBuiltin(BuiltinOperator_SPACE_TO_BATCH_ND, Register_SPACE_TO_BATCH_ND()); + AddBuiltin(BuiltinOperator_BATCH_TO_SPACE_ND, Register_BATCH_TO_SPACE_ND()); AddBuiltin(BuiltinOperator_MUL, Register_MUL()); AddBuiltin(BuiltinOperator_L2_NORMALIZATION, Register_L2_NORMALIZATION()); AddBuiltin(BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION, Register_LOCAL_RESPONSE_NORMALIZATION()); AddBuiltin(BuiltinOperator_LSTM, Register_LSTM()); + AddBuiltin(BuiltinOperator_PAD, Register_PAD()); AddBuiltin(BuiltinOperator_RESHAPE, Register_RESHAPE()); AddBuiltin(BuiltinOperator_RESIZE_BILINEAR, Register_RESIZE_BILINEAR()); AddBuiltin(BuiltinOperator_SKIP_GRAM, Register_SKIP_GRAM()); AddBuiltin(BuiltinOperator_SPACE_TO_DEPTH, Register_SPACE_TO_DEPTH()); + AddBuiltin(BuiltinOperator_GATHER, Register_GATHER()); + AddBuiltin(BuiltinOperator_TRANSPOSE, Register_TRANSPOSE()); + AddBuiltin(BuiltinOperator_MEAN, Register_MEAN()); + AddBuiltin(BuiltinOperator_DIV, Register_DIV()); + AddBuiltin(BuiltinOperator_SUB, Register_SUB()); + AddBuiltin(BuiltinOperator_SQUEEZE, Register_SQUEEZE()); } TfLiteRegistration* BuiltinOpResolver::FindOp( diff --git a/tensorflow/contrib/lite/kernels/reshape_test.cc b/tensorflow/contrib/lite/kernels/reshape_test.cc index 59ce7d5648c04f78123b16a195d3a4928d28394b..0fbcf6e6aa311d2cac491336ee54ccf58bbda8fd 100644 --- a/tensorflow/contrib/lite/kernels/reshape_test.cc +++ b/tensorflow/contrib/lite/kernels/reshape_test.cc @@ -83,8 +83,7 @@ TEST(ReshapeOpTest, WithStretchDimension) { } // namespace tflite int main(int argc, char** argv) { - // On Linux, add: tflite::LogToStderr(); - tflite::LogToStderr(); + ::tflite::LogToStderr(); ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } diff --git a/tensorflow/contrib/lite/kernels/resize_bilinear_test.cc b/tensorflow/contrib/lite/kernels/resize_bilinear_test.cc index 0257c0b557feb352413bcc33cb4e2ecdb32c5111..314a71e210d9b5ea75bb137ef228273ef48f28b5 100644 --- a/tensorflow/contrib/lite/kernels/resize_bilinear_test.cc +++ b/tensorflow/contrib/lite/kernels/resize_bilinear_test.cc @@ -111,7 +111,7 @@ TEST(ResizeBilinearOpTest, ThreeDimensionalResize) { } // namespace tflite int main(int argc, char** argv) { - // On Linux, add: tflite::LogToStderr(); + ::tflite::LogToStderr(); ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } diff --git a/tensorflow/contrib/lite/kernels/skip_gram_test.cc b/tensorflow/contrib/lite/kernels/skip_gram_test.cc index e7f6bc904be5e4c23a88f5b4ae7e199346c78ab2..185b64cb44969b57588ea5d0b40f55b6ddf8e11f 100644 --- a/tensorflow/contrib/lite/kernels/skip_gram_test.cc +++ b/tensorflow/contrib/lite/kernels/skip_gram_test.cc @@ -251,7 +251,7 @@ TEST(SkipGramTest, TestInputWithExtraSpace) { } // namespace tflite int main(int argc, char** argv) { - // On Linux, add: tflite::LogToStderr(); + ::tflite::LogToStderr(); ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } diff --git a/tensorflow/contrib/lite/kernels/softmax_test.cc b/tensorflow/contrib/lite/kernels/softmax_test.cc index ec8ec03b0d0279cad8543352b1dbaf34c88a7957..6c5338ff0fd26337c9adc8e0b94a0a88edfde37f 100644 --- a/tensorflow/contrib/lite/kernels/softmax_test.cc +++ b/tensorflow/contrib/lite/kernels/softmax_test.cc @@ -136,8 +136,7 @@ TEST(SoftmaxOpTest, CompareWithTFminiBetaNotEq1) { } // namespace tflite int main(int argc, char** argv) { - // On Linux, add: tflite::LogToStderr(); - tflite::LogToStderr(); + ::tflite::LogToStderr(); ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } diff --git a/tensorflow/contrib/lite/kernels/space_to_batch_nd.cc b/tensorflow/contrib/lite/kernels/space_to_batch_nd.cc new file mode 100644 index 0000000000000000000000000000000000000000..2e22d0db56a233bf554c57cf86275832ce941a18 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/space_to_batch_nd.cc @@ -0,0 +1,182 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace space_to_batch_nd { + +// This file has two implementations of SpaceToBatchND. +enum KernelType { + kReference, + kGenericOptimized, +}; + +// Inputs specified in the 2nd tensor (block_shape) and 3rd tensor (paddings) +// are ignored. Only use the `block_shape` and `paddings` specified in params. +// TODO(nupurgarg): Support inputs as tensors in SpaceToBatchND. +struct SpaceToBatchNDContext { + SpaceToBatchNDContext(TfLiteContext* context, TfLiteNode* node) { + params = reinterpret_cast(node->builtin_data); + input = GetInput(context, node, 0); + output = GetOutput(context, node, 0); + } + TfLiteSpaceToBatchNDParams* params; + TfLiteTensor* input; + TfLiteTensor* output; +}; + +// Currently, only 4D NHWC input/output op_context are supported. +// The 4D array need to have exactly 2 spatial dimensions. +// TODO(nupurgarg): Support arbitrary dimension in SpaceToBatchND. +const int kInputDimensionNum = 4; +const int kOutputDimensionNum = 4; +const int kSpatialDimensionNum = 2; +const int kPaddingDimensionNum = 4; + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE(context, NumInputs(node) >= 1 && NumInputs(node) <= 3); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + SpaceToBatchNDContext op_context(context, node); + TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.input), + kInputDimensionNum); + TF_LITE_ENSURE_EQ(context, op_context.params->num_spatial_dimensions, + kSpatialDimensionNum); + TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type); + + const TfLiteIntArray* input_size = op_context.input->dims; + const int* block_shape = op_context.params->block_shape; + + TfLiteIntArray* output_size = TfLiteIntArrayCreate(kOutputDimensionNum); + + // Ensures the input height and width (with padding) is a multiple of block + // shape height and width. + for (int dim = 0; dim < kSpatialDimensionNum; ++dim) { + int final_dim_size = + (input_size->data[dim + 1] + op_context.params->before_paddings[dim] + + op_context.params->after_paddings[dim]); + TF_LITE_ENSURE_EQ(context, final_dim_size % block_shape[dim], 0); + output_size->data[dim + 1] = final_dim_size / block_shape[dim]; + } + + const int output_batch_size = + input_size->data[0] * block_shape[0] * block_shape[1]; + const int output_channel_size = input_size->data[3]; + + output_size->data[0] = output_batch_size; + output_size->data[3] = output_channel_size; + + return context->ResizeTensor(context, op_context.output, output_size); +} + +template +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + SpaceToBatchNDContext op_context(context, node); + + int block_shape_dims_array[1] = {kSpatialDimensionNum}; + Dims<4> block_shape_dims = GetTensorDims(block_shape_dims_array, 1); + + // Initialize padding array in the format accepted by the kernel code. + // TODO(nupurgarg): Make kernel code accept padding array format that is + // consistent with Pad operation (i.e. before_paddings and after_paddings). + TfLiteIntArray* padding_data = TfLiteIntArrayCreate(kPaddingDimensionNum); + padding_data->data[0] = op_context.params->before_paddings[0]; + padding_data->data[1] = op_context.params->after_paddings[0]; + padding_data->data[2] = op_context.params->before_paddings[1]; + padding_data->data[3] = op_context.params->after_paddings[1]; + int padding_dims_array[1] = {kPaddingDimensionNum}; + Dims<4> padding_dims = GetTensorDims(padding_dims_array, 1); + +#define TF_LITE_SPACE_TO_BATCH_ND(type, scalar) \ + type::SpaceToBatchND(GetTensorData(op_context.input), \ + GetTensorDims(op_context.input), \ + op_context.params->block_shape, block_shape_dims, \ + padding_data->data, padding_dims, \ + GetTensorData(op_context.output), \ + GetTensorDims(op_context.output)) + switch (op_context.input->type) { // Already know in/out types are same. + case kTfLiteFloat32: + if (kernel_type == kReference) { + TF_LITE_SPACE_TO_BATCH_ND(reference_ops, float); + } else { + TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, float); + } + break; + case kTfLiteUInt8: + if (kernel_type == kReference) { + TF_LITE_SPACE_TO_BATCH_ND(reference_ops, uint8_t); + } else { + TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, uint8_t); + } + break; + case kTfLiteInt32: + if (kernel_type == kReference) { + TF_LITE_SPACE_TO_BATCH_ND(reference_ops, int32_t); + } else { + TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, int32_t); + } + break; + case kTfLiteInt64: + if (kernel_type == kReference) { + TF_LITE_SPACE_TO_BATCH_ND(reference_ops, int64_t); + } else { + TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, int64_t); + } + break; + default: + context->ReportError(context, + "Type is currently not supported by SpaceToBatch."); + return kTfLiteError; + } +#undef TF_LITE_SPACE_TO_BATCH_ND + + TfLiteIntArrayFree(padding_data); + return kTfLiteOk; +} + +} // namespace space_to_batch_nd + +TfLiteRegistration* Register_SPACE_TO_BATCH_ND_REF() { + static TfLiteRegistration r = { + nullptr, nullptr, space_to_batch_nd::Prepare, + space_to_batch_nd::Eval}; + return &r; +} + +TfLiteRegistration* Register_SPACE_TO_BATCH_ND_GENERIC_OPT() { + static TfLiteRegistration r = { + nullptr, nullptr, space_to_batch_nd::Prepare, + space_to_batch_nd::Eval}; + return &r; +} + +TfLiteRegistration* Register_SPACE_TO_BATCH_ND() { + // return Register_SPACE_TO_BATCH_ND_REF(); + return Register_SPACE_TO_BATCH_ND_GENERIC_OPT(); +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/space_to_batch_nd_test.cc b/tensorflow/contrib/lite/kernels/space_to_batch_nd_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..45a6aef73d05b57a7f9a7fc6f58c3971c6e03118 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/space_to_batch_nd_test.cc @@ -0,0 +1,110 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class SpaceToBatchNDOpModel : public SingleOpModel { + public: + SpaceToBatchNDOpModel(std::initializer_list input_shape, + std::initializer_list block_shape, + std::initializer_list before_paddings, + std::initializer_list after_paddings) { + input_ = AddInput(TensorType_FLOAT32); + output_ = AddOutput(TensorType_FLOAT32); + SetBuiltinOp(BuiltinOperator_SPACE_TO_BATCH_ND, + BuiltinOptions_SpaceToBatchNDOptions, + CreateSpaceToBatchNDOptions( + builder_, builder_.CreateVector(block_shape), + builder_.CreateVector(before_paddings), + builder_.CreateVector(after_paddings)) + .Union()); + BuildInterpreter({input_shape}); + } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetOutputShape() { return GetTensorShape(output_); } + + private: + int input_; + int output_; +}; + +TEST(SpaceToBatchNDOpTest, InvalidShapeTest) { + EXPECT_DEATH(SpaceToBatchNDOpModel({1, 3, 3, 1}, {2, 2}, {0, 0}, {0, 0}), + "Cannot allocate tensors"); +} + +TEST(SpaceToBatchNDOpTest, SimpleTest) { + SpaceToBatchNDOpModel m({1, 4, 4, 1}, {2, 2}, {0, 0}, {0, 0}); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 2, 2, 1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 3, 9, 11, 2, 4, 10, 12, 5, 7, + 13, 15, 6, 8, 14, 16})); +} + +TEST(SpaceToBatchNDOpTest, MultipleInputBatches) { + SpaceToBatchNDOpModel m({2, 2, 4, 1}, {2, 2}, {0, 0}, {0, 0}); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({8, 1, 2, 1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 3, 9, 11, 2, 4, 10, 12, 5, 7, + 13, 15, 6, 8, 14, 16})); +} + +TEST(SpaceToBatchNDOpTest, SimplePadding) { + SpaceToBatchNDOpModel m({1, 5, 2, 1}, {3, 2}, {1, 2}, {0, 0}); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({6, 2, 2, 1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({ + 0, 0, 0, 5, 0, 0, 0, 6, 0, 1, 0, 7, + 0, 2, 0, 8, 0, 3, 0, 9, 0, 4, 0, 10, + })); +} + +TEST(SpaceToBatchNDOpTest, ComplexPadding) { + SpaceToBatchNDOpModel m({1, 4, 2, 1}, {3, 2}, {1, 2}, {1, 4}); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({6, 2, 4, 1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({ + 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 6, 0, 0, + 0, 1, 0, 0, 0, 7, 0, 0, 0, 2, 0, 0, 0, 8, 0, 0, + 0, 3, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, + })); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/space_to_depth_test.cc b/tensorflow/contrib/lite/kernels/space_to_depth_test.cc index 911f08a92ccd6a97bee414c87bd79091808f0ed1..997f354861a235fb511235e4d64544dc8c3ddb34 100644 --- a/tensorflow/contrib/lite/kernels/space_to_depth_test.cc +++ b/tensorflow/contrib/lite/kernels/space_to_depth_test.cc @@ -95,8 +95,7 @@ TEST(SpaceToDepthOpModel, Int64) { } // namespace tflite int main(int argc, char** argv) { - // On Linux, add: tflite::LogToStderr(); - tflite::LogToStderr(); + ::tflite::LogToStderr(); ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } diff --git a/tensorflow/contrib/lite/kernels/squeeze.cc b/tensorflow/contrib/lite/kernels/squeeze.cc new file mode 100644 index 0000000000000000000000000000000000000000..29447ab021c7b68ff51070d35262402e08dc7ab9 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/squeeze.cc @@ -0,0 +1,99 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace squeeze { + +struct SqueezeContext { + SqueezeContext(TfLiteContext* context, TfLiteNode* node) { + params = reinterpret_cast(node->builtin_data); + input = GetInput(context, node, 0); + output = GetOutput(context, node, 0); + } + TfLiteSqueezeParams* params; + TfLiteTensor* input; + TfLiteTensor* output; +}; + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + SqueezeContext op_context(context, node); + int input_num_dims = NumDimensions(op_context.input); + int num_squeeze_dims = op_context.params->num_squeeze_dims; + + // Determines number of dimensions of output tensor after squeeze. + const TfLiteIntArray* input_dims = op_context.input->dims; + const int* squeeze_dims = op_context.params->squeeze_dims; + TF_LITE_ENSURE(context, input_num_dims <= 8); + bool should_squeeze[8] = {false}; + int num_squeezed_dims = 0; + if (num_squeeze_dims == 0) { + for (int idx = 0; idx < input_num_dims; ++idx) { + if (input_dims->data[idx] == 1) { + should_squeeze[idx] = true; + ++num_squeezed_dims; + } + } + } else { + for (int idx = 0; idx < num_squeeze_dims; ++idx) { + int current = squeeze_dims[idx] < 0 ? squeeze_dims[idx] + input_num_dims + : squeeze_dims[idx]; + TF_LITE_ENSURE(context, current >= 0 && current < input_num_dims && + input_dims->data[current] == 1); + if (!should_squeeze[current]) ++num_squeezed_dims; + should_squeeze[current] = true; + } + } + // Sets output dimensions. + TfLiteIntArray* output_dims = + TfLiteIntArrayCreate(input_num_dims - num_squeezed_dims); + for (int in_idx = 0, out_idx = 0; in_idx < input_num_dims; ++in_idx) { + if (!should_squeeze[in_idx]) { + output_dims->data[out_idx++] = input_dims->data[in_idx]; + } + } + return context->ResizeTensor(context, op_context.output, output_dims); +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + SqueezeContext op_context(context, node); + TF_LITE_ENSURE_EQ(context, op_context.input->bytes, op_context.output->bytes); + memcpy(op_context.output->data.raw, op_context.input->data.raw, + op_context.input->bytes); + return kTfLiteOk; +} + +} // namespace squeeze + +TfLiteRegistration* Register_SQUEEZE() { + static TfLiteRegistration r = {nullptr, nullptr, squeeze::Prepare, + squeeze::Eval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/squeeze_test.cc b/tensorflow/contrib/lite/kernels/squeeze_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..409227b626afdc8cbed66a27e300b320b59023f2 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/squeeze_test.cc @@ -0,0 +1,113 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class BaseSqueezeOpModel : public SingleOpModel { + public: + BaseSqueezeOpModel(const TensorData& input, const TensorData& output, + std::initializer_list axis) { + input_ = AddInput(input); + output_ = AddOutput(output); + SetBuiltinOp( + BuiltinOperator_SQUEEZE, BuiltinOptions_SqueezeOptions, + CreateSqueezeOptions(builder_, builder_.CreateVector(axis)) + .Union()); + BuildInterpreter({GetShape(input_)}); + } + + int input() { return input_; } + + protected: + int input_; + int output_; +}; + +class FloatSqueezeOpModel : public BaseSqueezeOpModel { + public: + using BaseSqueezeOpModel::BaseSqueezeOpModel; + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetOutputShape() { return GetTensorShape(output_); } +}; + +TEST(FloatSqueezeOpTest, SqueezeAll) { + std::initializer_list data = { + 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, + 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; + FloatSqueezeOpModel m({TensorType_FLOAT32, {1, 24, 1}}, + {TensorType_FLOAT32, {24}}, {}); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({24})); + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray({1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, + 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, + 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0})); +} + +TEST(FloatSqueezeOpTest, SqueezeSelectedAxis) { + std::initializer_list data = { + 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, + 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; + FloatSqueezeOpModel m({TensorType_FLOAT32, {1, 24, 1}}, + {TensorType_FLOAT32, {24}}, {2}); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 24})); + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray({1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, + 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, + 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0})); +} + +TEST(FloatSqueezeOpTest, SqueezeNegativeAxis) { + std::initializer_list data = { + 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, + 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; + FloatSqueezeOpModel m({TensorType_FLOAT32, {1, 24, 1}}, + {TensorType_FLOAT32, {24}}, {-1, 0}); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({24})); + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray({1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, + 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, + 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0})); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/sub.cc b/tensorflow/contrib/lite/kernels/sub.cc new file mode 100644 index 0000000000000000000000000000000000000000..ddaf498d5bac0109429224e7cf66cb3debcabc22 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/sub.cc @@ -0,0 +1,129 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace sub { + +// This file has three implementation of Div. +enum KernelType { + kReference, + kGenericOptimized, // Neon-free + kNeonOptimized, +}; + +constexpr int kInputTensor1 = 0; +constexpr int kInputTensor2 = 1; +constexpr int kOutputTensor = 0; + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); + TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + TF_LITE_ENSURE_EQ(context, NumDimensions(input1), NumDimensions(input2)); + for (int i = 0; i < NumDimensions(input1); ++i) { + TF_LITE_ENSURE_EQ(context, SizeOfDimension(input1, i), + SizeOfDimension(input2, i)); + } + + TF_LITE_ENSURE_EQ(context, input1->type, output->type); + TF_LITE_ENSURE_EQ(context, input2->type, output->type); + + TfLiteIntArray* output_size = TfLiteIntArrayCopy(input1->dims); + return context->ResizeTensor(context, output, output_size); +} + +template +void EvalSubFloat(TfLiteContext* context, TfLiteNode* node, + TfLiteSubParams* params, TfLiteTensor* input1, + TfLiteTensor* input2, TfLiteTensor* output) { + float output_activation_min, output_activation_max; + CalculateActivationRangeFloat(params->activation, &output_activation_min, + &output_activation_max); +#define TF_LITE_Sub(type) \ + type::Sub(GetTensorData(input1), GetTensorDims(input1), \ + GetTensorData(input2), GetTensorDims(input2), \ + output_activation_min, output_activation_max, \ + GetTensorData(output), GetTensorDims(output)) + if (kernel_type == kReference) { + TF_LITE_Sub(reference_ops); + } else { + TF_LITE_Sub(optimized_ops); + } +#undef TF_LITE_Sub +} + +template +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + + TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); + TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + if (output->type == kTfLiteFloat32) { + EvalSubFloat(context, node, params, input1, input2, output); + } else { + context->ReportError(context, "Inputs and outputs not all float types."); + return kTfLiteError; + } + + return kTfLiteOk; +} + +} // namespace sub + +TfLiteRegistration* Register_SUB_REF() { + static TfLiteRegistration r = {nullptr, nullptr, sub::Prepare, + sub::Eval}; + return &r; +} + +TfLiteRegistration* Register_SUB_GENERIC_OPT() { + static TfLiteRegistration r = {nullptr, nullptr, sub::Prepare, + sub::Eval}; + return &r; +} + +TfLiteRegistration* Register_SUB_NEON_OPT() { + static TfLiteRegistration r = {nullptr, nullptr, sub::Prepare, + sub::Eval}; + return &r; +} + +TfLiteRegistration* Register_SUB() { +#ifdef USE_NEON + return Register_SUB_NEON_OPT(); +#else + return Register_SUB_GENERIC_OPT(); +#endif +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/svdf_test.cc b/tensorflow/contrib/lite/kernels/svdf_test.cc index d956025e9dfc9b6c03e55657023fb042c8ac485d..4de2ceaf053df31a4bc857fb250db416c071e80f 100644 --- a/tensorflow/contrib/lite/kernels/svdf_test.cc +++ b/tensorflow/contrib/lite/kernels/svdf_test.cc @@ -306,7 +306,7 @@ TEST(SVDFOpTest, BlackBoxTestRank2) { } // namespace tflite int main(int argc, char** argv) { - // On Linux, add: tflite::LogToStderr(); + ::tflite::LogToStderr(); ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } diff --git a/tensorflow/contrib/lite/kernels/test_util.cc b/tensorflow/contrib/lite/kernels/test_util.cc index f716ba8741fd469e7ee405ac300924b53c5c48e5..b69f2b3e4bc66c94fdfc7ed4c244151be63a1711 100644 --- a/tensorflow/contrib/lite/kernels/test_util.cc +++ b/tensorflow/contrib/lite/kernels/test_util.cc @@ -180,4 +180,17 @@ int32_t SingleOpModel::GetTensorSize(int index) const { return total_size; } +template <> +std::vector SingleOpModel::ExtractVector(int index) { + TfLiteTensor* tensor_ptr = interpreter_->tensor(index); + CHECK(tensor_ptr != nullptr); + const int num_strings = GetStringCount(tensor_ptr); + std::vector result; + result.reserve(num_strings); + for (int i = 0; i < num_strings; ++i) { + const auto str = GetString(tensor_ptr, i); + result.emplace_back(str.str, str.len); + } + return result; +} } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/test_util.h b/tensorflow/contrib/lite/kernels/test_util.h index e68e49466119c50ec123edb84f1b1b6390a15a60..531c1366a87e20e140e779b767e29b1fd1111f97 100644 --- a/tensorflow/contrib/lite/kernels/test_util.h +++ b/tensorflow/contrib/lite/kernels/test_util.h @@ -24,16 +24,11 @@ limitations under the License. #include "tensorflow/contrib/lite/kernels/register.h" #include "tensorflow/contrib/lite/model.h" #include "tensorflow/contrib/lite/string_util.h" +#include "tensorflow/contrib/lite/testing/util.h" #include "tensorflow/core/platform/logging.h" namespace tflite { -inline void LogToStderr() { -#ifdef PLATFORM_GOOGLE - FLAGS_logtostderr = true; -#endif -} - // A gmock matcher that check that elements of a float vector match to a given // tolerance. std::vector<::testing::Matcher> ArrayFloatNear( @@ -197,6 +192,9 @@ class SingleOpModel { std::map> custom_registrations_; }; +// Strings have a special implementation that is in test_util.cc +template <> +std::vector SingleOpModel::ExtractVector(int index); } // namespace tflite #endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_TEST_UTIL_H_ diff --git a/tensorflow/contrib/lite/kernels/transpose.cc b/tensorflow/contrib/lite/kernels/transpose.cc new file mode 100644 index 0000000000000000000000000000000000000000..75d8136b6a26efd805d9fc8e9db26dce2cfcfcb1 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/transpose.cc @@ -0,0 +1,142 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace transpose { + +// This file has two implementations of Transpose. +enum KernelType { + kReference, +}; + +// TODO(nupurgarg): Permutation arrays represented as a tensor are ignored. Only +// use the `perm` specified in `params`. +struct TransposeContext { + TransposeContext(TfLiteContext* context, TfLiteNode* node) { + params = reinterpret_cast(node->builtin_data); + input = GetInput(context, node, 0); + output = GetOutput(context, node, 0); + } + TfLiteTransposeParams* params; + TfLiteTensor* input; + TfLiteTensor* output; +}; + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE(context, NumInputs(node) == 1 || NumInputs(node) == 2); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + TransposeContext op_context(context, node); + int dims = NumDimensions(op_context.input); + + // Ensure validity of input tensor and permutation array. + TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type); + TF_LITE_ENSURE_EQ(context, dims, op_context.params->num_dimensions); + TF_LITE_ENSURE_MSG(context, dims <= 4, + "Transpose op only supports 1D-4D input arrays."); + for (int idx = 0; idx < dims; ++idx) { + TF_LITE_ENSURE_MSG(context, + op_context.params->perm[idx] >= 0 && + op_context.params->perm[idx] < dims, + "Transpose op permutations array is out of bounds."); + } + + // Determine size of output tensor. + const TfLiteIntArray* input_size = op_context.input->dims; + TfLiteIntArray* output_size = TfLiteIntArrayCreate(dims); + for (int idx = 0; idx < dims; ++idx) { + output_size->data[idx] = input_size->data[op_context.params->perm[idx]]; + } + + return context->ResizeTensor(context, op_context.output, output_size); +} + +template +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + TransposeContext op_context(context, node); + + // Reverse the permuted axes and convert to 4D due to the way Dims are + // constructed in GetTensorDims. + const int kOutputDimensionNum = 4; + int reversed_perm[kOutputDimensionNum]; + int size = op_context.params->num_dimensions; + for (int output_k = 0, input_k = size - 1; output_k < size; + ++output_k, --input_k) { + reversed_perm[output_k] = size - op_context.params->perm[input_k] - 1; + } + for (int k = size; k < kOutputDimensionNum; ++k) { + reversed_perm[k] = k; + } + +#define TF_LITE_TRANSPOSE(type, scalar) \ + type::Transpose(GetTensorData(op_context.input), \ + GetTensorDims(op_context.input), \ + GetTensorData(op_context.output), \ + GetTensorDims(op_context.output), reversed_perm) + + switch (op_context.input->type) { + case kTfLiteFloat32: + if (kernel_type == kReference) { + TF_LITE_TRANSPOSE(reference_ops, float); + } + break; + case kTfLiteUInt8: + if (kernel_type == kReference) { + TF_LITE_TRANSPOSE(reference_ops, uint8_t); + } + break; + case kTfLiteInt32: + if (kernel_type == kReference) { + TF_LITE_TRANSPOSE(reference_ops, int32_t); + } + break; + case kTfLiteInt64: + if (kernel_type == kReference) { + TF_LITE_TRANSPOSE(reference_ops, int64_t); + } + break; + default: + context->ReportError(context, + "Type is currently not supported by Transpose."); + return kTfLiteError; + } +#undef TF_LITE_TRANSPOSE + + return kTfLiteOk; +} + +} // namespace transpose + +TfLiteRegistration* Register_TRANSPOSE_REF() { + static TfLiteRegistration r = {nullptr, nullptr, transpose::Prepare, + transpose::Eval}; + return &r; +} + +TfLiteRegistration* Register_TRANSPOSE() { return Register_TRANSPOSE_REF(); } + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/transpose_test.cc b/tensorflow/contrib/lite/kernels/transpose_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..7f5832cd5fa3d502b52bf5554111b45136b588ae --- /dev/null +++ b/tensorflow/contrib/lite/kernels/transpose_test.cc @@ -0,0 +1,247 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +void RunTestPermutation(const std::vector& shape, + const std::vector& perms, + std::vector* input_transposed) { + // Count elements and allocate output. + int count = 1; + for (auto factor : shape) count *= factor; + input_transposed->resize(count); + + // Create the dummy data + std::vector input(count); + for (int i = 0; i < input.size(); i++) { + input[i] = i; + } + + // Create reversed and padded perms. + int reversed_perms[4]; + for (int output_k = 0, input_k = shape.size() - 1; output_k < shape.size(); + output_k++, input_k--) { + reversed_perms[output_k] = shape.size() - perms[input_k] - 1; + } + // Unused dimensions should not be permuted so pad with identity transform + // subset. + for (int k = shape.size(); k < 4; k++) { + reversed_perms[k] = k; + } + + // Make input and output dims (i.e. reversed shape and dest_shape). + Dims<4> input_dims = GetTensorDims(shape); + Dims<4> output_dims; + for (int i = 0; i < 4; i++) { + output_dims.sizes[i] = input_dims.sizes[reversed_perms[i]]; + } + output_dims.strides[0] = 1; + for (int k = 1; k < 4; k++) { + output_dims.strides[k] = + output_dims.strides[k - 1] * output_dims.sizes[k - 1]; + } + + reference_ops::Transpose(input.data(), input_dims, + input_transposed->data(), output_dims, + reversed_perms); +} + +TEST(TransposeTest, TestRefOps1D) { + // Basic 1D identity. + std::vector out; + RunTestPermutation({3}, {0}, &out); + ASSERT_EQ(out, std::vector({0, 1, 2})); +} + +TEST(TransposeTest, TestRefOps2D) { + std::vector out; + // Basic 2D. + RunTestPermutation({3, 2}, {1, 0}, &out); + ASSERT_EQ(out, std::vector({0, 2, 4, 1, 3, 5})); + // Identity. + RunTestPermutation({3, 2}, {0, 1}, &out); + ASSERT_EQ(out, std::vector({0, 1, 2, 3, 4, 5})); +} + +TEST(TransposeTest, TestRefOps3D) { + std::vector out; + // Test 3 dimensional + { + std::vector ref({0, 4, 8, 12, 16, 20, 1, 5, 9, 13, 17, 21, + 2, 6, 10, 14, 18, 22, 3, 7, 11, 15, 19, 23}); + RunTestPermutation({2, 3, 4}, {2, 0, 1}, &out); + ASSERT_EQ(out, ref); + } + // Test 3 dimensional identity transform + { + RunTestPermutation({2, 3, 4}, {0, 1, 2}, &out); + std::vector ref(out.size()); + for (int k = 0; k < ref.size(); k++) ref[k] = k; + ASSERT_EQ(out, ref); + } +} + +TEST(TransposeTest, TestRefOps4D) { + std::vector out; + // Basic 4d. + RunTestPermutation({2, 3, 4, 5}, {2, 0, 1, 3}, &out); + ASSERT_EQ( + out, + std::vector( + {0, 1, 2, 3, 4, 20, 21, 22, 23, 24, 40, 41, 42, 43, 44, + 60, 61, 62, 63, 64, 80, 81, 82, 83, 84, 100, 101, 102, 103, 104, + 5, 6, 7, 8, 9, 25, 26, 27, 28, 29, 45, 46, 47, 48, 49, + 65, 66, 67, 68, 69, 85, 86, 87, 88, 89, 105, 106, 107, 108, 109, + 10, 11, 12, 13, 14, 30, 31, 32, 33, 34, 50, 51, 52, 53, 54, + 70, 71, 72, 73, 74, 90, 91, 92, 93, 94, 110, 111, 112, 113, 114, + 15, 16, 17, 18, 19, 35, 36, 37, 38, 39, 55, 56, 57, 58, 59, + 75, 76, 77, 78, 79, 95, 96, 97, 98, 99, 115, 116, 117, 118, 119})); + RunTestPermutation({2, 3, 4, 5}, {0, 1, 2, 3}, &out); + // Basic identity. + std::vector ref(out.size()); + for (int k = 0; k < ref.size(); k++) ref[k] = k; + ASSERT_EQ(out, ref); +} + +class TransposeOpModel : public SingleOpModel { + public: + TransposeOpModel(std::initializer_list input_shape, + std::initializer_list perm) { + input_ = AddInput(TensorType_FLOAT32); + output_ = AddOutput(TensorType_FLOAT32); + SetBuiltinOp( + BuiltinOperator_TRANSPOSE, BuiltinOptions_TransposeOptions, + CreateTransposeOptions(builder_, builder_.CreateVector(perm)) + .Union()); + BuildInterpreter({input_shape}); + } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetOutputShape() { return GetTensorShape(output_); } + + private: + int input_; + int output_; +}; + +TEST(TransposeTest, TestUnequalPermSize) { + EXPECT_DEATH(TransposeOpModel({1, 3, 3, 1}, {2, 2}), + "dims != op_context.params->num_dimensions"); +} + +TEST(TransposeTest, TestPermOutOfBounds) { + EXPECT_DEATH(TransposeOpModel({1, 3, 3, 1}, {0, -1, -2, -3}), + "Transpose op permutations array is out of bounds."); + EXPECT_DEATH(TransposeOpModel({1, 3, 3, 1}, {0, 1, 2, 4}), + "Transpose op permutations array is out of bounds."); +} + +TEST(TransposeTest, Test1DInputTensor) { + TransposeOpModel m({3}, {0}); + m.SetInput({1, 2, 3}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3})); +} + +TEST(TransposeTest, Test2DInputTensor) { + TransposeOpModel m({3, 2}, {1, 0}); + m.SetInput({0, 1, 2, 3, 4, 5}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 2, 4, 1, 3, 5})); +} + +TEST(TransposeTest, Test3DInputTensor) { + TransposeOpModel m({2, 3, 4}, {2, 0, 1}); + m.SetInput({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 2, 3})); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({0, 4, 8, 12, 16, 20, 1, 5, 9, 13, 17, 21, + 2, 6, 10, 14, 18, 22, 3, 7, 11, 15, 19, 23})); +} + +TEST(TransposeTest, Test5DInputTensor) { + EXPECT_DEATH(TransposeOpModel({1, 2, 3, 4, 5}, {0, 1, 2, 3, 4}), + "Transpose op only supports 1D-4D input arrays."); +} + +TEST(TransposeTest, SimpleTestNoReorder) { + TransposeOpModel m({1, 2, 3, 1}, {0, 1, 2, 3}); + m.SetInput({1, 2, 3, 4, 5, 6}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2, 3, 1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6})); +} + +TEST(TransposeTest, SimpleTestWithReorder) { + TransposeOpModel m({1, 2, 3, 1}, {2, 1, 3, 0}); + m.SetInput({1, 2, 3, 4, 5, 6}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 2, 1, 1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 4, 2, 5, 3, 6})); +} + +TEST(TransposeTest, ComplexTestWithReorder) { + TransposeOpModel m({2, 3, 4, 5}, {2, 0, 1, 3}); + m.SetInput({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, + 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, + 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, + 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, + 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, + 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, + 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, + 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, + 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119}); + m.Invoke(); + + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 2, 3, 5})); + auto result = ElementsAreArray( + {0, 1, 2, 3, 4, 20, 21, 22, 23, 24, 40, 41, 42, 43, 44, + 60, 61, 62, 63, 64, 80, 81, 82, 83, 84, 100, 101, 102, 103, 104, + 5, 6, 7, 8, 9, 25, 26, 27, 28, 29, 45, 46, 47, 48, 49, + 65, 66, 67, 68, 69, 85, 86, 87, 88, 89, 105, 106, 107, 108, 109, + 10, 11, 12, 13, 14, 30, 31, 32, 33, 34, 50, 51, 52, 53, 54, + 70, 71, 72, 73, 74, 90, 91, 92, 93, 94, 110, 111, 112, 113, 114, + 15, 16, 17, 18, 19, 35, 36, 37, 38, 39, 55, 56, 57, 58, 59, + 75, 76, 77, 78, 79, 95, 96, 97, 98, 99, 115, 116, 117, 118, 119}); + EXPECT_THAT(m.GetOutput(), result); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc new file mode 100644 index 0000000000000000000000000000000000000000..f5f1ec2cf3f45ae730b849b18e2b85fac50159c7 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc @@ -0,0 +1,208 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/activation_functor.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace unidirectional_sequence_rnn { + +constexpr int kInputTensor = 0; +constexpr int kWeightsTensor = 1; +constexpr int kRecurrentWeightsTensor = 2; +constexpr int kBiasTensor = 3; +constexpr int kHiddenStateTensor = 0; +constexpr int kOutputTensor = 1; + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + // Check we have all the inputs and outputs we need. + TF_LITE_ENSURE_EQ(context, node->inputs->size, 4); + TF_LITE_ENSURE_EQ(context, node->outputs->size, 2); + + TfLiteTensor* input = &context->tensors[node->inputs->data[kInputTensor]]; + TfLiteTensor* input_weights = + &context->tensors[node->inputs->data[kWeightsTensor]]; + TfLiteTensor* recurrent_weights = + &context->tensors[node->inputs->data[kRecurrentWeightsTensor]]; + TfLiteTensor* bias = &context->tensors[node->inputs->data[kBiasTensor]]; + + // Check all the parameters of tensor match within themselves and match the + // input configuration. + auto* params = reinterpret_cast(node->builtin_data); + const bool time_major = params->time_major; + const int batch_size = + (time_major) ? input->dims->data[1] : input->dims->data[0]; + const int max_time = + (time_major) ? input->dims->data[0] : input->dims->data[1]; + const int num_units = input_weights->dims->data[0]; + TF_LITE_ASSERT_EQ(input->dims->data[2], input_weights->dims->data[1]); + TF_LITE_ASSERT_EQ(input_weights->dims->data[0], bias->dims->data[0]); + TF_LITE_ASSERT_EQ(recurrent_weights->dims->data[0], bias->dims->data[0]); + TF_LITE_ASSERT_EQ(recurrent_weights->dims->data[1], bias->dims->data[0]); + + TfLiteTensor* hidden_state = + &context->tensors[node->outputs->data[kHiddenStateTensor]]; + TfLiteTensor* output = &context->tensors[node->outputs->data[kOutputTensor]]; + + // Resize state. + TfLiteIntArray* hidden_state_size_array = TfLiteIntArrayCreate(2); + hidden_state_size_array->data[0] = batch_size; + hidden_state_size_array->data[1] = num_units; + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, hidden_state, + hidden_state_size_array)); + + // Mark hidden state as a persistent tensor. + hidden_state->allocation_type = kTfLiteArenaRwPersistent; + + // Resize output. + TfLiteIntArray* output_size_array = TfLiteIntArrayCreate(3); + output_size_array->data[0] = (time_major) ? max_time : batch_size; + output_size_array->data[1] = (time_major) ? batch_size : max_time; + output_size_array->data[2] = num_units; + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, output, + output_size_array)); + + return kTfLiteOk; +} + +namespace { +void RnnStep(const float* input_ptr_batch, const float* input_weights_ptr, + const float* recurrent_weights_ptr, const float* bias_ptr, + int input_size, int num_units, int input_weights_stride, + int recurrent_weights_stride, TfLiteFusedActivation activation, + float* hidden_state_ptr_batch, float* output_ptr_batch) { + // Output = bias + for (int o = 0; o < num_units; o++) { + output_ptr_batch[o] = bias_ptr[o]; + } + + // Output += input * input_weights + for (int o = 0; o < num_units; o++) { + for (int i = 0; i < input_size; i++) { + output_ptr_batch[o] += input_ptr_batch[i] * input_weights_ptr[i]; + } + input_weights_ptr += input_weights_stride; + } + + // Output += recurrent_weights * hidden_state + for (int o = 0; o < num_units; o++) { + for (int h = 0; h < num_units; h++) { + output_ptr_batch[o] += + hidden_state_ptr_batch[h] * recurrent_weights_ptr[h]; + } + recurrent_weights_ptr += recurrent_weights_stride; + } + + // Output = activation(Output) and update hidden_state + for (int o = 0; o < num_units; o++) { + output_ptr_batch[o] = (ActivationFunctor(activation))(output_ptr_batch[o]); + hidden_state_ptr_batch[o] = output_ptr_batch[o]; + } +} +} // namespace + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + + TfLiteTensor* input = &context->tensors[node->inputs->data[kInputTensor]]; + TfLiteTensor* input_weights = + &context->tensors[node->inputs->data[kWeightsTensor]]; + TfLiteTensor* recurrent_weights = + &context->tensors[node->inputs->data[kRecurrentWeightsTensor]]; + TfLiteTensor* bias = &context->tensors[node->inputs->data[kBiasTensor]]; + TfLiteTensor* hidden_state = + &context->tensors[node->outputs->data[kHiddenStateTensor]]; + TfLiteTensor* output = &context->tensors[node->outputs->data[kOutputTensor]]; + + // Initialize the pointer bias. + const float* bias_ptr = bias->data.f; + + const bool time_major = params->time_major; + const int batch_size = + (time_major) ? input->dims->data[1] : input->dims->data[0]; + const int max_time = + (time_major) ? input->dims->data[0] : input->dims->data[1]; + const int num_units = input_weights->dims->data[0]; + const int input_size = input->dims->data[2]; + const int input_weights_stride = input_weights->dims->data[1]; + const int recurrent_weights_stride = recurrent_weights->dims->data[1]; + + // Initialize input_weights and recurrent_weights. + const float* input_weights_ptr = input_weights->data.f; + const float* recurrent_weights_ptr = recurrent_weights->data.f; + + if (time_major) { + // Unroll the sequence + for (int s = 0; s < max_time; s++) { + for (int b = 0; b < batch_size; b++) { + // Initialize the pointer to hidden state. + float* hidden_state_ptr_batch = hidden_state->data.f + b * num_units; + // Initialize the pointer to input and output. + const float* input_ptr_batch = + input->data.f + s * input_size * batch_size + b * input_size; + float* output_ptr_batch = + output->data.f + s * num_units * batch_size + b * num_units; + + RnnStep(input_ptr_batch, input_weights_ptr, recurrent_weights_ptr, + bias_ptr, input_size, num_units, input_weights_stride, + recurrent_weights_stride, params->activation, + hidden_state_ptr_batch, output_ptr_batch); + } + } + } else { + // For each batch + for (int b = 0; b < batch_size; b++) { + // Initialize the pointer to hidden state. + float* hidden_state_ptr_batch = hidden_state->data.f + b * num_units; + for (int s = 0; s < max_time; s++) { + // Initialize the pointer to input and output. + const float* input_ptr_batch = + input->data.f + b * input_size * max_time + s * input_size; + float* output_ptr_batch = + output->data.f + b * num_units * max_time + s * num_units; + + RnnStep(input_ptr_batch, input_weights_ptr, recurrent_weights_ptr, + bias_ptr, input_size, num_units, input_weights_stride, + recurrent_weights_stride, params->activation, + hidden_state_ptr_batch, output_ptr_batch); + } + } + } + return kTfLiteOk; +} + +} // namespace unidirectional_sequence_rnn + +TfLiteRegistration* Register_UNIDIRECTIONAL_SEQUENCE_RNN() { + static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr, + unidirectional_sequence_rnn::Prepare, + unidirectional_sequence_rnn::Eval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn_test.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..82c680ec3d8656004d721c8498292677cb061b6b --- /dev/null +++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn_test.cc @@ -0,0 +1,352 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Unit test for TFLite Sequential RNN op. + +#include +#include + +#include +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +static float rnn_input[] = { + 0.23689353, 0.285385, 0.037029743, -0.19858193, -0.27569133, + 0.43773448, 0.60379338, 0.35562468, -0.69424844, -0.93421471, + -0.87287879, 0.37144363, -0.62476718, 0.23791671, 0.40060222, + 0.1356622, -0.99774903, -0.98858172, -0.38952237, -0.47685933, + 0.31073618, 0.71511042, -0.63767755, -0.31729108, 0.33468103, + 0.75801885, 0.30660987, -0.37354088, 0.77002847, -0.62747043, + -0.68572164, 0.0069220066, 0.65791464, 0.35130811, 0.80834007, + -0.61777675, -0.21095741, 0.41213346, 0.73784804, 0.094794154, + 0.47791874, 0.86496925, -0.53376222, 0.85315156, 0.10288584, + 0.86684, -0.011186242, 0.10513687, 0.87825835, 0.59929144, + 0.62827742, 0.18899453, 0.31440187, 0.99059987, 0.87170351, + -0.35091716, 0.74861872, 0.17831337, 0.2755419, 0.51864719, + 0.55084288, 0.58982027, -0.47443086, 0.20875752, -0.058871567, + -0.66609079, 0.59098077, 0.73017097, 0.74604273, 0.32882881, + -0.17503482, 0.22396147, 0.19379807, 0.29120302, 0.077113032, + -0.70331609, 0.15804303, -0.93407321, 0.40182066, 0.036301374, + 0.66521823, 0.0300982, -0.7747041, -0.02038002, 0.020698071, + -0.90300065, 0.62870288, -0.23068321, 0.27531278, -0.095755219, + -0.712036, -0.17384434, -0.50593495, -0.18646687, -0.96508682, + 0.43519354, 0.14744234, 0.62589407, 0.1653645, -0.10651493, + -0.045277178, 0.99032974, -0.88255352, -0.85147917, 0.28153265, + 0.19455957, -0.55479527, -0.56042433, 0.26048636, 0.84702539, + 0.47587705, -0.074295521, -0.12287641, 0.70117295, 0.90532446, + 0.89782166, 0.79817224, 0.53402734, -0.33286154, 0.073485017, + -0.56172788, -0.044897556, 0.89964068, -0.067662835, 0.76863563, + 0.93455386, -0.6324693, -0.083922029}; + +static float rnn_golden_output[] = { + 0.496726, 0, 0.965996, 0, 0.0584254, 0, + 0, 0.12315, 0, 0, 0.612266, 0.456601, + 0, 0.52286, 1.16099, 0.0291232, + + 0, 0, 0.524901, 0, 0, 0, + 0, 1.02116, 0, 1.35762, 0, 0.356909, + 0.436415, 0.0355727, 0, 0, + + 0, 0, 0, 0.262335, 0, 0, + 0, 1.33992, 0, 2.9739, 0, 0, + 1.31914, 2.66147, 0, 0, + + 0.942568, 0, 0, 0, 0.025507, 0, + 0, 0, 0.321429, 0.569141, 1.25274, 1.57719, + 0.8158, 1.21805, 0.586239, 0.25427, + + 1.04436, 0, 0.630725, 0, 0.133801, 0.210693, + 0.363026, 0, 0.533426, 0, 1.25926, 0.722707, + 0, 1.22031, 1.30117, 0.495867, + + 0.222187, 0, 0.72725, 0, 0.767003, 0, + 0, 0.147835, 0, 0, 0, 0.608758, + 0.469394, 0.00720298, 0.927537, 0, + + 0.856974, 0.424257, 0, 0, 0.937329, 0, + 0, 0, 0.476425, 0, 0.566017, 0.418462, + 0.141911, 0.996214, 1.13063, 0, + + 0.967899, 0, 0, 0, 0.0831304, 0, + 0, 1.00378, 0, 0, 0, 1.44818, + 1.01768, 0.943891, 0.502745, 0, + + 0.940135, 0, 0, 0, 0, 0, + 0, 2.13243, 0, 0.71208, 0.123918, 1.53907, + 1.30225, 1.59644, 0.70222, 0, + + 0.804329, 0, 0.430576, 0, 0.505872, 0.509603, + 0.343448, 0, 0.107756, 0.614544, 1.44549, 1.52311, + 0.0454298, 0.300267, 0.562784, 0.395095, + + 0.228154, 0, 0.675323, 0, 1.70536, 0.766217, + 0, 0, 0, 0.735363, 0.0759267, 1.91017, + 0.941888, 0, 0, 0, + + 0, 0, 1.5909, 0, 0, 0, + 0, 0.5755, 0, 0.184687, 0, 1.56296, + 0.625285, 0, 0, 0, + + 0, 0, 0.0857888, 0, 0, 0, + 0, 0.488383, 0.252786, 0, 0, 0, + 1.02817, 1.85665, 0, 0, + + 0.00981836, 0, 1.06371, 0, 0, 0, + 0, 0, 0, 0.290445, 0.316406, 0, + 0.304161, 1.25079, 0.0707152, 0, + + 0.986264, 0.309201, 0, 0, 0, 0, + 0, 1.64896, 0.346248, 0, 0.918175, 0.78884, + 0.524981, 1.92076, 2.07013, 0.333244, + + 0.415153, 0.210318, 0, 0, 0, 0, + 0, 2.02616, 0, 0.728256, 0.84183, 0.0907453, + 0.628881, 3.58099, 1.49974, 0 +}; + +class UnidirectionalRNNOpModel : public SingleOpModel { + public: + UnidirectionalRNNOpModel(int batches, int sequence_len, int units, int size, + bool time_major) + : batches_(batches), + sequence_len_(sequence_len), + units_(units), + input_size_(size) { + input_ = AddInput(TensorType_FLOAT32); + weights_ = AddInput(TensorType_FLOAT32); + recurrent_weights_ = AddInput(TensorType_FLOAT32); + bias_ = AddInput(TensorType_FLOAT32); + hidden_state_ = AddOutput(TensorType_FLOAT32); + output_ = AddOutput(TensorType_FLOAT32); + SetBuiltinOp(BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN, + BuiltinOptions_SequenceRNNOptions, + CreateSequenceRNNOptions(builder_, time_major, + ActivationFunctionType_RELU) + .Union()); + if (time_major) { + BuildInterpreter({{sequence_len_, batches_, input_size_}, + {units_, input_size_}, + {units_, units_}, + {units_}}); + } else { + BuildInterpreter({{batches_, sequence_len_, input_size_}, + {units_, input_size_}, + {units_, units_}, + {units_}}); + } + } + + void SetBias(std::initializer_list f) { PopulateTensor(bias_, f); } + + void SetWeights(std::initializer_list f) { + PopulateTensor(weights_, f); + } + + void SetRecurrentWeights(std::initializer_list f) { + PopulateTensor(recurrent_weights_, f); + } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + void SetInput(int offset, float* begin, float* end) { + PopulateTensor(input_, offset, begin, end); + } + + void ResetHiddenState() { + const int zero_buffer_size = units_ * batches_; + std::unique_ptr zero_buffer(new float[zero_buffer_size]); + memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float)); + PopulateTensor(hidden_state_, 0, zero_buffer.get(), + zero_buffer.get() + zero_buffer_size); + } + + std::vector GetOutput() { return ExtractVector(output_); } + + int input_size() { return input_size_; } + int num_units() { return units_; } + int num_batches() { return batches_; } + int sequence_len() { return sequence_len_; } + + private: + int input_; + int weights_; + int recurrent_weights_; + int bias_; + int hidden_state_; + int output_; + + int batches_; + int sequence_len_; + int units_; + int input_size_; +}; + +// TODO(mirkov): add another test which directly compares to TF once TOCO +// supports the conversion from dynamic_rnn with BasicRNNCell. +TEST(FullyConnectedOpTest, BlackBoxTest) { + UnidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16, + /*units=*/16, /*size=*/8, /*time_major=*/false); + rnn.SetWeights( + {0.461459, 0.153381, 0.529743, -0.00371218, 0.676267, -0.211346, + 0.317493, 0.969689, -0.343251, 0.186423, 0.398151, 0.152399, + 0.448504, 0.317662, 0.523556, -0.323514, 0.480877, 0.333113, + -0.757714, -0.674487, -0.643585, 0.217766, -0.0251462, 0.79512, + -0.595574, -0.422444, 0.371572, -0.452178, -0.556069, -0.482188, + -0.685456, -0.727851, 0.841829, 0.551535, -0.232336, 0.729158, + -0.00294906, -0.69754, 0.766073, -0.178424, 0.369513, -0.423241, + 0.548547, -0.0152023, -0.757482, -0.85491, 0.251331, -0.989183, + 0.306261, -0.340716, 0.886103, -0.0726757, -0.723523, -0.784303, + 0.0354295, 0.566564, -0.485469, -0.620498, 0.832546, 0.697884, + -0.279115, 0.294415, -0.584313, 0.548772, 0.0648819, 0.968726, + 0.723834, -0.0080452, -0.350386, -0.272803, 0.115121, -0.412644, + -0.824713, -0.992843, -0.592904, -0.417893, 0.863791, -0.423461, + -0.147601, -0.770664, -0.479006, 0.654782, 0.587314, -0.639158, + 0.816969, -0.337228, 0.659878, 0.73107, 0.754768, -0.337042, + 0.0960841, 0.368357, 0.244191, -0.817703, -0.211223, 0.442012, + 0.37225, -0.623598, -0.405423, 0.455101, 0.673656, -0.145345, + -0.511346, -0.901675, -0.81252, -0.127006, 0.809865, -0.721884, + 0.636255, 0.868989, -0.347973, -0.10179, -0.777449, 0.917274, + 0.819286, 0.206218, -0.00785118, 0.167141, 0.45872, 0.972934, + -0.276798, 0.837861, 0.747958, -0.0151566, -0.330057, -0.469077, + 0.277308, 0.415818}); + + rnn.SetBias({0.065691948, -0.69055247, 0.1107955, -0.97084129, -0.23957068, + -0.23566568, -0.389184, 0.47481549, -0.4791103, 0.29931796, + 0.10463274, 0.83918178, 0.37197268, 0.61957061, 0.3956964, + -0.37609905}); + + rnn.SetRecurrentWeights({0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1}); + + rnn.ResetHiddenState(); + const int input_sequence_size = rnn.input_size() * rnn.sequence_len(); + float* batch_start = rnn_input; + float* batch_end = batch_start + input_sequence_size; + rnn.SetInput(0, batch_start, batch_end); + rnn.SetInput(input_sequence_size, batch_start, batch_end); + + rnn.Invoke(); + + float* golden_start = rnn_golden_output; + float* golden_end = golden_start + rnn.num_units() * rnn.sequence_len(); + std::vector expected; + expected.insert(expected.end(), golden_start, golden_end); + expected.insert(expected.end(), golden_start, golden_end); + + EXPECT_THAT(rnn.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); +} + +TEST(FullyConnectedOpTest, TimeMajorBlackBoxTest) { + UnidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16, + /*units=*/16, /*size=*/8, /*time_major=*/true); + rnn.SetWeights( + {0.461459, 0.153381, 0.529743, -0.00371218, 0.676267, -0.211346, + 0.317493, 0.969689, -0.343251, 0.186423, 0.398151, 0.152399, + 0.448504, 0.317662, 0.523556, -0.323514, 0.480877, 0.333113, + -0.757714, -0.674487, -0.643585, 0.217766, -0.0251462, 0.79512, + -0.595574, -0.422444, 0.371572, -0.452178, -0.556069, -0.482188, + -0.685456, -0.727851, 0.841829, 0.551535, -0.232336, 0.729158, + -0.00294906, -0.69754, 0.766073, -0.178424, 0.369513, -0.423241, + 0.548547, -0.0152023, -0.757482, -0.85491, 0.251331, -0.989183, + 0.306261, -0.340716, 0.886103, -0.0726757, -0.723523, -0.784303, + 0.0354295, 0.566564, -0.485469, -0.620498, 0.832546, 0.697884, + -0.279115, 0.294415, -0.584313, 0.548772, 0.0648819, 0.968726, + 0.723834, -0.0080452, -0.350386, -0.272803, 0.115121, -0.412644, + -0.824713, -0.992843, -0.592904, -0.417893, 0.863791, -0.423461, + -0.147601, -0.770664, -0.479006, 0.654782, 0.587314, -0.639158, + 0.816969, -0.337228, 0.659878, 0.73107, 0.754768, -0.337042, + 0.0960841, 0.368357, 0.244191, -0.817703, -0.211223, 0.442012, + 0.37225, -0.623598, -0.405423, 0.455101, 0.673656, -0.145345, + -0.511346, -0.901675, -0.81252, -0.127006, 0.809865, -0.721884, + 0.636255, 0.868989, -0.347973, -0.10179, -0.777449, 0.917274, + 0.819286, 0.206218, -0.00785118, 0.167141, 0.45872, 0.972934, + -0.276798, 0.837861, 0.747958, -0.0151566, -0.330057, -0.469077, + 0.277308, 0.415818}); + + rnn.SetBias({0.065691948, -0.69055247, 0.1107955, -0.97084129, -0.23957068, + -0.23566568, -0.389184, 0.47481549, -0.4791103, 0.29931796, + 0.10463274, 0.83918178, 0.37197268, 0.61957061, 0.3956964, + -0.37609905}); + + rnn.SetRecurrentWeights({0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1}); + + rnn.ResetHiddenState(); + for (int i = 0; i < rnn.sequence_len(); i++) { + float* batch_start = rnn_input + i * rnn.input_size(); + float* batch_end = batch_start + rnn.input_size(); + // The two batches are identical. + rnn.SetInput(2 * i * rnn.input_size(), batch_start, batch_end); + rnn.SetInput((2 * i + 1) * rnn.input_size(), batch_start, batch_end); + } + + rnn.Invoke(); + + std::vector expected; + for (int i = 0; i < rnn.sequence_len(); i++) { + float* golden_batch_start = rnn_golden_output + i * rnn.num_units(); + float* golden_batch_end = golden_batch_start + rnn.num_units(); + expected.insert(expected.end(), golden_batch_start, golden_batch_end); + expected.insert(expected.end(), golden_batch_start, golden_batch_end); + } + + EXPECT_THAT(rnn.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + // On Linux, add: tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/lib_package/BUILD b/tensorflow/contrib/lite/lib_package/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..3c1b8d3d45f2bb382bbe6b789ec6ac7ec89ebc66 --- /dev/null +++ b/tensorflow/contrib/lite/lib_package/BUILD @@ -0,0 +1,16 @@ +package(default_visibility = ["//visibility:private"]) + +# Create the LICENSE file for libraries that are used by TensorFlow Lite +# C library. +genrule( + name = "clicenses_generate", + srcs = [ + "//third_party/eigen3:LICENSE", + "@arm_neon_2_x86_sse//:LICENSE", + "@farmhash_archive//:COPYING", + "@gemmlowp//:LICENSE", + ], + outs = ["LICENSE"], + cmd = "$(location :concat_licenses.sh) $(SRCS) >$@", + tools = [":concat_licenses.sh"], +) diff --git a/tensorflow/contrib/lite/lib_package/concat_licenses.sh b/tensorflow/contrib/lite/lib_package/concat_licenses.sh new file mode 100755 index 0000000000000000000000000000000000000000..2070f64e9fa4384234361556da0ed6f5089319b3 --- /dev/null +++ b/tensorflow/contrib/lite/lib_package/concat_licenses.sh @@ -0,0 +1,28 @@ +#!/usr/bin/env bash +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# +# Script aimed to combining multiple license files into a single one. + +for f in $@ +do + echo "--------------------------------------------------------------------------------" + echo "BEGIN LICENSE FOR $f" + echo "--------------------------------------------------------------------------------" + cat $f + echo "--------------------------------------------------------------------------------" + echo "END LICENSE FOR $f" + echo "--------------------------------------------------------------------------------" +done diff --git a/tensorflow/contrib/lite/memory_planner.h b/tensorflow/contrib/lite/memory_planner.h new file mode 100644 index 0000000000000000000000000000000000000000..b11d86c375ca6bd8693f2271df63ecb3c87657de --- /dev/null +++ b/tensorflow/contrib/lite/memory_planner.h @@ -0,0 +1,45 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_MEMORY_PLANNER_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_MEMORY_PLANNER_H_ + +#include "tensorflow/contrib/lite/context.h" + +namespace tflite { + +// A MemoryPlanner is responsible for planning and executing a number of +// memory-related operations that are necessary in TF Lite. +class MemoryPlanner { + public: + virtual ~MemoryPlanner() {} + + // Plans the necessary memory allocations. This is the MemoryPlanner's + // pre-processing step and is called when the graph structure is known but + // actual size of the tensors is not. + virtual TfLiteStatus PlanAllocations() = 0; + + // Allocates the necessary memory to execute all nodes in the interval + // [first_node, last_node]. + virtual TfLiteStatus ExecuteAllocations(int first_node, int last_node) = 0; + + // Invalidates allocations made earliers. This is called when tensors sizes + // have change. All planned allocations remain, but can't be used until + // ExecuteAllocations() is called. + virtual TfLiteStatus ResetAllocations() = 0; +}; + +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_MEMORY_PLANNER_H_ diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc index e2f3560e61baae88a4afaafaa202cde784063efc..4b0c853f77c102efa7574ff97c254d92504730a3 100644 --- a/tensorflow/contrib/lite/model.cc +++ b/tensorflow/contrib/lite/model.cc @@ -60,6 +60,14 @@ std::unique_ptr FlatBufferModel::BuildFromBuffer( return model; } +std::unique_ptr FlatBufferModel::BuildFromModel( + const tflite::Model* model_spec, ErrorReporter* error_reporter) { + std::unique_ptr model; + model.reset(new FlatBufferModel(model_spec, error_reporter)); + if (!model->initialized()) model.reset(); + return model; +} + FlatBufferModel::FlatBufferModel(const char* filename, bool mmap_file, ErrorReporter* error_reporter, bool use_nnapi) : error_reporter_(error_reporter ? error_reporter @@ -72,8 +80,7 @@ FlatBufferModel::FlatBufferModel(const char* filename, bool mmap_file, } else { allocation_ = new FileCopyAllocation(filename, error_reporter); } - if (!allocation_->valid()) return; - if (!CheckModelIdentifier()) return; + if (!allocation_->valid() || !CheckModelIdentifier()) return; model_ = VerifyAndGetModel(allocation_->base(), allocation_->bytes()); } @@ -99,6 +106,13 @@ FlatBufferModel::FlatBufferModel(const char* ptr, size_t num_bytes, model_ = VerifyAndGetModel(allocation_->base(), allocation_->bytes()); } +FlatBufferModel::FlatBufferModel(const Model* model, + ErrorReporter* error_reporter) + : error_reporter_(error_reporter ? error_reporter + : DefaultErrorReporter()) { + model_ = model; +} + FlatBufferModel::~FlatBufferModel() { delete allocation_; } InterpreterBuilder::InterpreterBuilder(const FlatBufferModel& model, @@ -160,6 +174,27 @@ std::vector FlatBufferIntArrayToVector(T* flat_array) { return ret; } +// Copies the contents from the flatbuffer int vector `flatbuffer` into the +// int array `buffer`. `flat_vector` and `buffer` represent the same +// configuration operation for a given operation. +void FlatBufferIntVectorToArray(int max_size_of_buffer, + const flatbuffers::Vector* flat_vector, + int* buffer, ErrorReporter* error_reporter) { + if (!flat_vector) { + error_reporter->Report("Input array not provided for operation.\n"); + } else { + int num_dimensions = flat_vector->Length(); + if (num_dimensions > max_size_of_buffer / sizeof(int)) { + error_reporter->Report( + "Found too many dimensions in the operation's input array.\n"); + } else { + for (int i = 0; i < num_dimensions; ++i) { + buffer[i] = flat_vector->Get(i); + } + } + } +} + // Allocate a structure using C malloc, but make sure the structure is a // POD structure that doesn't require constructors to run. The reason we do // this, is that Interpreter's C extension part will take ownership and wants @@ -175,6 +210,9 @@ T* MallocPOD() { // This handles builtin data explicitly as there are flatbuffer schemas. // // Returns memory that must be feed. +// +// TODO(nupurgarg): Pass in void ** and return TfLiteStatus to ensure program +// crashes if error reporter is called. void* ParseOpData(const Operator* op, BuiltinOperator op_type, ErrorReporter* error_reporter) { auto parse_padding = [](Padding padding) { @@ -192,7 +230,7 @@ void* ParseOpData(const Operator* op, BuiltinOperator op_type, return kTfLiteActNone; case ActivationFunctionType_RELU: return kTfLiteActRelu; - case ActivationFunctionType_RELU1: + case ActivationFunctionType_RELU_N1_TO_1: return kTfLiteActRelu1; case ActivationFunctionType_RELU6: return kTfLiteActRelu6; @@ -248,7 +286,7 @@ void* ParseOpData(const Operator* op, BuiltinOperator op_type, case BuiltinOperator_TANH: case BuiltinOperator_LOGISTIC: case BuiltinOperator_RELU: - case BuiltinOperator_RELU1: + case BuiltinOperator_RELU_N1_TO_1: case BuiltinOperator_RELU6: case BuiltinOperator_CONCAT_EMBEDDINGS: break; @@ -301,6 +339,17 @@ void* ParseOpData(const Operator* op, BuiltinOperator op_type, builtin_data = reinterpret_cast(params); break; } + case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN: { + TfLiteSequenceRNNParams* params = MallocPOD(); + if (auto* sequence_rnn_params = + op->builtin_options_as_SequenceRNNOptions()) { + params->activation = + parse_activation(sequence_rnn_params->fused_activation_function()); + params->time_major = sequence_rnn_params->time_major(); + } + builtin_data = reinterpret_cast(params); + break; + } case BuiltinOperator_RNN: { TfLiteRNNParams* params = MallocPOD(); if (auto* rnn_params = op->builtin_options_as_RNNOptions()) { @@ -375,6 +424,24 @@ void* ParseOpData(const Operator* op, BuiltinOperator op_type, builtin_data = reinterpret_cast(params); break; } + case BuiltinOperator_DIV: { + auto* params = MallocPOD(); + if (auto* schema_params = op->builtin_options_as_DivOptions()) { + params->activation = + parse_activation(schema_params->fused_activation_function()); + } + builtin_data = reinterpret_cast(params); + break; + } + case BuiltinOperator_SUB: { + auto* params = MallocPOD(); + if (auto* schema_params = op->builtin_options_as_SubOptions()) { + params->activation = + parse_activation(schema_params->fused_activation_function()); + } + builtin_data = reinterpret_cast(params); + break; + } case BuiltinOperator_L2_NORMALIZATION: { auto* params = MallocPOD(); if (auto* schema_params = op->builtin_options_as_L2NormOptions()) { @@ -417,23 +484,35 @@ void* ParseOpData(const Operator* op, BuiltinOperator op_type, builtin_data = reinterpret_cast(params); break; } + case BuiltinOperator_PAD: { + auto* params = MallocPOD(); + if (auto* schema_params = op->builtin_options_as_PadOptions()) { + auto* before_padding = schema_params->before_padding(); + FlatBufferIntVectorToArray(sizeof(params->before_padding), + before_padding, params->before_padding, + error_reporter); + + auto* after_padding = schema_params->after_padding(); + FlatBufferIntVectorToArray(sizeof(params->after_padding), after_padding, + params->after_padding, error_reporter); + + if (before_padding->Length() != after_padding->Length()) { + error_reporter->Report( + "Before padding and after padding arrays need to contain the " + "same number of dimensions.\n"); + } + params->num_dimensions = after_padding->Length(); + } + builtin_data = reinterpret_cast(params); + break; + } case BuiltinOperator_RESHAPE: { auto* params = MallocPOD(); if (auto* schema_params = op->builtin_options_as_ReshapeOptions()) { auto* new_shape = schema_params->new_shape(); - if (!new_shape) { - error_reporter->Report("No new_shape provided for Reshape\n"); - } else { - params->num_dimensions = new_shape->Length(); - if (params->num_dimensions > sizeof(params->shape) / sizeof(int)) { - error_reporter->Report( - "Found too many dimensions in Reshape's new_shape\n"); - } else { - for (int i = 0; i < params->num_dimensions; ++i) { - params->shape[i] = new_shape->Get(i); - } - } - } + FlatBufferIntVectorToArray(sizeof(params->shape), new_shape, + params->shape, error_reporter); + params->num_dimensions = new_shape->Length(); } builtin_data = reinterpret_cast(params); break; @@ -456,6 +535,88 @@ void* ParseOpData(const Operator* op, BuiltinOperator op_type, builtin_data = reinterpret_cast(params); break; } + case BuiltinOperator_GATHER: { + TfLiteGatherParams* params = MallocPOD(); + params->axis = 0; + if (auto* gather_params = op->builtin_options_as_GatherOptions()) { + params->axis = gather_params->axis(); + } + + builtin_data = reinterpret_cast(params); + break; + } + case BuiltinOperator_SPACE_TO_BATCH_ND: { + auto* params = MallocPOD(); + if (auto* schema_params = + op->builtin_options_as_SpaceToBatchNDOptions()) { + const auto& block_shape = schema_params->block_shape(); + FlatBufferIntVectorToArray(sizeof(params->block_shape), block_shape, + params->block_shape, error_reporter); + const auto& before_paddings = schema_params->before_paddings(); + FlatBufferIntVectorToArray(sizeof(params->before_paddings), + before_paddings, params->before_paddings, + error_reporter); + const auto& after_paddings = schema_params->after_paddings(); + FlatBufferIntVectorToArray(sizeof(params->after_paddings), + after_paddings, params->after_paddings, + error_reporter); + params->num_spatial_dimensions = block_shape->Length(); + } + builtin_data = reinterpret_cast(params); + break; + } + case BuiltinOperator_BATCH_TO_SPACE_ND: { + auto* params = MallocPOD(); + if (auto* schema_params = + op->builtin_options_as_BatchToSpaceNDOptions()) { + const auto& block_shape = schema_params->block_shape(); + FlatBufferIntVectorToArray(sizeof(params->block_shape), block_shape, + params->block_shape, error_reporter); + const auto& before_crops = schema_params->before_crops(); + FlatBufferIntVectorToArray(sizeof(params->before_crops), before_crops, + params->before_crops, error_reporter); + const auto& after_crops = schema_params->after_crops(); + FlatBufferIntVectorToArray(sizeof(params->after_crops), after_crops, + params->after_crops, error_reporter); + params->num_spatial_dimensions = block_shape->Length(); + } + builtin_data = reinterpret_cast(params); + break; + } + case BuiltinOperator_TRANSPOSE: { + auto* params = MallocPOD(); + if (auto* schema_params = op->builtin_options_as_TransposeOptions()) { + const auto& perm = schema_params->perm(); + FlatBufferIntVectorToArray(sizeof(params->perm), perm, params->perm, + error_reporter); + params->num_dimensions = perm->Length(); + } + builtin_data = reinterpret_cast(params); + break; + } + case BuiltinOperator_MEAN: { + auto* params = MallocPOD(); + if (auto* schema_params = op->builtin_options_as_MeanOptions()) { + const auto& axis = schema_params->axis(); + FlatBufferIntVectorToArray(sizeof(params->axis), axis, params->axis, + error_reporter); + params->keep_dims = schema_params->keep_dims(); + params->num_axis_dimensions = axis->Length(); + } + builtin_data = reinterpret_cast(params); + break; + } + case BuiltinOperator_SQUEEZE: { + auto* params = MallocPOD(); + if (auto* schema_params = op->builtin_options_as_SqueezeOptions()) { + const auto& squeeze_dims = schema_params->squeeze_dims(); + FlatBufferIntVectorToArray(sizeof(params->squeeze_dims), squeeze_dims, + params->squeeze_dims, error_reporter); + params->num_squeeze_dims = squeeze_dims->Length(); + } + builtin_data = reinterpret_cast(params); + break; + } } return builtin_data; } diff --git a/tensorflow/contrib/lite/model.h b/tensorflow/contrib/lite/model.h index 15659d33f37dfb2f119480ed88d2e1b81f34c145..e0c96f7f0480cd3146f95a22957477809cf0096d 100644 --- a/tensorflow/contrib/lite/model.h +++ b/tensorflow/contrib/lite/model.h @@ -45,18 +45,25 @@ namespace tflite { // or mmapped. This uses flatbuffers as the serialization format. class FlatBufferModel { public: - // Build a model based on a file. Return a nullptr in case of failure. + // Builds a model based on a file. Returns a nullptr in case of failure. static std::unique_ptr BuildFromFile( const char* filename, ErrorReporter* error_reporter = DefaultErrorReporter()); - // Build a model based on a pre-loaded flatbuffer. The caller retains + // Builds a model based on a pre-loaded flatbuffer. The caller retains // ownership of the buffer and should keep it alive until the returned object - // is destroyed. Return a nullptr in case of failure. + // is destroyed. Returns a nullptr in case of failure. static std::unique_ptr BuildFromBuffer( const char* buffer, size_t buffer_size, ErrorReporter* error_reporter = DefaultErrorReporter()); + // Builds a model directly from a flatbuffer pointer. The caller retains + // ownership of the buffer and should keep it alive until the returned object + // is destroyed. Returns a nullptr in case of failure. + static std::unique_ptr BuildFromModel( + const tflite::Model* model_spec, + ErrorReporter* error_reporter = DefaultErrorReporter()); + // Releases memory or unmaps mmaped meory. ~FlatBufferModel(); @@ -75,7 +82,7 @@ class FlatBufferModel { bool CheckModelIdentifier() const; private: - // Load a model from `filename`. If `mmap_file` is true then use mmap, + // Loads a model from `filename`. If `mmap_file` is true then use mmap, // otherwise make a copy of the model in a buffer. // // Note, if `error_reporter` is null, then a DefaultErrorReporter() will be @@ -85,8 +92,8 @@ class FlatBufferModel { ErrorReporter* error_reporter = DefaultErrorReporter(), bool use_nnapi = false); - // Load a model from `ptr` and `num_bytes` of the model file. The `ptr` has to - // remain alive and unchanged until the end of this flatbuffermodel's + // Loads a model from `ptr` and `num_bytes` of the model file. The `ptr` has + // to remain alive and unchanged until the end of this flatbuffermodel's // lifetime. // // Note, if `error_reporter` is null, then a DefaultErrorReporter() will be @@ -94,6 +101,10 @@ class FlatBufferModel { FlatBufferModel(const char* ptr, size_t num_bytes, ErrorReporter* error_reporter = DefaultErrorReporter()); + // Loads a model from Model flatbuffer. The `model` has to remain alive and + // unchanged until the end of this flatbuffermodel's lifetime. + FlatBufferModel(const Model* model, ErrorReporter* error_reporter); + // Flatbuffer traverser pointer. (Model* is a pointer that is within the // allocated memory of the data allocated by allocation's internals. const tflite::Model* model_ = nullptr; @@ -106,9 +117,9 @@ class FlatBufferModel { // model are mapped to executable function pointers (TfLiteRegistrations). class OpResolver { public: - // Find the op registration for a builtin operator by enum code. + // Finds the op registration for a builtin operator by enum code. virtual TfLiteRegistration* FindOp(tflite::BuiltinOperator op) const = 0; - // Find the op registration of a custom operator by op name. + // Finds the op registration of a custom operator by op name. virtual TfLiteRegistration* FindOp(const char* op) const = 0; virtual ~OpResolver() {} }; @@ -131,7 +142,7 @@ class InterpreterBuilder { public: InterpreterBuilder(const FlatBufferModel& model, const OpResolver& op_resolver); - // Build an interpreter given only the raw flatbuffer Model object (instead + // Builds an interpreter given only the raw flatbuffer Model object (instead // of a FlatBufferModel). Mostly used for testing. // If `error_reporter` is null, then DefaultErrorReporter() is used. InterpreterBuilder(const ::tflite::Model* model, diff --git a/tensorflow/contrib/lite/model_test.cc b/tensorflow/contrib/lite/model_test.cc index 61043866420752b552281e353be9a2b41a6aadc8..5330c8f594593655b2a8776cf6b399c0d16cdc19 100644 --- a/tensorflow/contrib/lite/model_test.cc +++ b/tensorflow/contrib/lite/model_test.cc @@ -26,6 +26,7 @@ limitations under the License. #include #include "tensorflow/contrib/lite/error_reporter.h" +#include "tensorflow/contrib/lite/testing/util.h" // Comparison for TfLiteRegistration. Since TfLiteRegistration is a C object, // we must declare this in global namespace, so argument-dependent operator @@ -254,6 +255,28 @@ TEST(BasicFlatBufferModel, TestBuildModelFromCorruptedData) { ASSERT_FALSE(model); } +// Test that loading model directly from a Model flatbuffer works. +TEST(BasicFlatBufferModel, TestBuildFromModel) { + TestErrorReporter reporter; + FileCopyAllocation model_allocation( + "tensorflow/contrib/lite/testdata/test_model.bin", &reporter); + ASSERT_TRUE(model_allocation.valid()); + ::flatbuffers::Verifier verifier( + reinterpret_cast(model_allocation.base()), + model_allocation.bytes()); + ASSERT_TRUE(VerifyModelBuffer(verifier)); + const Model* model_fb = ::tflite::GetModel(model_allocation.base()); + + auto model = FlatBufferModel::BuildFromModel(model_fb); + ASSERT_TRUE(model); + + std::unique_ptr interpreter; + ASSERT_EQ( + InterpreterBuilder(*model, TrivialResolver(&dummy_reg))(&interpreter), + kTfLiteOk); + ASSERT_NE(interpreter, nullptr); +} + // TODO(aselle): Add tests for serialization of builtin op data types. // These tests will occur with the evaluation tests of individual operators, // not here. @@ -261,7 +284,7 @@ TEST(BasicFlatBufferModel, TestBuildModelFromCorruptedData) { } // namespace tflite int main(int argc, char** argv) { - // On Linux, add: tflite::LogToStderr(); + ::tflite::LogToStderr(); ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } diff --git a/tensorflow/contrib/lite/models/smartreply/BUILD b/tensorflow/contrib/lite/models/smartreply/BUILD index fbdf19f2054cf01aec44e3fcb13d0d0a2ff6f914..733c3f4c7fa0605f24a1e6b4c458e34310c079c4 100644 --- a/tensorflow/contrib/lite/models/smartreply/BUILD +++ b/tensorflow/contrib/lite/models/smartreply/BUILD @@ -1,7 +1,92 @@ package(default_visibility = ["//visibility:public"]) +load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts", "gen_selected_ops") + licenses(["notice"]) # Apache 2.0 +gen_selected_ops( + name = "smartreply_ops", + model = "@tflite_smartreply//:smartreply.tflite", +) + +cc_library( + name = "custom_ops", + srcs = [ + "ops/extract_feature.cc", + "ops/normalize.cc", + "ops/predict.cc", + ":smartreply_ops", + ], + copts = tflite_copts(), + deps = [ + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:string_util", + "//tensorflow/contrib/lite/kernels:builtin_ops", + "//tensorflow/contrib/lite/tools:mutable_op_resolver", + "@com_google_absl//absl/strings", + "@com_googlesource_code_re2//:re2", + "@farmhash_archive//:farmhash", + ], +) + +cc_library( + name = "predictor_lib", + srcs = ["predictor.cc"], + hdrs = ["predictor.h"], + copts = tflite_copts(), + deps = [ + ":custom_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:string_util", + "//tensorflow/contrib/lite/kernels:builtin_ops", + "//tensorflow/contrib/lite/tools:mutable_op_resolver", + "@com_google_absl//absl/strings", + "@com_googlesource_code_re2//:re2", + ], +) + +cc_test( + name = "extract_feature_op_test", + size = "small", + srcs = ["ops/extract_feature_test.cc"], + deps = [ + ":custom_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:builtin_ops", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + "@farmhash_archive//:farmhash", + ], +) + +cc_test( + name = "normalize_op_test", + size = "small", + srcs = ["ops/normalize_test.cc"], + deps = [ + ":custom_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:string_util", + "//tensorflow/contrib/lite/kernels:builtin_ops", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +cc_test( + name = "predict_op_test", + size = "small", + srcs = ["ops/predict_test.cc"], + deps = [ + ":custom_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:string_util", + "//tensorflow/contrib/lite/kernels:builtin_ops", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/AndroidManifest.xml b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/AndroidManifest.xml new file mode 100644 index 0000000000000000000000000000000000000000..75ed9432c8fcdfd77a64d3c659e6336c977cdda2 --- /dev/null +++ b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/AndroidManifest.xml @@ -0,0 +1,38 @@ + + + + + + + + + + + + + + + + diff --git a/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/BUILD b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..f8767b443a2aa64b666c3b6bfb7db30cc0be62ea --- /dev/null +++ b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/BUILD @@ -0,0 +1,65 @@ +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +load( + "//tensorflow/contrib/lite:build_def.bzl", + "tflite_copts", + "tflite_jni_binary", +) + +filegroup( + name = "assets", + srcs = [ + "@tflite_smartreply//:model_files", + ], +) + +android_binary( + name = "SmartReplyDemo", + srcs = glob(["java/**/*.java"]), + assets = [":assets"], + assets_dir = "", + custom_package = "com.example.android.smartreply", + manifest = "AndroidManifest.xml", + nocompress_extensions = [ + ".tflite", + ], + resource_files = glob(["res/**"]), + tags = ["manual"], + deps = [ + ":smartreply_runtime", + "@androidsdk//com.android.support:support-v13-25.2.0", + "@androidsdk//com.android.support:support-v4-25.2.0", + ], +) + +cc_library( + name = "smartreply_runtime", + srcs = ["libsmartreply_jni.so"], + visibility = ["//visibility:public"], +) + +tflite_jni_binary( + name = "libsmartreply_jni.so", + deps = [ + ":smartreply_jni_lib", + ], +) + +cc_library( + name = "smartreply_jni_lib", + srcs = [ + "smartreply_jni.cc", + ], + copts = tflite_copts(), + linkopts = [ + "-lm", + "-ldl", + ], + deps = [ + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/models/smartreply:predictor_lib", + ], + alwayslink = 1, +) diff --git a/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/assets/BUILD b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/assets/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..3c882ffc43fde577801428151a43b592e8faaed1 --- /dev/null +++ b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/assets/BUILD @@ -0,0 +1,15 @@ +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files(glob(["*"])) + +filegroup( + name = "assets_files", + srcs = glob( + ["**/*"], + exclude = [ + "BUILD", + ], + ), +) diff --git a/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/assets/backoff_response.txt b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/assets/backoff_response.txt new file mode 100644 index 0000000000000000000000000000000000000000..a0a5b46b5f8d5fd6a0297c8056bb2fb9b6ad9ada --- /dev/null +++ b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/assets/backoff_response.txt @@ -0,0 +1,16 @@ +Ok +Yes +No +👍 +☺ +😟 +❤️ +Lol +Thanks +Got it +Done +Nice +I don't know +What? +Why? +What's up? diff --git a/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/java/com/example/android/smartreply/MainActivity.java b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/java/com/example/android/smartreply/MainActivity.java new file mode 100644 index 0000000000000000000000000000000000000000..02fec9ae5e971ad756ae6c2b0149a6aacfa27cad --- /dev/null +++ b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/java/com/example/android/smartreply/MainActivity.java @@ -0,0 +1,99 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package com.example.android.smartreply; + +import android.app.Activity; +import android.os.Bundle; +import android.os.Handler; +import android.util.Log; +import android.view.View; +import android.widget.Button; +import android.widget.EditText; +import android.widget.TextView; + +/** + * The main (and only) activity of this demo app. Displays a text box which updates as messages are + * received. + */ +public class MainActivity extends Activity { + private static final String TAG = "SmartReplyDemo"; + private SmartReplyClient client; + + private Button sendButton; + private TextView messageTextView; + private EditText messageInput; + + private Handler handler; + + @Override + protected void onCreate(Bundle savedInstanceState) { + super.onCreate(savedInstanceState); + Log.v(TAG, "onCreate"); + setContentView(R.layout.main_activity); + + client = new SmartReplyClient(getApplicationContext()); + handler = new Handler(); + + sendButton = (Button) findViewById(R.id.send_button); + sendButton.setOnClickListener( + (View v) -> { + send(messageInput.getText().toString()); + }); + + messageTextView = (TextView) findViewById(R.id.message_text); + messageInput = (EditText) findViewById(R.id.message_input); + } + + @Override + protected void onStart() { + super.onStart(); + Log.v(TAG, "onStart"); + handler.post( + () -> { + client.loadModel(); + }); + } + + @Override + protected void onStop() { + super.onStop(); + Log.v(TAG, "onStop"); + handler.post( + () -> { + client.unloadModel(); + }); + } + + private void send(final String message) { + handler.post( + () -> { + messageTextView.append("Input: " + message + "\n"); + + SmartReply[] ans = client.predict(new String[] {message}); + for (SmartReply reply : ans) { + appendMessage("Reply: " + reply.getText()); + } + appendMessage("------"); + }); + } + + private void appendMessage(final String message) { + handler.post( + () -> { + messageTextView.append(message + "\n"); + }); + } +} diff --git a/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/java/com/example/android/smartreply/SmartReply.java b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/java/com/example/android/smartreply/SmartReply.java new file mode 100644 index 0000000000000000000000000000000000000000..3357fd17c11f870d1b0998bb26ffa9abf149686b --- /dev/null +++ b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/java/com/example/android/smartreply/SmartReply.java @@ -0,0 +1,44 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package com.example.android.smartreply; + +import android.support.annotation.Keep; + +/** + * SmartReply contains predicted message, and confidence. + * + *

NOTE: this class used by JNI, class name and constructor should not be obfuscated. + */ +@Keep +public class SmartReply { + + private final String text; + private final float score; + + @Keep + public SmartReply(String text, float score) { + this.text = text; + this.score = score; + } + + public String getText() { + return text; + } + + public float getScore() { + return score; + } +} diff --git a/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/java/com/example/android/smartreply/SmartReplyClient.java b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/java/com/example/android/smartreply/SmartReplyClient.java new file mode 100644 index 0000000000000000000000000000000000000000..d5b1ac0ffbc47283aa0c1bf68c0a85ad6228cdcc --- /dev/null +++ b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/java/com/example/android/smartreply/SmartReplyClient.java @@ -0,0 +1,129 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package com.example.android.smartreply; + +import android.content.Context; +import android.content.res.AssetFileDescriptor; +import android.support.annotation.Keep; +import android.support.annotation.WorkerThread; +import android.util.Log; +import java.io.BufferedReader; +import java.io.FileInputStream; +import java.io.IOException; +import java.io.InputStreamReader; +import java.nio.MappedByteBuffer; +import java.nio.channels.FileChannel; +import java.util.ArrayList; +import java.util.List; + +/** Interface to load TfLite model and provide predictions. */ +public class SmartReplyClient implements AutoCloseable { + private static final String TAG = "SmartReplyDemo"; + private static final String MODEL_PATH = "smartreply.tflite"; + private static final String BACKOFF_PATH = "backoff_response.txt"; + private static final String JNI_LIB = "smartreply_jni"; + + private final Context context; + private long storage; + private MappedByteBuffer model; + + private volatile boolean isLibraryLoaded; + + public SmartReplyClient(Context context) { + this.context = context; + } + + public boolean isLoaded() { + return storage != 0; + } + + @WorkerThread + public synchronized void loadModel() { + if (!isLibraryLoaded) { + System.loadLibrary(JNI_LIB); + isLibraryLoaded = true; + } + + try { + model = loadModelFile(); + String[] backoff = loadBackoffList(); + storage = loadJNI(model, backoff); + } catch (IOException e) { + Log.e(TAG, "Fail to load model", e); + return; + } + } + + @WorkerThread + public synchronized SmartReply[] predict(String[] input) { + if (storage != 0) { + return predictJNI(storage, input); + } else { + return new SmartReply[] {}; + } + } + + @WorkerThread + public synchronized void unloadModel() { + close(); + } + + @Override + public synchronized void close() { + if (storage != 0) { + unloadJNI(storage); + storage = 0; + } + } + + private MappedByteBuffer loadModelFile() throws IOException { + AssetFileDescriptor fileDescriptor = context.getAssets().openFd(MODEL_PATH); + FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor()); + try { + FileChannel fileChannel = inputStream.getChannel(); + long startOffset = fileDescriptor.getStartOffset(); + long declaredLength = fileDescriptor.getDeclaredLength(); + return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength); + } finally { + inputStream.close(); + } + } + + private String[] loadBackoffList() throws IOException { + List labelList = new ArrayList(); + BufferedReader reader = + new BufferedReader(new InputStreamReader(context.getAssets().open(BACKOFF_PATH))); + String line; + while ((line = reader.readLine()) != null) { + if (!line.isEmpty()) { + labelList.add(line); + } + } + reader.close(); + String[] ans = new String[labelList.size()]; + labelList.toArray(ans); + return ans; + } + + @Keep + private native long loadJNI(MappedByteBuffer buffer, String[] backoff); + + @Keep + private native SmartReply[] predictJNI(long storage, String[] text); + + @Keep + private native void unloadJNI(long storage); +} diff --git a/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/res/layout/main_activity.xml b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/res/layout/main_activity.xml new file mode 100644 index 0000000000000000000000000000000000000000..23b4cadc007a4457d33b8c8fecf9b1e7b7436320 --- /dev/null +++ b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/res/layout/main_activity.xml @@ -0,0 +1,44 @@ + + + + + + + + + + +